计算
弹性云服务器 ECS
Flexus云服务
裸金属服务器 BMS
弹性伸缩 AS
镜像服务 IMS
专属主机 DeH
函数工作流 FunctionGraph
云手机服务器 CPH
Huawei Cloud EulerOS
网络
虚拟私有云 VPC
弹性公网IP EIP
虚拟专用网络 VPN
弹性负载均衡 ELB
NAT网关 NAT
云专线 DC
VPC终端节点 VPCEP
云连接 CC
企业路由器 ER
企业交换机 ESW
全球加速 GA
安全与合规
安全技术与应用
Web应用防火墙 WAF
企业主机安全 HSS
云防火墙 CFW
安全云脑 SecMaster
DDoS防护 AAD
数据加密服务 DEW
数据库安全服务 DBSS
云堡垒机 CBH
数据安全中心 DSC
云证书管理服务 CCM
边缘安全 EdgeSec
威胁检测服务 MTD
CDN与智能边缘
内容分发网络 CDN
CloudPond云服务
智能边缘云 IEC
迁移
主机迁移服务 SMS
对象存储迁移服务 OMS
云数据迁移 CDM
迁移中心 MGC
大数据
MapReduce服务 MRS
数据湖探索 DLI
表格存储服务 CloudTable
云搜索服务 CSS
数据接入服务 DIS
数据仓库服务 GaussDB(DWS)
数据治理中心 DataArts Studio
数据可视化 DLV
数据湖工厂 DLF
湖仓构建 LakeFormation
企业应用
云桌面 Workspace
应用与数据集成平台 ROMA Connect
云解析服务 DNS
专属云
专属计算集群 DCC
IoT物联网
IoT物联网
设备接入 IoTDA
智能边缘平台 IEF
用户服务
账号中心
费用中心
成本中心
资源中心
企业管理
工单管理
国际站常见问题
ICP备案
我的凭证
支持计划
客户运营能力
合作伙伴支持计划
专业服务
区块链
区块链服务 BCS
Web3节点引擎服务 NES
解决方案
SAP
高性能计算 HPC
视频
视频直播 Live
视频点播 VOD
媒体处理 MPC
实时音视频 SparkRTC
数字内容生产线 MetaStudio
存储
对象存储服务 OBS
云硬盘 EVS
云备份 CBR
存储容灾服务 SDRS
高性能弹性文件服务 SFS Turbo
弹性文件服务 SFS
云硬盘备份 VBS
云服务器备份 CSBS
数据快递服务 DES
专属分布式存储服务 DSS
容器
云容器引擎 CCE
容器镜像服务 SWR
应用服务网格 ASM
华为云UCS
云容器实例 CCI
管理与监管
云监控服务 CES
统一身份认证服务 IAM
资源编排服务 RFS
云审计服务 CTS
标签管理服务 TMS
云日志服务 LTS
配置审计 Config
资源访问管理 RAM
消息通知服务 SMN
应用运维管理 AOM
应用性能管理 APM
组织 Organizations
优化顾问 OA
IAM 身份中心
云运维中心 COC
资源治理中心 RGC
应用身份管理服务 OneAccess
数据库
云数据库 RDS
文档数据库服务 DDS
数据管理服务 DAS
数据复制服务 DRS
云数据库 GeminiDB
云数据库 GaussDB
分布式数据库中间件 DDM
数据库和应用迁移 UGO
云数据库 TaurusDB
人工智能
人脸识别服务 FRS
图引擎服务 GES
图像识别 Image
内容审核 Moderation
文字识别 OCR
AI开发平台ModelArts
图像搜索 ImageSearch
对话机器人服务 CBS
华为HiLens
视频智能分析服务 VIAS
语音交互服务 SIS
应用中间件
分布式缓存服务 DCS
API网关 APIG
微服务引擎 CSE
分布式消息服务Kafka版
分布式消息服务RabbitMQ版
分布式消息服务RocketMQ版
多活高可用服务 MAS
事件网格 EG
企业协同
华为云会议 Meeting
云通信
消息&短信 MSGSMS
云生态
合作伙伴中心
云商店
开发者工具
SDK开发指南
API签名指南
Terraform
华为云命令行工具服务 KooCLI
其他
产品价格详情
系统权限
管理控制台
客户关联华为云合作伙伴须知
消息中心
公共问题
开发与运维
应用管理与运维平台 ServiceStage
软件开发生产线 CodeArts
需求管理 CodeArts Req
部署 CodeArts Deploy
性能测试 CodeArts PerfTest
编译构建 CodeArts Build
流水线 CodeArts Pipeline
制品仓库 CodeArts Artifact
测试计划 CodeArts TestPlan
代码检查 CodeArts Check
代码托管 CodeArts Repo
云应用引擎 CAE
开天aPaaS
云消息服务 KooMessage
云手机服务 KooPhone
云空间服务 KooDrive

查看模型评估结果

更新时间:2024-10-24 GMT+08:00

