文档首页 > > AI Gallery用户指南> 预置AI算法(官方发布)> 文本分类>

BERT(文本分类/TensorFlow)

BERT(文本分类/TensorFlow)

分享
更新时间:2020/12/14 GMT+08:00

概述

基于BERT预训练模型的文本分类算法,支持单标签文本分类。预训练模型基于BERT BASE模型。用户需要在数据管理平台完成标注,该算法会载入预训练模型在用户数据集上做迁移学习。训练后生成的模型可直接在ModelArts平台部署为在线服务或批量服务,同时支持使用CPU、GPU或Ascend 310进行推理。

单击此处订阅算法。

训练

  • 算法基本信息
    • 适用场景:文本分类
    • 支持的框架引擎:Tensorflow-1.13.1-python3.6
    • 算法输入:
      • ModelArts数据管理平台发布的文本分类数据集(数据集必须设置“训练验证比例”),建议用户以8:2或9:1的比例进行切分,即“训练验证比例”设置为0.8或0.9。
    • 算法输出:
      • 用于TF-Serving推理的saved_model模型。
  • 训练参数说明

    名称

    默认值

    类型

    是否必填

    是否可修改

    描述

    num_train_epochs

    5

    int

    训练数据的次数。

    max_training_time

    1

    int

    最大训练小时数(单位:小时)。

    train_batch_size

    32

    int

    每次迭代训练的单卡输入句子数量。

    eval_batch_size

    32

    int

    验证时每步读取的单卡输入句子数量。

    max_seq_length

    256

    int

    输入句子最大长度,同数据集相关,实际输入少于该值会补0,多于该值会被截断。

    export_d_model

    False

    bool

    是否导出用于Ascend推理模型。

    task_type

    bert_classifier

    string

    适用场景。

    model_name

    bert_base_chinese

    string

    模型名称。

  • 训练输出文件

    训练完成后的输出文件如下。

      |- om
        |- model
          |- index
          |- customize_service_d310.py
      |- model
        |- variables
          |- variables.data-00000-of-00001
          |- variables.index
        |- customize_service.py
        |- index
        |- config.json
        |- saved_model.pb
      |- frozen_graph
        |- model_d.pb
      |- checkpoint
      |- model.ckpt-xxx
      |- ...
      |- best_checkpoint
      |- best_model.ckpt-xxx
      |- ...
      |- events...
      |- graph.pbtxt

Ascend 310推理

  • 模型转换
    • 转换模板:TF-FrozenGraph-To-Ascend-C32
    • 转换输入目录:选择“训练输出目录”中的frozen_graph
    • 转换输出目录:选择“训练输出目录”中的om/model
    • 输入张量形状:input_ids:1,256;input_mask:1,256;segment_ids:1,256
    • 输入数据格式:NCHW

    其他参数均使用默认值。

  • 模型导入
    • 从模板中选择:ARM-Ascend模板
    • 模型目录:选择“训练输出目录”中的om/model
    • 输入输出模式:预置预测分析模式

GPU/CPU推理

“元模型来源”选择“从训练中选择”,选择训练作业及版本。

推理配置文件“model/config.json”,默认使用CPU推理镜像(runtime:tf1.xx-python3.x-cpu)。若使用GPU推理,导入模型之前需修改“model/config.json”文件,将runtime字段修改为“tf1.xx-python3.x-gpu”

推理输入、输出格式

使用curl命令发送预测请求的命令。

  • content为实际要推理的文本内容,最大长度为“max_seq_length”的参数值,大于该字段的会被截断。
  • “-H”是post命令的headers,Headers的Key值为“X-Auth-Token”,此名称为固定的, Token值是用户获取到的token值(关于如何获取token,请参考获取请求认证)。
  • POST后面跟随的是在线服务的调用地址。
1
curl -d '{"text":"content"}' -H 'X-Auth-Token:Token值' -H 'Content-type: application/json' -X POST 在线服务地址​

以json字符串的形式返回请求结果,如下所示。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
{
  "predicted_label": 置信度最高的分类标签,            # top1结果
  "scores":                                           # top5结果
         [
                 score1,
                 score2,
                 score3,
                 score4,
                 score5,
         ]
}

案例指导

GPU训练+Ascend 310推理,可参考图像分类案例

分享:

    相关文档

    相关产品