AI开发平台ModelArtsAI开发平台ModelArts

更新时间:2021/09/18 GMT+08:00
分享

数据增强(图像生成)

图像生成算子概述

图像生成算子利用Gan网络依据已知的数据集生成新的数据集。Gan是一个包含生成器和判别器的网络,生成器从潜在空间中随机取样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别器的输入则为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能分辨出来。而生成网络则要尽可能地欺骗判别网络。两个网络相互对抗、不断调整参数,最终目的是使判别网络无法判断生成网络的输出结果是否真实。训练中获得的生成器网络可用于生成与输入图片相似的图片,用作新的数据集参与训练。基于Gan网络生成新的数据集不会生成相应的标签。图像生成过程不会改动原始数据,新生成的图片或xml文件保存在指定的输出路径下。

ModelArts提供两种类型的图像生成算子:

  • CycleGan算子:基于CycleGAN用于生成域迁移的图像,即将一类图片转换成另一类图片,把X空间中的样本转换成Y空间中的样本。CycleGAN可以利用非成对数据进行训练。模型训练时运行支持两个输入,分别代表数据的原域和目标域,在训练结束时会生成所有原域向目标域迁移的图像。
    图1 CycleGan算子
表1 CycleGan算子高级参数

参数名

默认值

参数说明

do_validation

True

是否做数据校验,默认为True,表示数据生成前需要做数据校验,否则只做数据生成。

image_channel

3

生成图像的通道数。

image_height

256

图像相关参数:生成图像的高,大小需要是2的次方。

image_width

256

图像相关参数:生成图像的宽,大小需要是2的次方

batch_size

1

训练相关参数:批量训练样本个数。

max_epoch

100

训练相关参数:训练遍历数据集次数。

g_learning_rate

0.0001

训练相关参数:生成器训练学习率。

d_learning_rate

0.0001

训练相关参数:判别器训练学习率。

log_frequency

5

训练相关参数:日志打印频率(按step计数)。

save_frequency

5

训练相关参数:模型保存频率(按epoch计数)。

predict

False

是否进行推理预测,默认为False。如果设置True,需要在resume参数设置已经训练完成的模型的obs路径。

resume

empty

如果predict设置为True,需要填写Tensorflow模型文件的obs路径用于推理预测。当前仅支持“.pb”格式的模型。示例:obs://xxx/xxxx.pb。

默认值为empty。

  • StyleGan算子:基于StyleGan2用于在数据集较小的情形下,随机生成相似图像。StyleGAN提出了一个新的生成器结构,能够控制所生成图像的高层级属性(high-level attributes),如发型、雀斑等;并且生成的图像在一些评价标准上得分更好。而本算法又增加了数据增强算法,可以在较少样本的情况下也能生成较好的新样本,但是样本数尽量在70张以上,样本太少生成出来的新图像不会有太多的样式。
    图2 StyleGan算子
表2 StyleGan算子高级参数

参数名

默认值

参数说明

resolution

256

生成正方形图像的高宽,大小需要是2的次方。

batch-size

8

批量训练样本个数。

total-kimg

300

总共训练的图像数量为total_kimg*1000。

generate_num

300

生成的图像数量,如果是多个类的,则为每类生成的数量。

predict

False

是否进行推理预测,默认为False。如果设置True,需要在resume参数设置已经训练完成的模型的obs路径。

resume

empty

如果predict设置为True,需要填写Tensorflow模型文件的obs路径用于推理预测。当前仅支持“.pb”格式的模型。示例:obs://xxx/xxxx.pb。

默认值为empty。

do_validation

True

是否做数据校验,默认为True,表示数据生成前需要做数据校验,否则只做数据生成。

数据输入

算子输入分为两种,“数据集”“OBS目录”

  • 选择“数据集”,请从下拉框中选择ModelArts中管理的数据集及其版本。要求数据集类型与您在本任务中选择的场景类别一致。
  • 选择“OBS目录”,图像生成算子不需要标注信息,输入支持单层级或双层级目录,存放结构支持“单层级”“双层级”模式。

单层级目录结构如下所示:

image_folder----0001.jpg           
            ----0002.jpg            
            ----0003.jpg            
            ...            
            ----1000.jpg

双层级目录结构如下所示:

image_folder----sub_folder_1----0001.jpg                            
                            ----0002.jpg                            
                            ----0003.jpg                            
                            ...                            
                            ----0500.jpg            
            ----sub_folder_2----0001.jpg                            
                            ----0002.jpg                           
                            ----0003.jpg                            
                            ...                            
                            ----0500.jpg
                            ...            
            ----sub_folder_100----0001.jpg                            
                              ----0002.jpg                            
                              ----0003.jpg                            
                              ...                            
                              ----0500.jpg

输出说明

输出目录的结构如下所示。其中“model”文件夹存放用于推理的“frozen pb”模型,“samples”文件夹存放训练过程中输出图像,“Data”文件夹存放训练模型生成的图像。

train_url----model----CYcleGan_epoch_10.pb                  
                  ----CYcleGan_epoch_20.pb                  
                  ...                 
                  ----CYcleGan_epoch_1000.pb         
         ----samples----0000_0.jpg                   
                   ----0000_1.jpg                  
                   ...                   
                   ----0100_15.jpg         
         ----Data----CYcleGan_0_0.jpg                 
                 ----CYcleGan_0_1.jpg                 
                 ...                 
                 ----CYcleGan_16_8.jpg         
         ----output_0.manifest

其中manifest文件内容示例如下所示。

{
	"id": "xss",
	"source": "obs://home/fc8e2688015d4a1784dcbda44d840307_14.jpg",
	"usage": "train", 
	"annotation": [
		{
			"name": "Cat", 
			"type": "modelarts/image_classification"
		}
	]
}

分享:

    相关文档

    相关产品