AI开发平台ModelArtsAI开发平台ModelArts

更新时间:2021/08/06 GMT+08:00
分享

模型优化中常见优化模型精度的方法

前言

在很多深度学习的比赛项目中,各种方法trick层出不穷,其中有一种颇受争议的方法就是在测试时使用增强的手段,将输入的源图片生成多份分别送入模型,然后对所有的推理结果做一个综合整合。这种方法被称为测试时增强(test time augmentation,TTA),本章节介绍测试时增强的原理及建议。

原理说明

  • TTA流程

    TTA的基本流程是通过对原图做增强操作,获得很多份增强后的样本与原图组成一个数据组,然后用这些样本获取推理结果,最后把多份的推理结果按一定方法合成得到最后的推理结果再进行精度指标计算。

    图1 TTA流程图

    这么看上去需要确认很多问题:

    1. 原图片需要用什么增强方法来生成新的样本。
    2. 生成的样本在获取推理结果之后应该使用什么样的方法进行合成。

    我们举个简单的例子来说明TTA的作用以及如何利用ModelArts平台提供的功能来使用TTA。

  • TTA使用实例
    • 数据集:数据集样例图片如下所示。其中左侧为正常样本图像,共754张,右侧为有瑕疵的电板图像,共358张,经过一定的增强手段后扩充至1508张正常类图片,1432张有瑕疵类图片。
      图2 数据集样例
    • 使用框架及算法:pytorch官方提供的训练imagenet开源代码
    • 训练策略:50个epoch,初始学习率lr0.001,batchsize16用Adam的优化器训练。
    原模型精度信息

    精度信息

    正常类

    有瑕疵非正常类

    召回率recall

    97.2%

    71.3%

    准确率accuracy

    89.13%

  • TTA过程
    1. 首先,需选定使用的增强方法来获取多样本。这里有两种方法:
      1. 从训练中使用的增强手段入手,用训练中使用的增强手段获取多样本。

        如pytorch训练imagenet的代码中,使用了算子transforms.RandomHorizontalFlip()做垂直方向的翻转操作。那么对于模型而言,应该也见过很多经过垂直翻转的图片,所以我们可以用垂直方向的翻转来作为增强手段的一种。

      2. 进行模型评估,从模型评估的结果中分析该使用什么样的增强方法。

        对原模型进行评估,评估代码如下,这里是修改了开源代码中validate部分做前向部分推理的代码:

        with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
        if args.gpu is not None:
        images = images.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)
        
        # compute output
        output_origin = model(images)
        output = output_origin
        loss = criterion(output, target)
        pred_list += output.cpu().numpy()[:, :2].tolist()
        target_list += target.cpu().numpy().tolist()
        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5), i=i)
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        if i % args.print_freq == 0:
        progress.display(i)
        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
        .format(top1=top1, top5=top5))
        name_list = val_loader.dataset.samples
        for idx in range(len(name_list)):
        name_list[idx] = name_list[idx][0]
        analyse(task_type='image_classification', save_path='./', pred_list=pred_list, label_list=target_list, name_list=name_list)

        评估就是需要获得三个list,推理的直接的结果logits组合成的pred_list,存储的是每一张图片直接的预测结果,如[[8.725419998168945, 21.92235565185547]...[xxx, xxx]]。一个真实的label值组成的target_list,存储的是每一张图片的真实标签,如[0, 1, 0, 1, 1..., 1, 0]。还有原图像文件存储的路径组合成的name_list,如[xxx.jpg, ... xxx.jpg],这里是从pytorch度数据模块的类中通过val_loader.dataset.samples获取到后重新组合的。然后调用deep_moxing库中的analyse接口,在save_path下会生成一个model_analysis_results.json的文件,将这个文件上传到页面上任意一个训练任务的输出目录下,就能在页面的评估界面上看到对模型评估的结果。

        图3 评估结果

        这结果中需要分析模型的敏感度。

        表1 图像清晰度敏感度分析

        特征值分布

        0

        1

        0% - 20%

        0.7929

        0.8727

        20% - 40%

        0.8816

        0.7429

        40% - 60%

        0.9363

        0.7229

        60% - 80%

        0.9462

        0.7912

        80% - 100%

        0.9751

        0.7619

        标准差

        0.0643

        0.0523

        上述结果中能看到,0类(正常类)随着图像清晰度的增大F1-score会提升,也就是说,模型在清晰的图片上,对正常类的检测表现更好,而在1类(瑕疵类)随着图像清晰读增大精度会下降,说明对模型而言,模糊的图片能让它检测有瑕疵类更加准确。由于该模型侧重于对瑕疵类的鉴别,所以可以使用图像模糊的手段作为TTA的增强方法。

    2. 接下来可以看看在pytorch中,如何加入TTA。

      PyTorch的好处在于,可以直接获取到输入模型前的tensor并进行想要的操作。如在eval中。

      1
      2
      3
      4
      5
      with torch.no_grad():         
          end = time.time()         
          for i, (images, target) in enumerate(val_loader):             
              if args.gpu is not None:                 
              images = images.cuda(args.gpu, non_blocking=True)
      

      这里拿到的images就是已经做好前处理的一个batch的图片数据。由步骤1中确定了两种增强方法,竖直方向的翻转和模糊。

      PyTorch中的翻转,在版本大于0.4.0时,可以使用如下代码:

      1
      2
      3
      4
      def flip(x, dim):     
          indices = [slice(None)] * x.dim()     
          indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)    
          return x[tuple(indices)]
      

      dim为模式,这里使用2为竖直方向的翻转,3为水平方向,1为做通道翻转。使用img_flip = flip(images, 2)就能得到竖直方向翻转的图片。

      模糊稍多一些操作,可以利用cv2中自带的blur操作。

      1
      2
      3
      img = images.numpy() 
      img[0] = cv2.blur(img[0], (3, 3)) 
      images_blur = torch.from_numpy(img.copy())
      
    3. 结果合成。

      现在得到了三个输出,原图的推理结果origin_result、竖直方向翻转后得到的结果flip_output、模糊后得到的blur_output。

      那么该如何合成呢?

      先看flip_output,一个想法是,原训练中见过的做过翻转的图片所占的比例是多少,在最终的输出一张做过翻转的图片对结果的贡献权重就是多少。那么相信很多有深度学习经验的同学们知道,一般模型做FLIP的概率为0.5,也就是模型见过的做过翻转的图片,大致比例上为0.5,那么flip的结果最终结果的贡献就也是0.5,可得:

      logits = 0.5*origin_result + 0.5*flip_result

      此时,模型的精度结果为:

      表2 模型精度结果

      操作

      acc

      norm类recall

      abnorm类recall

      原版

      89.13%

      97.2%

      71.3%

      flip结果合成

      87.74%

      93.7%

      72.7%

      可以看到,虽然损失了norm类的精度,但是相对而言更重要的指标abnorm类的recall有提升。

      然后分析blur_output,可以看到,位于最低的0-20%时,瑕疵类的精度是最高的,但是norm类的精度掉的太多,而且模糊本身就是提升abnorm类精度的,所以我们做一个折中,同样取blur图片的贡献值为0.5,可得公式:

      logit = 0.5*origin_result + 0.5*blur_output

      此时,模型的精度结果为:

      表3 模型精度结果

      操作

      acc

      norm类recall

      abnorm类recall

      原版

      89.13%

      97.2%

      71.3%

      blur结果合成

      88.117%

      94.8%

      73.3%

      可以看到,norm类的精度下降较多,abnorm类增长明显,与模型评估的分析结果一致。

      综上,我们调整的结果虽然对norm类的损失较多导致整体精度下降,但是这是符合模型分析的结果的,我们需要的指标就是abnorm类recall的提升,而且可以看到,模型评估的结果要稍好于使用原版增强的合成结果。

总结

这里实验了两种使用test time augmatation的方法,一种是根据训练过程自带的增强方法来选择测试前增强,另一种是通过对模型进行敏感度分析,分析图片什么样的特征范围对于模型的判别最有帮助。当然,这里很重要的一点:TTA会增加模型推理的时间,对推理时延要求很高的人工智能算法应用请仔细抉择选择合适的解决方法。

分享:

    相关文档

    相关产品