训练作业运行结束后,ModelArts可为您的模型进行评估,并且给出调优诊断和建议。

  • 针对使用预置算法创建训练作业,无需任何配置,即可查看此评估结果(由于每个模型情况不同,系统将自动根据您的模型指标情况,给出一些调优建议,请仔细阅读界面中的建议和指导,对您的模型进行进一步的调优)。
  • 针对用户自己编写训练脚本或自定义镜像方式创建的训练作业,则需要在您的训练代码中添加评估代码,才可以在训练作业结束后查看相应的评估诊断建议。
    说明:
    • 只支持验证集的数据格式为图片。
    • 目前,仅如下常用框架的训练脚本支持添加评估代码。
      • TF-1.13.1-python3.6
      • TF-2.1.0-python3.6
      • PyTorch-1.4.0-python3.6

下文将介绍如何在训练中使用评估代码。对训练代码做一定的适配和修正,分为三个方面:添加输出目录复制数据集到本地映射数据集路径到OBS

添加输出目录

添加输出目录的代码比较简单,即在代码中添加一个输出评估结果文件的目录,被称为train_url,也就是页面上的训练输出位置。并把train_url添加到使用的函数analysis中,使用save_path来获取train_url。示例代码如下所示:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('model_url', '', 'path to saved model')
tf.app.flags.DEFINE_string('data_url', '', 'path to output files')
tf.app.flags.DEFINE_string('train_url', '', 'path to output files')
tf.app.flags.DEFINE_string('adv_param_json',
                           '{"attack_method":"FGSM","eps":40}',
                           'params for adversarial attacks')
FLAGS(sys.argv, known_only=True)

...

# analyse
res = analyse(
    task_type=task_type,
    pred_list=pred_list,
    label_list=label_list,
    name_list=file_name_list,
    label_map_dict=label_dict,
    save_path=FLAGS.train_url)

复制数据集到本地

复制数据集到本地主要是为了防止长时间访问OBS容易导致OBS连接中断使得作业卡住,所以一般先将数据复制到本地再进行操作。

数据集复制有两种方式,推荐使用OBS路径复制。

  • OBS路径(推荐)

    直接使用moxing的copy_parallel接口,复制对应的OBS路径。

  • ModelArts数据管理中的数据集(即manifest文件格式)

    使用moxing的copy_manifest接口将文件复制到本地并获取新的manifest文件路径,然后使用SDK解析新的manifest文件。

说明:

ModelArts数据管理模块在重构升级中,对未使用过数据管理的用户不可见。建议新用户将训练数据存放至OBS桶中使用。

1
2
3
4
5
6
7
8
if data_path.startswith('obs://'):
    if '.manifest' in data_path:
        new_manifest_path, _ = mox.file.copy_manifest(data_path, '/cache/data/')
        data_path = new_manifest_path
    else:
        mox.file.copy_parallel(data_path, '/cache/data/')
        data_path = '/cache/data/'
    print('------------- download dataset success ------------')

映射数据集路径到OBS

由于最终JSON体中需要填写的是图片文件的真实路径,也就是OBS对应的路径,所以在复制到本地做完分析和评估操作后,需要将原来的本地数据集路径映射到OBS路径,然后将新的list送入analysis接口。

如果使用的是OBS路径作为输入的data_url,则只需要替换本地路径的字符串即可。

1
2
3
if FLAGS.data_url.startswith('obs://'):
    for idx, item in enumerate(file_name_list):
        file_name_list[idx] = item.replace(data_path, FLAGS.data_url)

如果使用manifest文件,需要再解析一遍原版的manifest文件获取list,然后再送入analysis接口。

1
2
3
4
5
6
7
8
if or FLAGS.data_url.startswith('obs://'):
    if 'manifest' in FLAGS.data_url:
            file_name_list = []
            manifest, _ = get_sample_list(
                manifest_path=FLAGS.data_url, task_type='image_classification')
            for item in manifest:
                if len(item[1]) != 0:
                    file_name_list.append(item[0])

完整的适配了训练作业创建的图像分类样例代码如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import json
import logging
import os
import sys
import tempfile

import h5py
import numpy as np
from PIL import Image

import moxing as mox
import tensorflow as tf
from deep_moxing.framework.manifest_api.manifest_api import get_sample_list
from deep_moxing.model_analysis.api import analyse, tmp_save
from deep_moxing.model_analysis.common.constant import TMP_FILE_NAME

logging.basicConfig(level=logging.DEBUG)

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('model_url', '', 'path to saved model')
tf.app.flags.DEFINE_string('data_url', '', 'path to output files')
tf.app.flags.DEFINE_string('train_url', '', 'path to output files')
tf.app.flags.DEFINE_string('adv_param_json',
                           '{"attack_method":"FGSM","eps":40}',
                           'params for adversarial attacks')
FLAGS(sys.argv, known_only=True)


def _preprocess(data_path):
    img = Image.open(data_path)
    img = img.convert('RGB')
    img = np.asarray(img, dtype=np.float32)
    img = img[np.newaxis, :, :, :]
    return img


