文档首页 > > AI工程师用户指南> 训练管理> 自动化搜索作业> 示例:使用更优秀的网络结构替换原生ResNet50

示例:使用更优秀的网络结构替换原生ResNet50

分享
更新时间:2020/07/06 GMT+08:00

以使用resnet50在MNIST数据集上的分类任务为例。

数据准备

ModelArts在公共OBS桶中提供了MNIST数据集,命名为“Mnist-Data-Set”,本文的操作示例可使用此数据集。请执行如下操作,将数据集上传至您的OBS目录下,例如上传至“test-modelarts/dataset-mnist”

  1. 单击数据集下载链接,将“Mnist-Data-Set”数据集下载至本地。
  2. 在本地,将“Mnist-Data-Set.zip”压缩包解压。例如,解压至本地“Mnist-Data-Set”文件夹下。
  3. 参考上传文件,使用批量上传方式将“Mnist-Data-Set”文件夹下的所有文件上传至“test-modelarts/dataset-mnist”OBS路径下。

    “Mnist-Data-Set”数据集包含的内容如下所示,其中“.gz”为对应的压缩包。

    本示例需使用“.gz”压缩包格式,请务必将数据集的4个压缩包上传至OBS目录。

    • “t10k-images-idx3-ubyte.gz”:验证集,共包含10000个样本。
    • “t10k-labels-idx1-ubyte.gz”:验证集标签,共包含10000个样本的类别标签。
    • “train-images-idx3-ubyte.gz”:训练集,共包含60000个样本。
    • “train-labels-idx1-ubyte.gz”:训练集标签,共包含60000个样本的类别标签。

样例代码

假设我们手上有一份用ResNet50做MNIST手写数字识别的图像分类任务的TensorFlow代码,只需进行五行的改动,就能使用自动化搜索作业替换其中的ResNet50结构。下面就是一份修改后的代码,修改的地方都用注释作了说明。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
import argparse
import time
import os
import logging

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

import autosearch    # 改动1:导入autosearch包
from autosearch.client.nas.backbone.resnet import ResNet50    # 改动2: 导入预置的ResNet50模块,能把结构编码解码成TensorFlow的结构

parser = argparse.ArgumentParser()
parser.add_argument(
    "--max_steps", type=int, default=100, help="Number of steps to run trainer."
)
parser.add_argument("--data_url", type=str, default="MNIST_data")

parser.add_argument(
    "--learning_rate",
    type=float,
    default=0.01,  # st2
    help="Number of steps to run trainer.",
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
FLAGS, unparsed = parser.parse_known_args()

logger = logging.getLogger(__name__)


def train():
    if is_cloud():
        FLAGS.data_url = cloud_init(FLAGS.data_url)
    mnist = input_data.read_data_sets(FLAGS.data_url, one_hot=True)
    with tf.Graph().as_default():
        sess = tf.InteractiveSession()
        with tf.name_scope("input"):
            x = tf.placeholder(tf.float32, [None, 784], name="x-input")
            y_ = tf.placeholder(tf.int64, [None, 10], name="y-input")
        image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])
        y = ResNet50(image_shaped_input, include_top=True, mode="train")  # 改动3: 使用导入的ResNet50解码模块替换原来代码中的ResNet50
        with tf.name_scope("cross_entropy"):
            y = tf.reduce_mean(y, [1, 2])
            y = tf.layers.dense(y, 10)
            with tf.name_scope("total"):
                cross_entropy = tf.losses.softmax_cross_entropy(y_, y)

        with tf.name_scope("train"):
            train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(  # st2
                cross_entropy
            )

        with tf.name_scope("accuracy"):
            with tf.name_scope("correct_prediction"):
                correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
            with tf.name_scope("accuracy"):
                accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.global_variables_initializer().run()

        def feed_dict(train):
            if train:
                xs, ys = mnist.train.next_batch(100)
            else:
                xs, ys = mnist.test.images, mnist.test.labels
            return {x: xs, y_: ys}

        max_acc = 0
        latencies = []
        for i in range(FLAGS.max_steps):
            if i % 10 == 0:  # Record summaries and test-set accuracy
                loss, acc = sess.run(
                    [cross_entropy, accuracy], feed_dict=feed_dict(False)
                )
                print("loss at step %s: %s" % (i, loss))
                print("acc step %s: %s" % (i, acc))
                if acc > max_acc:
                    max_acc = acc
                autosearch.reporter(loss=loss, mean_accuracy=max_acc)    # 改动4: 反馈精度给AutoSearch框架
            else:
                start = time.time()
                loss, _ = sess.run(
                    [cross_entropy, train_step], feed_dict=feed_dict(True)
                )
                end = time.time()
                if i % 10 != 1:
                    latencies.append(end - start)
        latency = sum(latencies) / len(latencies)
        autosearch.reporter(mean_accuracy=max_acc, latency=latency)    # 同改动4: 反馈精度给AutoSearch框架
        sess.close()


def is_cloud():
    return True if os.path.exists("/home/work/user-job-dir") else False


def cloud_init(s3_url):
    local_data_dir = "/cache/mnist"

    import moxing as mox

    logger.info(
        'Copying from s3_url({})" to local path({})'.format(s3_url, local_data_dir)
    )
    mox.file.copy_parallel(s3_url, local_data_dir)

    return local_data_dir

配置文件编写

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
general:
  gpu_per_instance: 1

search_space:
  - type: discrete
    params:
      - name: resnet50
        values: ["1-11111111-2111121111-211111", "1-1112-1111111111121-11111112111", "1-11111121-12-11111211", "11-111112-112-11111211", "1-1-111111112-11212", "1-1-1-2112112", "1-111211-1111112-21111111", "1-1111111111-21112112-11111","1-111111111112-121111111121-11","11-211-121-11111121", "111-2111-211111-211"]

search_algorithm:
  type: grid_search
  reward_attr: mean_accuracy

scheduler:
  type: FIFOScheduler

启动搜索作业

参考创建自动化搜索作业操作指导,创建一个自动化搜索作业,将启动文件设置为样例代码中的示例代码文件,将“config_path”设置为示例yaml文件的OBS路径,例如“obs://bucket_name/config.yaml”。配置完成后,提交作业启动搜索作业。

  • 样例代码需存储为.py文件,为搜索作业的启动脚本。
  • yaml配置文件,必须以.yaml结尾。
  • 启动脚本和yaml配置文件的命名可根据实际业务进行命名。
  • 启动脚本和yaml配置文件需提前上传至OBS,且此OBS桶与当前使用的ModelArts在同一区域。
图1 设置自动化搜索作业

查看搜索结果

等待自动化搜索作业运行结束后,单击作业名称进入作业详情页面,单击“搜索结果”页签,查看搜索结果。

分享:

    相关文档

    相关产品

文档是否有解决您的问题?

提交成功!

非常感谢您的反馈,我们会继续努力做到更好!

反馈提交失败,请稍后再试!

*必选

请至少选择或填写一项反馈信息

字符长度不能超过200

提交反馈 取消

如您有其它疑问,您也可以通过华为云社区问答频道来与我们联系探讨

智能客服提问云社区提问