文档首页 > > AI工程师用户指南> 算法管理> 使用常用框架创建算法>

开发自定义脚本

开发自定义脚本

分享
更新时间:2021/06/04 GMT+08:00

当您使用常用框架创建算法时,您需要完成训练的自定义脚本开发。

使用常见框架的训练代码开发

当您使用常用框架创建自定义算法时,您需要在创建页面提供代码目录路径、代码目录路径中的启动文件、输入路径参数和训练输出路径参数。这四种输入搭建了用户和ModelArts后台交互的桥梁。

  • 代码目录路径

    您需要在OBS桶中指定代码目录,并将训练代码、依赖安装包或者预生成模型等训练所需文件上载至该代码目录下。训练作业创建完成后,ModelArts会将代码目录及其子目录下载至后台容器中。

    请注意不要将训练数据放在代码目录路径下。训练数据比较大,训练代码目录在训练作业启动后会下载至后台,可能会有下载失败的风险。

  • 代码目录路径中的启动文件

    代码目录路径中的启动文件作为训练启动的入口,当前只支持python格式。

  • 输入路径参数

    训练数据需上传至OBS桶。在训练代码中,用户需通过解析输入路径参数“data_url”下载训练数据至“/cache”目录。请保证您设置的桶路径有读取权限。在训练作业启动后,ModelArts会挂载硬盘至“/cache”目录,用户可以使用此目录来存储临时文件。“/cache”目录大小请参考训练环境中不同规格资源“/cache”目录的大小

  • 训练输出路径参数

    建议设置一个空目录为训练输出路径。在训练代码中,您需要解析训练输出路径参数“train_url”上载训练输出至指定的训练输出路径,请保证您设置的桶路径有写入权限和读取权限。

当您使用常用框架创建算法时,您需要实现训练代码的开发。在ModelArts中,训练代码需包含以下步骤:

图1 训练代码开发说明
  1. (可选)引入依赖

    当您使用常见框架创建算法的时候,如果您的模型引用了其他依赖,您需要在创建算法的“代码目录”下放置相应的文件或安装包。

    图2 选择代码目录并指定模型启动文件
  2. 解析输入路径参数“data_url”、输出路径参数“train_url”
    • 在使用常用框架创建自定义算法时,您需要在创建算法页面填写输入输出路径参数配置。

      训练数据是算法开发中必不可少的输入。输入数据默认配置为“数据来源”,在算法代码中对应的参数为“data_url”

      图3 解析输入路径参数“data_url”
    • 模型训练结束后,训练模型以及相关输出信息需保存在OBS路径。输出数据默认配置为“模型输出”,在算法代码中对应的参数为“train_url”
      图4 解析输出路径参数“train_url”

    在训练代码中需解析“data_url”“train_url”,ModelArts推荐以下方式实现参数解析。

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    import argparse
    # 创建解析
    parser = argparse.ArgumentParser(description="train mnist",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # 添加参数
    parser.add_argument('--train_url', type=str, default='obs://obs-test/ckpt/mnist',
                        help='the path model saved')
    parser.add_argument('--data_url', type=str, default='obs://obs-test/data/', help='the training data')
    # 解析参数
    args, unkown = parser.parse_known_args()
    
  3. “data_url”导入训练数据

    已知训练输入路径为“data_url”,ModelArts推荐采用Moxing接口实现训练数据下载到“cache”目录。因为在训练作业启动后,ModelArts会挂载硬盘至“/cache”目录,用户可以使用此目录来存储临时文件。“/cache”目录大小请参考训练环境中不同规格资源“/cache”目录的大小

    1
    2
    import moxing as mox 
    mox.file.copy_parallel(args.data_url, /cache)
    
  4. 训练代码正文和保存模型

    训练代码正文和保存模型涉及的代码与您使用的AI引擎密切相关。以下案例以Tensorflow框架为例,训练代码中解析参数方式采用tensorflow接口tf.flags.FLAGS接受命令行参数:

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import os
    
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    import moxing as mox
    
    tf.flags.DEFINE_integer('max_steps', 1000, 'number of training iterations.')
    tf.flags.DEFINE_string('data_url', '/home/jnn/nfs/mnist', 'dataset directory.')
    tf.flags.DEFINE_string('train_url', '/home/jnn/temp/delete', 'saved model directory.')
    
    FLAGS = tf.flags.FLAGS
    
    
    def main(*args):
        mox.file.copy_parallel(FLAGS.data_url, '/cache/data_url')
    
        # Train model
        print('Training model...')
        mnist = input_data.read_data_sets('/cache/data_url', one_hot=True)
        sess = tf.InteractiveSession()
        serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
        feature_configs = {'x': tf.FixedLenFeature(shape=[784], dtype=tf.float32),}
        tf_example = tf.parse_example(serialized_tf_example, feature_configs)
        x = tf.identity(tf_example['x'], name='x')
        y_ = tf.placeholder('float', shape=[None, 10])
        w = tf.Variable(tf.zeros([784, 10]))
        b = tf.Variable(tf.zeros([10]))
        sess.run(tf.global_variables_initializer())
        y = tf.nn.softmax(tf.matmul(x, w) + b, name='y')
        cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
    
        tf.summary.scalar('cross_entropy', cross_entropy)
    
        train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
        tf.summary.scalar('accuracy', accuracy)
        merged = tf.summary.merge_all()
        test_writer = tf.summary.FileWriter('/cache/train_url', flush_secs=1)
    
        for step in range(FLAGS.max_steps):
            batch = mnist.train.next_batch(50)
            train_step.run(feed_dict={x: batch[0], y_: batch[1]})
            if step % 10 == 0:
                summary, acc = sess.run([merged, accuracy], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
                test_writer.add_summary(summary, step)
                print('training accuracy is:', acc)
        print('Done training!')
    
        builder = tf.saved_model.builder.SavedModelBuilder(os.path.join('/cache/train_url', 'model'))
    
        tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
        tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
    
        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs={'images': tensor_info_x},
                outputs={'scores': tensor_info_y},
                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
    
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'predict_images':
                    prediction_signature,
            },
            main_op=tf.tables_initializer(),
            strip_default_attrs=True)
    
        builder.save()
    
        print('Done exporting!')
    
        mox.file.copy_parallel('/cache/train_url', FLAGS.train_url)
    
    
    if __name__ == '__main__':
        tf.app.run(main=main)

  5. 导出训练模型至“train_url”

    已知训练输出位置为“train_url”,ModelArts推荐采用Moxing接口实现输出结果从后台自定义目录“/cache/train_url”目录导出至“train_url”目录。

    mox.file.copy_parallel(“/cache/train_url”, args.train_url)
分享:

    相关文档

    相关产品