高可靠性前置准备:断点续训练
通过本章,您将了解到断点续训练的相关原理、配置方法、配置建议和注意事项。
由于内容较多,为便于您找到关心的内容,本节开头先做一个简单的内容分类。请参照如下分类,阅读您关心的内容。
- 了解断点续训的概念及为什么需要断点续训,请您参考什么是断点续训练和为什么需要配置断点续训练。
- 了解如何配置断点续训,请您参考如何配置CheckPoint。
- 了解配置CheckPoint需要的注意点,请您参考分布式训练CheckPoint注意事项。
- 本节为您提供了配置断点的一些建议,请您参考CheckPoint保存频率建议。
什么是断点续训练
断点续训练是指因为某些原因(例如容错重启、资源抢占、作业卡死等)导致训练作业还未完成就被中断,下一次训练可以在上一次的训练基础上继续进行。这种方式对于需要长时间训练的模型而言比较友好。
断点续训练是通过CheckPoint机制实现。
CheckPoint的机制是:在模型训练的过程中,不断地保存训练结果(包括但不限于EPOCH、模型权重、优化器状态、调度器状态)。即便模型训练中断,也可以基于CheckPoint继续训练。
当需要从训练中断的位置继续训练,只需要加载CheckPoint,并用CheckPoint信息初始化训练状态即可。用户需要在代码里加上reload ckpt的代码,使能读取前一次训练保存的预训练模型。具体参见常用训练框架reload ckpt配置示例修改对应框架ckpt配置。
为什么需要配置断点续训练
ModelArts提供的自动重启功能只能保证作业失败后重新拉起,不能保证训练进度不丢失。如果训练脚本不支持断点续训练,作业重启后可能会从头开始训练,导致如下问题:
- 已完成训练进度丢失。
- 算力资源重复消耗。
- 训练结果不符合预期。
- 多次自动重启后仍然无法恢复到故障前状态。
因此,对于耗时较长的训练任务,应优先完成CheckPoint适配。
如何配置CheckPoint
ModelArts提供了训练作业CheckPoint存储路径及自动重启机制。训练作业断点续训需要您完成如下配置才能实现。
- 配置CheckPoint保存内容。根据训练作业配置CheckPoint要保存的内容。
- 配置CheckPoint保存路径。在ModelArts上选择外置存储路径,保证CheckPoint产物能够安全持久存放,快速恢复。
- 常用训练框架reload ckpt配置示例。根据不同训练框架,在训练代码中配置reload CheckPoint配置。
配置CheckPoint保存内容
在模型训练作业执行过程中,训练作业会产生许多中间态产物,为保证断点续训能够顺利恢复作业现场,建议CheckPoint至少保存如表1所示内容。
本节给出相关配置项保存内容并非所有断点续训都需配置内容,请根据实际训练作业酌情参考。
配置CheckPoint保存路径
ModalArts提供了多种存储挂载路径方案,保证训练作业CheckPoint能安全稳定存储,即使训练作业重启后仍能够方便快速恢复训练。
断点续训练建议和训练容错检查(即自动重启)功能同时使用。在创建训练作业页面,开启“自动重启”开关。训练环境预检测失败、或者训练容器硬件检测故障、或者训练作业失败时会自动重新下发并运行训练作业。
建议将CheckPoint保存在稳定的共享存储或训练输出路径中,避免仅保存在容器本地目录。
为改善用户体验,ModelArts在部分站点提供了新版的控制台界面。以下分别介绍ModelArts新版控制台和旧版控制台如何设置CheckPoint存储路径。
新版控制台
在ModelArts训练中实现断点续训练或增量训练,建议使用存储挂载功能。
在创建训练作业时,通过挂载存储路径来保存和读取Checkpoint文件。具体操作如下:
- 在训练作业中配置存储路径挂载,将存储Checkpoint的目录挂载到训练容器的本地目录。
- 训练过程中,将Checkpoint文件保存到挂载的本地目录中,数据会自动同步到挂载的存储位置。
- 对于断点续训练,确保挂载的存储目录中包含之前的Checkpoint文件,训练脚本会自动加载最新的Checkpoint继续训练。
通过存储挂载方式,可以实现训练数据的持久化存储和跨作业的模型复用。
在ModelArts中创建训练作业时,可以选择以下两种存储挂载选项。以下是它们的详细对比,帮助您根据需求选择合适的存储方案。
旧版控制台
在ModelArts训练中实现断点续训练或增量训练,建议使用“训练输出”功能。
在创建训练作业时,设置训练“输出”参数名称为“train_output”,用户可通过环境变量或超参方式获取该参数。设置成功后可在指定的训练输出的数据存储位置中保存Checkpoint,且“预下载至本地目录”选择“下载”。选择预下载至本地目录时,系统在训练作业启动前,自动将数据存储位置中的Checkpoint文件下载到训练容器的本地目录。
常用训练框架reload ckpt配置示例
以下示例仅展示核心逻辑,生产环境建议增加异常处理、CheckPoint完整性校验和分布式同步。示例中设置CheckPoint保存地址变量为train_output。
| 训练框架 | reload ckpt配置示例 |
|---|---|
| VeRL | VeRL是一个灵活、高效且被广泛使用的强化学习(RL)训练库,后训练的事实标准框架。VeRL是论文HybridFlow: A Flexible and Efficient RLHF Framework 的开源实现。
|
| MindSpeed-LLM | MindSpeed LLM是基于昇腾生态的大语言模型分布式训练框架,旨在为华为昇腾芯片生态合作伙伴提供端到端的大语言模型训练方案,包含分布式预训练、分布式指令微调以及对应的开发工具链,如:数据预处理、权重转换、在线推理、基线评估。作为昇腾计算主打的训练框架,在性能上做了极致的优化,特别在大参数、大集群和MOE类型模型的训练性能突出,且兼容Megatron-LM框架,对于Megatron客户可以平滑迁移。
|
| LLaMa-Factory | LLaMa-Factory是开源社区中一个非常活跃的大模型训练框架,它的主打特点是简单易用,通过命令行或者WebUI界面可以轻松微调数百种大模型,包括大语言和多模态模型。LLaMa-Factory的底层是基于Transformers+DeepSpeed构建,对开源模型就有非常好的兼容性。
|
| Pytorch |
|
| MindSpore | import os
import argparse
from resnet import resnet50
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint
from mindspore.train.callback import LossMonitor
parser = argparse.ArgumentParser()
parser.add_argument("--train_output", type=str)
parser.add_argument("--batch_size", type=int, default=32, help="Batch size.")
parser.add_argument("--num_classes", type=int, default=10, help="Num classes.")
parser.add_argument("--do_train", type=bool, default=True, help="Do train or not.")
args_opt, unparsed = parser.parse_known_args()
# train_output 将被赋值为"/home/ma-user/modelarts/outputs/train_output_0" 。
train_output = args_opt.train_output
# 初始定义的网络、损失函数及优化器,详细请参见MindSpore保存与加载。
# 1.初始定义的网络,以“ResNet50”为例。详细请参见ResNet50。
net = resnet50(args_opt.batch_size, args_opt.num_classes)
# 2.定义损失函数,详细请参见MindSpore自定义损失函数。
ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# 3.定义优化器,详细请参见MindSpore自定义优化器。
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
# 首次训练的epoch初始值,mindspore1.3及以后版本会支持定义epoch_size初始值。
# cur_epoch_num = 0
# 判断输出obs路径中是否有模型文件。如果无文件则默认从头训练,如果有模型文件,则加载epoch值最大的ckpt文件当做预训练模型。
if os.listdir(train_output):
last_ckpt = sorted([file for file in os.listdir(train_output) if file.endswith(".ckpt")])[-1]
print('last_ckpt:', last_ckpt)
last_ckpt_file = os.path.join(train_output, last_ckpt)
# 加载断点,详细请参见mindspore.load_checkpoint。
param_dict = load_checkpoint(last_ckpt_file)
print('> load last ckpt and continue training!!')
# 加载模型参数到net。
load_param_into_net(net, param_dict)
# 加载模型参数到opt。
load_param_into_net(opt, param_dict)
# 获取保存的epoch值,模型会在此epoch的基础上继续训练,此参数在mindspore1.3及以后版本会支持。
# if param_dict.get("epoch_num"):
# cur_epoch_num = int(param_dict["epoch_num"].data.asnumpy())
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
# as for train, users could use model.train
if args_opt.do_train:
dataset = create_dataset()
batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num,
keep_checkpoint_max=35)
# append_info=[{"epoch_num": cur_epoch_num}],mindspore1.3及以后版本会支持append_info参数,保存当前时刻的epoch值。
# 保存网络参数,详细请参见mindspore.train.ModelCheckpoint。
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10",
directory=args_opt.train_output,
config=config_ck)
loss_cb = LossMonitor()
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
# model.train(epoch_size-cur_epoch_num, dataset, callbacks=[ckpoint_cb, loss_cb]),mindspore1.3及以后版本支持从断点恢复训练。 |