文档首页 > > AI工程师用户指南> 训练管理> 自动化搜索作业> 示例:使用预置的数据增强策略进行自动数据增强

示例:使用预置的数据增强策略进行自动数据增强

分享
更新时间:2020/11/17 GMT+08:00

在CIFAR 10数据集上,利用Auto Augment算法,搜得了效果最好的数据增强策略。这个示例里,就演示如何使用这些搜索好的数据增强策略。

样例代码

这是一份用ResNet50网络来训练MNIST上的图像分类模型的代码,训练代码中因为使用自动数据增强策略而需要额外修改的部分已经在注释中说明。

import argparse
import time

import tensorflow as tf
from autosearch.client.augment.offline_search.preprocessor_builder import (
    ImageClassificationTensorflowBuilder,
)    # 改动1: 导入相关decoder模块
from autosearch.client.nas.backbone.resnet import ResNet50    
from tensorflow.examples.tutorials.mnist import input_data

import autosearch

parser = argparse.ArgumentParser()
parser.add_argument(
    "--max_steps", type=int, default=100, help="Number of steps to run trainer."
)
parser.add_argument("--data_url", type=str, default="MNIST_data")

parser.add_argument(
    "--learning_rate",
    type=float,
    default=0.01,  
    help="Number of steps to run trainer.",
)
FLAGS, unparsed = parser.parse_known_args()


def train():
    mnist = input_data.read_data_sets(FLAGS.data_url, one_hot=True)
    with tf.Graph().as_default():
        sess = tf.InteractiveSession()
        with tf.name_scope("input"):
            x = tf.placeholder(tf.float32, [None, 784], name="x-input")
            y_ = tf.placeholder(tf.int64, [None, 10], name="y-input")
        image_shaped_input = tf.multiply(x, 255)
        image_shaped_input = tf.cast(image_shaped_input, tf.uint8)
        image_shaped_input = tf.reshape(image_shaped_input, [-1, 784, 1])
        image_shaped_input = tf.concat([image_shaped_input, image_shaped_input, image_shaped_input], axis=2)
        image_shaped_input = ImageClassificationTensorflowBuilder("offline")(image_shaped_input)    # 改动2: decoder模块会自动解析框架下发的参数,转换成对应的增强操作
        image_shaped_input = tf.cast(image_shaped_input, tf.float32)
        image_shaped_input = tf.reshape(image_shaped_input, [-1, 28, 28, 3])
        image_shaped_input = tf.multiply(image_shaped_input, 1 / 255.0)
        y = ResNet50(image_shaped_input, include_top=True, mode="train")
        with tf.name_scope("cross_entropy"):
            y = tf.reduce_mean(y, [1, 2])
            y = tf.layers.dense(y, 10)
            with tf.name_scope("total"):
                cross_entropy = tf.losses.softmax_cross_entropy(y_, y)

        with tf.name_scope("train"):
            train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(  
                cross_entropy
            )

        with tf.name_scope("accuracy"):
            with tf.name_scope("correct_prediction"):
                correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
            with tf.name_scope("accuracy"):
                accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.global_variables_initializer().run()

        def feed_dict(train):
            if train:
                xs, ys = mnist.train.next_batch(100)
            else:
                xs, ys = mnist.test.next_batch(10000)
            return {x: xs, y_: ys}

        max_acc = 0
        latencys = []
        for i in range(FLAGS.max_steps):
            if i % 10 == 0:  # Record summaries and test-set accuracy
                loss, acc = sess.run(
                    [cross_entropy, accuracy], feed_dict=feed_dict(False)
                )
                # print('loss at step %s: %s' % (i, loss))
                print("Accuracy at step %s: %s" % (i, acc))
                if acc > max_acc:
                    max_acc = acc
                # autosearch.reporter(loss=loss)
                autosearch.reporter(mean_accuracy=acc)    # 改动3: 反馈精度给AutoSearch框架
            else:
                start = time.time()
                loss, _ = sess.run(
                    [cross_entropy, train_step], feed_dict=feed_dict(True)
                )
                end = time.time()
                if i % 10 != 1:
                    latencys.append(end - start)
        latency = sum(latencys) / len(latencys)
        autosearch.reporter(mean_accuracy=max_acc, latency=latency)    # 同改动3
        sess.close()

