深度学习模型预测
深度学习已经广泛应用于图像分类、图像识别和语音识别等不同领域,DLI服务中提供了若干函数实现加载深度学习模型并进行预测的能力。
目前可支持的模型包括DeepLearning4j 模型和Keras模型。由于Keras它能够以 TensorFlow、CNTK或者 Theano 作为后端运行,导入来自Keras的神经网络模型,可以借此导入Theano、Tensorflow、Caffe、CNTK等主流学习框架的模型。
语法格式
1 2 3 4 5 6 7 |
-- 图像分类, 返回预测图像分类的类别id DL_IMAGE_MAX_PREDICTION_INDEX(field_name, model_path, is_dl4j_model) DL_IMAGE_MAX_PREDICTION_INDEX(field_name, keras_model_config_path, keras_weights_path) -- 适用于Keras模型 -- 文本分类,返回预测文本分类的类别id DL_TEXT_MAX_PREDICTION_INDEX(field_name, model_path, is_dl4j_model) -- 采用默认word2vec模型 DL_TEXT_MAX_PREDICTION_INDEX(field_name, word2vec_path, model_path, is_dl4j_model) |
模型及配置文件等需存储在用户的OBS中,路径格式为"obs://your_ak:your_sk@obs.your_obs_region.xxx.com:443/your_model_path"。
参数说明
参数 |
是否必选 |
说明 |
---|---|---|
field_name |
是 |
数据在数据流中的字段名。 图像分类中field_name类型需声明为ARRAY[TINYINT]。 文本分类中field_name类型需声明为String。 |
model_path |
是 |
模型存放在OBS上的完整路径,包括模型结构和模型权值。 |
is_dl4j_model |
是 |
是否是deeplearning4j的模型。 true代表是deeplearning4j,false代表是keras模型。 |
keras_model_config_path |
是 |
模型结构存放在OBS上的完整路径。在keras中通过model.to_json()可得到模型结构。 |
keras_weights_path |
是 |
模型权值存放在OBS上的完整路径。在keras中通过model.save_weights(filepath)可得到模型权值。 |
word2vec_path |
是 |
word2vec模型存放在OBS上的完整路径。 |
示例
图片分类预测我们采用Mnist数据集作为流的输入,通过加载预训练的deeplearning4j模型或者keras模型,可以实时预测每张图片代表的数字。
1 2 3 4 5 6 |
CREATE SOURCE STREAM Mnist( image Array[TINYINT] ) SELECT DL_IMAGE_MAX_PREDICTION_INDEX(image, 'your_dl4j_model_path', false) FROM Mnist SELECT DL_IMAGE_MAX_PREDICTION_INDEX(image, 'your_keras_model_path', true) FROM Mnist SELECT DL_IMAGE_MAX_PREDICTION_INDEX(image, 'your_keras_model_config_path', 'keras_weights_path') FROM Mnist |
文本分类预测我们采用一组新闻标题数据作为流的输入,通过加载预训练的deeplearning4j模型或者keras模型,可以实时预测每个新闻标题所属的类别,比如经济,体育,娱乐等。
1 2 3 4 5 6 7 |
CREATE SOURCE STREAM News( title String ) SELECT DL_TEXT_MAX_PREDICTION_INDEX(title, 'your_dl4j_word2vec_model_path','your_dl4j_model_path', false) FROM News SELECT DL_TEXT_MAX_PREDICTION_INDEX(title, 'your_keras_word2vec_model_path','your_keras_model_path', true) FROM News SELECT DL_TEXT_MAX_PREDICTION_INDEX(title, 'your_dl4j_model_path', false) FROM New SELECT DL_TEXT_MAX_PREDICTION_INDEX(title, 'your_keras_model_path', true) FROM New |