文档首页 > > AI工程师用户指南> 模型包规范> 模型推理代码编写说明

模型推理代码编写说明

分享
更新时间:2020/07/10 GMT+08:00

本章节介绍了在ModelArts中模型推理代码编写的通用方法及说明,针对常用AI引擎的自定义脚本代码示例(包含推理代码示例),请参见自定义脚本代码示例。本文在编写说明下方提供了一个TensorFlow引擎的推理代码示例以及一个在推理脚本中自定义推理逻辑的示例。

推理代码编写指导

  1. 所有的自定义Python代码必须继承自BaseService类,不同类型的模型父类导入语句如表1所示。
    表1 BaseService类导入语句

    模型类型

    父类

    导入语句

    TensorFlow

    TfServingBaseService

    from model_service.tfserving_model_service import TfServingBaseService

    MXNet

    MXNetBaseService

    from mms.model_service.mxnet_model_service import MXNetBaseService

    PyTorch

    PTServingBaseService

    from model_service.pytorch_model_service import PTServingBaseService

    Pyspark

    SparkServingBaseService

    from model_service.spark_model_service import SparkServingBaseService

    Caffe

    CaffeBaseService

    from model_service.caffe_model_service import CaffeBaseService

    XGBoost

    XgSklServingBaseService

    from model_service.python_model_service import XgSklServingBaseService

    Scikit_Learn

    XgSklServingBaseService

    from model_service.python_model_service import XgSklServingBaseService

  2. 可以重写的方法有以下几种。
    表2 重写方法

    方法名

    说明

    __init__(self, model_name, model_path)

    初始化方法,该方法内加载模型及标签等(pytorch和caffe类型模型必须重写,实现模型加载逻辑)。

    _preprocess(self, data)

    预处理方法,在推理请求前调用,用于将API接口用户原始请求数据转换为模型期望输入数据。

    _inference(self, data)

    实际推理请求方法(不建议重写,重写后会覆盖modelarts内置的推理过程,运行自定义的推理逻辑)

    _postprocess(self, data)

    后处理方法,在推理请求完成后调用,用于将模型输出转换为API接口输出

    • 通常,用户可以选择重写preprocess和postprocess方法,以实现数据的API输入的预处理和推理结果输出的后处理。
    • 重写BaseService继承类的初始化方法init可能导致模型“运行异常”
  3. 当前支持两种content-type的接口传入,“multipart/form-data”“application/json”
    • “multipart/form-data”请求
      curl -X POST \
        <modelarts-inference-endpoint> \
        -F image1=@cat.jpg \
        -F images2=@horse.jpg

      对应的传入data为

      [
         {
            "image1":{
               "cat.jpg":"<cat..jpg file io>"
            }
         },
         {
            "image2":{
               "horse.jpg":"<horse.jpg file io>"
            }
         }
      ]
    • “application/json”请求
       curl -X POST \
         <modelarts-inference-endpoint> \
         -d '{
          "images":"base64 encode image"
          }'

      对应的传入data为python dict

       {
          "images":"base64 encode image"
      
       }
  4. 可以使用的属性为模型所在的本地目录,属性名为“self.model_path”。另外pyspark模型在“customize_service.py”中可以使用“self.spark”获取SparkSession对象。

TensorFlow的推理脚本示例

TensorFlow MnistService示例如下。
  • 推理代码
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    from PIL import Image
    import numpy as np
    from model_service.tfserving_model_service import TfServingBaseService
    
    class mnist_service(TfServingBaseService):
    
        def _preprocess(self, data):
            preprocessed_data = {}
    
            for k, v in data.items():
                for file_name, file_content in v.items():
                    image1 = Image.open(file_content)
                    image1 = np.array(image1, dtype=np.float32)
                    image1.resize((1, 784))
                    preprocessed_data[k] = image1
    
            return preprocessed_data
    
        def _postprocess(self, data):
    
            infer_output = {}
    
            for output_name, result in data.items():
    
                infer_output["mnist_result"] = result[0].index(max(result[0]))
    
            return infer_output
    
  • 请求
    curl -X POST \ 在线服务地址 \ -F images=@test.jpg
  • 返回
    {"mnist_result": 7}