def softmax(x):
    x = np.array(x)
    orig_shape = x.shape
    if len(x.shape) > 1:
        # Matrix
        x = np.apply_along_axis(lambda x: np.exp(x - np.max(x)), 1, x)
        denominator = np.apply_along_axis(lambda x: 1.0 / np.sum(x), 1, x)
        if len(denominator.shape) == 1:
            denominator = denominator.reshape((denominator.shape[0], 1))
        x = x * denominator
    else:
        # Vector
        x_max = np.max(x)
        x = x - x_max
        numerator = np.exp(x)
        denominator = 1.0 / np.sum(numerator)
        x = numerator.dot(denominator)
    assert x.shape == orig_shape
    return x


def get_dataset(data_path, label_map_dict):
    label_list = []
    img_name_list = []
    if 'manifest' in data_path:
        manifest, _ = get_sample_list(
            manifest_path=data_path, task_type='image_classification')
        for item in manifest:
            if len(item[1]) != 0:
                label_list.append(label_map_dict.get(item[1][0]))
                img_name_list.append(item[0])
            else:
                continue
    else:
        label_name_list = os.listdir(data_path)
        label_dict = {}
        for idx, item in enumerate(label_name_list):
            label_dict[str(idx)] = item
            sub_img_list = os.listdir(os.path.join(data_path, item))
            img_name_list += [
                os.path.join(data_path, item, img_name) for img_name in sub_img_list
            ]
            label_list += [label_map_dict.get(item)] * len(sub_img_list)
    return img_name_list, label_list


def deal_ckpt_and_data_with_obs():
    pb_dir = FLAGS.model_url
    data_path = FLAGS.data_url

    if pb_dir.startswith('obs://'):
        mox.file.copy_parallel(pb_dir, '/cache/ckpt/')
        pb_dir = '/cache/ckpt'
        print('------------- download success ------------')
    if data_path.startswith('obs://'):
        if '.manifest' in data_path:
            new_manifest_path, _ = mox.file.copy_manifest(data_path, '/cache/data/')
            data_path = new_manifest_path
        else:
            mox.file.copy_parallel(data_path, '/cache/data/')
            data_path = '/cache/data/'
        print('------------- download dataset success ------------')
    assert os.path.isdir(pb_dir), 'Error, pb_dir must be a directory'
    return pb_dir, data_path


def evalution():
    pb_dir, data_path = deal_ckpt_and_data_with_obs()
    index_file = os.path.join(pb_dir, 'index')
    try:
        label_file = h5py.File(index_file, 'r')
        label_array = label_file['labels_list'][:].tolist()
        label_array = [item.decode('utf-8') for item in label_array]
    except Exception as e:
        logging.warning(e)
        logging.warning('index file is not a h5 file, try json.')
        with open(index_file, 'r') as load_f:
            label_file = json.load(load_f)
        label_array = label_file['labels_list'][:]
    label_map_dict = {}
    label_dict = {}
    for idx, item in enumerate(label_array):
        label_map_dict[item] = idx
        label_dict[idx] = item
    print(label_map_dict)
    print(label_dict)

    data_file_list, label_list = get_dataset(data_path, label_map_dict)

    assert len(label_list) > 0, 'missing valid data'
    assert None not in label_list, 'dataset and model not match'

    pred_list = []
    file_name_list = []
    img_list = []

    for img_path in data_file_list:
        img = _preprocess(img_path)
        img_list.append(img)
        file_name_list.append(img_path)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = '0'
    with tf.Session(graph=tf.Graph(), config=config) as sess:
        meta_graph_def = tf.saved_model.loader.load(
            sess, [tf.saved_model.tag_constants.SERVING], pb_dir)
        signature = meta_graph_def.signature_def
        signature_key = 'predict_object'
        input_key = 'images'
        output_key = 'logits'
        x_tensor_name = signature[signature_key].inputs[input_key].name
        y_tensor_name = signature[signature_key].outputs[output_key].name
        x = sess.graph.get_tensor_by_name(x_tensor_name)
        y = sess.graph.get_tensor_by_name(y_tensor_name)
        for img in img_list:
            pred_output = sess.run([y], {x: img})
            pred_output = softmax(pred_output[0])
            pred_list.append(pred_output[0].tolist())

    label_dict = json.dumps(label_dict)
    task_type = 'image_classification'

    if FLAGS.data_url.startswith('obs://'):
        if 'manifest' in FLAGS.data_url:
            file_name_list = []
            manifest, _ = get_sample_list(
                manifest_path=FLAGS.data_url, task_type='image_classification')
            for item in manifest:
                if len(item[1]) != 0:
                    file_name_list.append(item[0])
        for idx, item in enumerate(file_name_list):
            file_name_list[idx] = item.replace(data_path, FLAGS.data_url)
    # analyse
    res = analyse(
        task_type=task_type,
        pred_list=pred_list,
        label_list=label_list,
        name_list=file_name_list,
        label_map_dict=label_dict,
        save_path=FLAGS.train_url)

if __name__ == "__main__":
    evalution()

我们使用cookie来确保您的高速浏览体验。继续浏览本站,即表示您同意我们使用cookie。 详情

文档反馈

文档反馈

意见反馈

0/500

标记内容

同时提交标记内容