def cloud_init(s3_url):
    local_data_url = "/cache/mnist"
    import moxing as mox
    print(
        'Copying from s3_url({})" to local path({})'.format(s3_url, local_data_url)
    )
    mox.file.copy_parallel(s3_url, local_data_url)
    return local_data_url

由于策略是在CIFAR10上进行搜索的,所以支持的数据是CIFAR10格式的,也就是RGB三通道,每个像素取值范围是0-255。而MNIST的数据默认是单通道,像素取值范围已经归一化到了0~1之间,所以,在上面的样例代码中,存在额外的让数据变成CIFAR 10的格式的操作,详情请参见下面的代码片段。

1
2
3
4
5
6
7
8
        image_shaped_input = tf.multiply(x, 255)
        image_shaped_input = tf.cast(image_shaped_input, tf.uint8)
        image_shaped_input = tf.reshape(image_shaped_input, [-1, 784, 1])
        image_shaped_input = tf.concat([image_shaped_input, image_shaped_input, image_shaped_input], axis=2)
        image_shaped_input = ImageClassificationTensorflowBuilder("offline")(image_shaped_input)    # 改动2: decoder模块会自动解析框架下发的参数,转换成对应的增强操作
        image_shaped_input = tf.cast(image_shaped_input, tf.float32)
        image_shaped_input = tf.reshape(image_shaped_input, [-1, 28, 28, 3])
        image_shaped_input = tf.multiply(image_shaped_input, 1 / 255.0)

配置文件编写

grid_search的目的只是为了传参数给我们嵌入在代码中的解码模块,实际只有一个策略。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
general:
  gpu_per_instance: 1

search_space:
  - type: discrete
    params:
      - name: image_classification_auto_augment
	values: [
		["4-4-3", "6-6-7", "7-3-9", "6-7-9", "1-6-5", "1-5-1", "5-6-7", "7-6-5", "6-3-7", "0-5-8", "0-9-4", "0-5-6", "14-3-5", "1-6-5", "6-0-8", "4-8-8", "14-2-6", "4-8-6", "14-2-6", "0-8-1", "14-4-1", "1-6-5", "6-0-0", "14-5-2", "0-9-5", "6-5-3", "5-7-5", "6-0-2", "14-2-8", "14-1-5", "0-9-4", "1-8-4", "6-0-7", "1-4-7", "14-2-5", "1-7-5", "1-6-8", "4-6-2", "4-3-7", "4-2-4", "0-5-2", "14-7-2", "0-2-0", "1-1-0", "6-9-3", "0-4-1", "1-8-8", "1-7-7", "1-7-7", "14-5-0", "1-3-7", "0-4-8", "6-9-6", "4-2-8", "0-1-5", "6-0-0", "8-2-4", "1-1-1", "1-7-7", "0-6-4", "1-8-2", "0-9-5", "1-5-0", "14-6-6", "1-9-5", "4-7-0", "0-7-3", "1-7-0", "6-5-1", "5-1-7", "5-1-4", "14-6-5", "0-3-9", "8-5-3", "0-9-2", "2-0-3", "14-4-3", "4-2-4", "1-1-4", "1-7-6", "1-3-8", "0-4-3", "14-6-4", "0-7-6", "0-2-9", "6-4-8", "1-1-0", "1-0-6", "1-8-4", "1-0-4", "1-5-5", "0-1-2", "14-5-5", "0-9-5", "0-6-1", "0-7-8", "1-2-0", "0-1-2", "1-6-9", "1-4-4"]
	]

search_algorithm:
  type: grid_search
  reward_attr: mean_accuracy

scheduler:
  type: FIFOScheduler

启动搜索作业

此样例需要用到MNIST数据集,请参考示例:使用更优秀的网络结构替换原生ResNet50上传、配置数据集后,上传python脚本和yaml配置文件,根据创建自动化搜索作业章节启动搜索作业。

分享:

    相关文档

    相关产品

文档是否有解决您的问题?

提交成功!非常感谢您的反馈,我们会继续努力做到更好!
反馈提交失败,请稍后再试!

*必选

请至少选择或填写一项反馈信息

字符长度不能超过200

提交反馈 取消

如您有其它疑问,您也可以通过华为云社区问答频道来与我们联系探讨

智能客服提问云社区提问