在上面的代码示例中,完成了将用户表单输入的图片的大小调整,转换为可以适配模型输入的shape。首先通过Pillow库读取“32×32”的图片,调整图片大小为“1×784”以匹配模型输入。在后续处理中,转换模型输出为列表,用于Restful接口输出展示。

自定义推理逻辑的推理脚本示例

首先,需要在配置文件中,定义自己的依赖包,详细示例请参见使用自定义依赖包的模型配置文件示例。然后通过如下所示示例代码,实现了“saved_model”格式模型的加载推理。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
import json
import os
import threading

import numpy as np
import tensorflow as tf
from PIL import Image

from model_service.tfserving_model_service import TfServingBaseService
import logging

logger = logging.getLogger(__name__)


class MnistService(TfServingBaseService):

    def __init__(self, model_name, model_path):
        self.model_name = model_name
        self.model_path = model_path
        self.model_inputs = {}
        self.model_outputs = {}

        # label文件可以在这里加载,在后处理函数里使用
        # label.txt放在OBS和模型包的目录

        # with open(os.path.join(self.model_path, 'label.txt')) as f:
        #     self.label = json.load(f)

        # 非阻塞方式加载saved_model模型,防止阻塞超时
        thread = threading.Thread(target=self.get_tf_sess)
        thread.start()

    def get_tf_sess(self):
        # 加载saved_model格式的模型

        # session要重用,建议不要用with语句
        sess = tf.Session(graph=tf.Graph())
        meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], self.model_path)
        signature_defs = meta_graph_def.signature_def

        self.sess = sess

        signature = []

        # only one signature allowed
        for signature_def in signature_defs:
            signature.append(signature_def)
        if len(signature) == 1:
            model_signature = signature[0]
        else:
            logger.warning("signatures more than one, use serving_default signature")
            model_signature = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY

        logger.info("model signature: %s", model_signature)

        for signature_name in meta_graph_def.signature_def[model_signature].inputs:
            tensorinfo = meta_graph_def.signature_def[model_signature].inputs[signature_name]
            name = tensorinfo.name
            op = self.sess.graph.get_tensor_by_name(name)
            self.model_inputs[signature_name] = op

        logger.info("model inputs: %s", self.model_inputs)

        for signature_name in meta_graph_def.signature_def[model_signature].outputs:
            tensorinfo = meta_graph_def.signature_def[model_signature].outputs[signature_name]
            name = tensorinfo.name
            op = self.sess.graph.get_tensor_by_name(name)

            self.model_outputs[signature_name] = op

        logger.info("model outputs: %s", self.model_outputs)

    def _preprocess(self, data):
        # https两种请求形式
        # 1. form-data文件格式的请求对应:data = {"请求key值":{"文件名":<文件io>}}
        # 2. json格式对应:data = json.loads("接口传入的json体")
        preprocessed_data = {}

        for k, v in data.items():
            for file_name, file_content in v.items():
                image1 = Image.open(file_content)
                image1 = np.array(image1, dtype=np.float32)
                image1.resize((1, 28, 28))
                preprocessed_data[k] = image1

        return preprocessed_data

    def _inference(self, data):

        feed_dict = {}
        for k, v in data.items():
            if k not in self.model_inputs.keys():
                logger.error("input key %s is not in model inputs %s", k, list(self.model_inputs.keys()))
                raise Exception("input key %s is not in model inputs %s" % (k, list(self.model_inputs.keys())))
            feed_dict[self.model_inputs[k]] = v

        result = self.sess.run(self.model_outputs, feed_dict=feed_dict)
        logger.info('predict result : ' + str(result))

        return result

    def _postprocess(self, data):
        infer_output = {"mnist_result": []}
        for output_name, results in data.items():

            for result in results:
                infer_output["mnist_result"].append(np.argmax(result))

        return infer_output

    def __del__(self):
        self.sess.close()
分享:

    相关文档

    相关产品

文档是否有解决您的问题?

提交成功!

非常感谢您的反馈,我们会继续努力做到更好!

反馈提交失败,请稍后再试!

*必选

请至少选择或填写一项反馈信息

字符长度不能超过200

提交反馈 取消

如您有其它疑问,您也可以通过华为云社区问答频道来与我们联系探讨

智能客服提问云社区提问