文档首页 > > AI工程师用户指南> 自定义脚本代码示例> XGBoost

XGBoost

分享
更新时间:2020/07/06 GMT+08:00

训练并保存模型

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split

# Prepare training data and setting parameters
iris = pd.read_csv('/data/iris.csv')
X = iris.drop(['virginica'],axis=1)
y = iris[['virginica']]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234565)
params = {
    'booster': 'gbtree',
    'objective': 'multi:softmax',
    'num_class': 3,
    'gamma': 0.1,
    'max_depth': 6,
    'lambda': 2,
    'subsample': 0.7,
    'colsample_bytree': 0.7,
    'min_child_weight': 3,
    'silent': 1,
    'eta': 0.1,
    'seed': 1000,
    'nthread': 4,
}
plst = params.items()
dtrain = xgb.DMatrix(X_train, y_train)
num_rounds = 500
model = xgb.train(plst, dtrain, num_rounds)
model.save_model('/tmp/xgboost.m')

保存完模型后,需要上传到OBS目录才能发布。发布时需要带上config.json配置以及customize_service.py,定义方式参考模型包规范介绍

推理代码

# coding:utf-8
import collections
import json
import xgboost as xgb
from model_service.python_model_service import XgSklServingBaseService
class user_Service(XgSklServingBaseService):

    # request data preprocess
    def _preprocess(self, data):
        list_data = []
        json_data = json.loads(data, object_pairs_hook=collections.OrderedDict)
        for element in json_data["data"]["req_data"]:
            array = []
            for each in element:
                array.append(element[each])
                list_data.append(array)
        return list_data

    #   predict
    def _inference(self, data):
        xg_model = xgb.Booster(model_file=self.model_path)
        pre_data = xgb.DMatrix(data)
        pre_result = xg_model.predict(pre_data)
        pre_result = pre_result.tolist()
        return pre_result

    # predict result process
    def _postprocess(self,data):
        resp_data = []
        for element in data:
            resp_data.append({"predictresult": element})
        return resp_data
分享:

    相关文档

    相关产品

文档是否有解决您的问题?

提交成功!非常感谢您的反馈,我们会继续努力做到更好!
反馈提交失败,请稍后再试!

*必选

请至少选择或填写一项反馈信息

字符长度不能超过200

提交反馈 取消

如您有其它疑问,您也可以通过华为云社区问答频道来与我们联系探讨

智能客服提问云社区提问