更新时间:2022-08-12 GMT+08:00

深度学习模型预测

深度学习已经广泛应用于图像分类、图像识别和语音识别等不同领域,DLI服务中提供了若干函数实现加载深度学习模型并进行预测的能力。

目前可支持的模型包括DeepLearning4j 模型和Keras模型。由于Keras它能够以 TensorFlow、CNTK或者 Theano 作为后端运行,导入来自Keras的神经网络模型,可以借此导入Theano、Tensorflow、Caffe、CNTK等主流学习框架的模型。

语法格式

1
2
3
4
5
6
7
-- 图像分类, 返回预测图像分类的类别id
DL_IMAGE_MAX_PREDICTION_INDEX(field_name, model_path, is_dl4j_model)
DL_IMAGE_MAX_PREDICTION_INDEX(field_name, keras_model_config_path, keras_weights_path) -- 适用于Keras模型

-- 文本分类,返回预测文本分类的类别id
DL_TEXT_MAX_PREDICTION_INDEX(field_name, model_path, is_dl4j_model) -- 采用默认word2vec模型
DL_TEXT_MAX_PREDICTION_INDEX(field_name, word2vec_path, model_path, is_dl4j_model)

模型及配置文件等需存储在用户的OBS中,路径格式为"obs://your_ak:your_sk@obs.your_obs_region.xxx.com:443/your_model_path"。

参数说明

表1 参数说明

参数

是否必选

说明

field_name

数据在数据流中的字段名。

图像分类中field_name类型需声明为ARRAY[TINYINT]。

文本分类中field_name类型需声明为String。

model_path

模型存放在OBS上的完整路径,包括模型结构和模型权值。

is_dl4j_model

是否是deeplearning4j的模型。

true代表是deeplearning4j,false代表是keras模型。

keras_model_config_path

模型结构存放在OBS上的完整路径。在keras中通过model.to_json()可得到模型结构。

keras_weights_path

模型权值存放在OBS上的完整路径。在keras中通过model.save_weights(filepath)可得到模型权值。

word2vec_path

word2vec模型存放在OBS上的完整路径。

示例

图片分类预测我们采用Mnist数据集作为流的输入,通过加载预训练的deeplearning4j模型或者keras模型,可以实时预测每张图片代表的数字。

1
2
3
4
5
6
CREATE SOURCE STREAM Mnist(
    image Array[TINYINT]
)
SELECT DL_IMAGE_MAX_PREDICTION_INDEX(image, 'your_dl4j_model_path', false) FROM Mnist
SELECT DL_IMAGE_MAX_PREDICTION_INDEX(image, 'your_keras_model_path', true) FROM Mnist
SELECT DL_IMAGE_MAX_PREDICTION_INDEX(image, 'your_keras_model_config_path', 'keras_weights_path') FROM Mnist

文本分类预测我们采用一组新闻标题数据作为流的输入,通过加载预训练的deeplearning4j模型或者keras模型,可以实时预测每个新闻标题所属的类别,比如经济,体育,娱乐等。

1
2
3
4
5
6
7
CREATE SOURCE STREAM News(
    title String
)
SELECT DL_TEXT_MAX_PREDICTION_INDEX(title, 'your_dl4j_word2vec_model_path','your_dl4j_model_path', false) FROM News
SELECT DL_TEXT_MAX_PREDICTION_INDEX(title, 'your_keras_word2vec_model_path','your_keras_model_path', true) FROM News
SELECT DL_TEXT_MAX_PREDICTION_INDEX(title, 'your_dl4j_model_path', false) FROM New
SELECT DL_TEXT_MAX_PREDICTION_INDEX(title, 'your_keras_model_path', true) FROM New