文档首页 > > AI工程师用户指南> 训练管理> 自动化搜索作业> 示例:使用MBNAS算法搜索网络结构

示例:使用MBNAS算法搜索网络结构

分享
更新时间:2020/06/17 GMT+08:00

MBNAS是Model-based Neural Architecture Search的缩写,是华为自研的NAS算法。

图1 MBNAS算法介绍

其思想是:

  • 在搜索空间内,用随机补集的方式,采样200个结构,并行训练这200个结构,得到200个结构对应的reward。
  • 利用这200个数据,训练一个Evaluator,这个Evaluator,能预测一个采样到的结构对应的reward。
  • 基于上面的Evaluator,训练一个基于强化学习或者是进化算法的Controller。
  • 由训练好的Controller提出若干个优秀的结构。

为了方便起见,我们以类似MNIST的模拟数据来演示,您也可以将下面的代码修改成使用示例:使用更优秀的网络结构替换原生ResNet50中的MNIST数据集。

样例代码

样例代码中,以行末注释的方式,标明了对于原来的代码进行的额外修改工作。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import logging
import time
import tensorflow as tf
from tensorflow import keras
import autosearch    # 改动1: 导入包

parser = argparse.ArgumentParser()
parser.add_argument(
    "--max_steps", type=int, default=10 ** 2, help="Number of steps to run trainer."
)
FLAGS, unparsed = parser.parse_known_args()

batch_size = 1
learning_rate = 0.001


def train(config, reporter):
    nas_code_extra = config["nas_code_extra"]    # 改动2: 获得框架下发的参数

    with tf.Graph().as_default():
        sess = tf.InteractiveSession()
        x = tf.ones([batch_size, 784], name="x-input")
        y_true = tf.ones([batch_size, 10], name="y-input", dtype=tf.int64)

        for i, stage in enumerate(nas_code_extra):
            x = keras.layers.Dense(
                units=stage["hidden_units"], activation=stage["activation"],
            )(x)    # 改动3:根据传入的模型参数,来对模型构图

        y = tf.layers.dense(x, 10)
        cross_entropy = tf.losses.softmax_cross_entropy(y_true, y)
        train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)

        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_true, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.global_variables_initializer().run()
        max_acc = 0
        latencies = []
        for i in range(FLAGS.max_steps):
            logging.info("current step: {}".format(i))
            if i % 10 == 0:  # Record summaries and test-set accuracy
                loss, acc = sess.run([cross_entropy, accuracy])
                # print('Accuracy at step %s: %s' % (i, acc))
                if acc > max_acc:
                    max_acc = acc
                if i == (FLAGS.max_steps - 1):
                    autosearch.reporter(
                        loss=loss, acc=max_acc, mean_loss=loss, done=True
                    )    # 改动4: 反馈精度指标
                else:
                    autosearch.reporter(loss=loss, acc=acc, mean_loss=loss)    # 同改动4
            else:
                start = time.time()
                loss, _ = sess.run([cross_entropy, train_step])
                end = time.time()
                if i % 10 != 1:
                    latencies.append(end - start)
        latency = sum(latencies) / len(latencies)
        print(max_acc, latency)

配置文件编写

repeat_discrete用来配合MBNAS使用,按这个yaml配置,前面的代码中的for循环中的内容就会执行4次。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
general:
  gpu_per_instance: 0

search_space:
  - type: repeat_discrete
    name: mbnas
    repeat: 4
    params:
      - name: activation
        values: ['sigmoid', 'relu']
      - name: hidden_units
        values: [1, 2, 4, 8]

search_algorithm:
  type: mbnas
  reward_attr: acc
  num_of_arcs: 20

启动搜索作业

将样例代码的脚本和yaml文件上传至OBS后,即可在页面上启动作业。由于不需要实际数据,因此任意选择已有数据集,或者空的OBS目录即可。其他配置的选择参考示例:使用经典超参算法搜索超参中的启动搜索作业步骤。

分享:

    相关文档

    相关产品

文档是否有解决您的问题?

提交成功!非常感谢您的反馈,我们会继续努力做到更好!
反馈提交失败,请稍后再试!

*必选

请至少选择或填写一项反馈信息

字符长度不能超过200

提交反馈 取消

如您有其它疑问,您也可以通过华为云社区问答频道来与我们联系探讨

智能客服提问云社区提问