文档首页 > > AI工程师用户指南> 管理模型> 评估和诊断模型> 模型评估的优化建议>

利用模型评估功能进行数据增强

利用模型评估功能进行数据增强

分享
更新时间:2021/06/04 GMT+08:00

场景描述

数据集是训练中最重要的一个环节,modelarts平台虽然给出了每类5张图片就能训练的限制,但是这种限制对一个工业级的应用场景往往是远远不够的。这里介绍其中一种带标签扩充数据集的方法。

原理说明

数据集情况

首先,这是一个分类的问题,需要检测出工业零件表面的瑕疵,判断是否为残次品,如下是样例图片。

图1 样例图片

这是两块太阳能电板的表面,左侧是正常的,右侧是有残缺和残次现象的,需要用一个模型来区分这两类的图片,帮助定位哪些太阳能电板存在问题。左侧的正常样本754张,右侧的残次样本358张,验证集同样,正常样本754张,残次样本357张。总样本在2000张左右,对于一般工业要求的95%以上准确率模型而言属于一个非常小的样本。先直接拿这个数据集用Pytorch加载imagenet的resnet50模型训练,整体精度ACC在86.06%左右,召回率正常类为97.3%,但非正常类为62.9%,还不能达到用户预期。

小样本学习few-shot fewshot learning (FSFSL)的常见方法,基本都是从两个方向入手。一是数据本身,二是从模型训练本身,也就是对图像提取的特征做文章。这里从数据本身入手,首先观察数据集,都是300*300的灰度图像,而且都已太阳能电板表面的正面俯视为整张图片。这属于预先处理的很好的图片。那么针对这种图片,翻转镜像对图片整体结构影响不大,所以我们首先可以做的就是flip操作,增加数据的多样性。flip效果如下。

图2 flip效果

这样数据集就从1100张扩增到了2200张,还是不是很多,但是直接观察数据集已经没什么太好的扩充办法了。这时想使用Modelarts模型评估功能来评估模型对数据的泛化能力。这里调用了提供的模型评估接口,deep_moxing.model_analysis下面的analyse接口。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def validate(val_loader, model, criterion, args):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1, top5],
prefix='Test: ')
pred_list = []
target_list = []
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)

# compute output
output = model(images)
loss = criterion(output, target)
# 获取logits输出结果pred和实际目标的结果target
pred_list += output.cpu().numpy()[:, :2].tolist()
target_list += target.cpu().numpy().tolist()
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5), i=i)
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if i % args.print_freq == 0:
progress.display(i)
# TODO: this should also be done with the ProgressMeter
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
# 获取图片的存储路径name
name_list = val_loader.dataset.samples
for idx in range(len(name_list)):
name_list[idx] = name_list[idx][0]
analyse(task_type='image_classification', save_path='/home/image_labeled/',
pred_list=pred_list, label_list=target_list, name_list=name_list)
return top1.avg

上段代码大部分都是Pytorch训练ImageNet中的验证部分代码,需要获取三个list,模型pred直接结果logits、图片实际类别target和图片存储路径name。然后按如上的调用方法调用analyse接口,会在save_path的目录下生成一个json文件,放到Modelarts训练输出目录里,就能在评估结果里看到对模型的分析结果。这里是线下生成的json文件再上传到线上看可视化结果。关于敏感度分析结果如表1表2所示。

表1 图像亮度敏感度分析

特征值分布

0

1

0% - 20%

0.7273

0.8864

20% - 40%

0.8446

0.6892

40% - 60%

0.9077

0.4615

60% - 80%

0.9496

0.5116

80% - 100%

0.966

0.5625

标准差

0.0864

0.1516

表2 图像亮度敏感度分析

特征值分布

0

1

0% - 20%

0.7556

0.8333

20% - 40%

0.8489

0.6466

40% - 60%

0.9239

0.6316

60% - 80%

0.9492

0.8

80% - 100%

0.9631

0.5946

标准差

0.0771

0.0963

上述两个表的意思是,不同的特征值范围图片分别测试的精度是多少。比如亮度敏感度分析的第一项0%-20%,可以理解为,在图片亮度较低的场景下对与0类和其他亮度条件的图片相比,精度要低很多。整体来看,主要是为了检测1类,1类在图片的亮度和清晰度两项上显得都很敏感,也就是模型不能很好地处理图片的这两项特征变化的图片。那这不就是我要扩增数据集的方向吗?

同时,ModelArts平台还提供了使用“数据处理>数据扩增”功能,可以直接扩充数据集。

那么我们就得到一个正常类2210张,瑕疵类1174张图片的数据集,用同样的策略扔进pytorch中训练,得到的结果。

表3 数据扩增后的结果

方法

accuracy

recall norm类

recall abnorm类

原版

86.06%

97.3%

62.9%

从1100张扩增到2940张

86.31%

97.6%

62.5%

从上述结果中,发现精度并没有明显提升。重新分析一下数据集,这种工业类的数据集往往都存在一个样本不均匀的问题,这里虽然接近2:1,但是检测的要求针对有瑕疵的类别的比较高,应该让模型倾向于有瑕疵类去学习,而且看到1类的也就是有瑕疵类的结果比较敏感,所以其实还是存在样本不均衡的情况。由此后面的这两种增强方法只针对了1类也就是有问题的破损类做,最终得到3000张左右,1508张正常类图片,1432张有瑕疵类图片,这样样本就相对平衡了。用同样的策略扔进resnet50中训练。最终得到的精度信息。

表4 修改扩增数据后的结果

方法

accuracy

recall norm类

recall abnorm类

原版

86.06%

97.3%

62.9%

从1100张扩增到2940张

89.13%

97.2%

71.3%

总结

可以看到,同样在验证集,正常样本754张,残次样本357张的样本上,Acc1的精度整体提升了接近3%,重要指标残次类的recall提升了8.4%。所以直接扩充数据集的方法很有效,而且结合模型评估能让您参考哪些扩增的方法是有意义的。当然还有很重要的一点,要排除原始数据集存在的问题,比如这里存在的样本不均衡问题,具体情况具体分析,这个扩增的方法就会变得简单实用。

之后基于这个实验的结果和数据集。帮助用户改了一些训练策略,换了个更厉害的网络,就达到了用户的要求,当然这都是定制化分析的结果,这里不详细展开说明了。

数据集引自:

Buerhop-Lutz, C.; Deitsch, S.; Maier, A.; Gallwitz, F.; Berger, S.; Doll, B.; Hauch, J.; Camus, C. & Brabec, C. J. A Benchmark for Visual Identification of Defective Solar Cells in Electroluminescence Imagery. European PV Solar Energy Conference and Exhibition (EU PVSEC), 2018. DOI: 10.4229/35thEUPVSEC20182018-5CV.3.15

Deitsch, S.; Buerhop-Lutz, C.; Maier, A. K.; Gallwitz, F. & Riess, C. Segmentation of Photovoltaic Module Cells in Electroluminescence Images. CoRR, 2018, abs/1806.06530

Deitsch, S.; Christlein, V.; Berger, S.; Buerhop-Lutz, C.; Maier, A.; Gallwitz, F. & Riess, C. Automatic classification of defective photovoltaic module cells in electroluminescence images. Solar Energy, Elsevier BV, 2019, 185, 455-468. DOI: 10.1016/j.solener.2019.02.067

分享:

    相关文档

    相关产品