更新时间:2026-02-06 GMT+08:00
分享

设置断点续训练

什么是断点续训练

断点续训练是指因为某些原因(例如容错重启、资源抢占、作业卡死等)导致训练作业还未完成就被中断,下一次训练可以在上一次的训练基础上继续进行。这种方式对于需要长时间训练的模型而言比较友好

断点续训练是通过checkpoint机制实现。

checkpoint的机制是:在模型训练的过程中,不断地保存训练结果(包括但不限于EPOCH、模型权重、优化器状态、调度器状态)。即便模型训练中断,也可以基于checkpoint接续训练。

当需要从训练中断的位置接续训练,只需要加载checkpoint,并用checkpoint信息初始化训练状态即可。用户需要在代码里加上reload ckpt的代码,使能读取前一次训练保存的预训练模型。

ModelArts中设置训练输出实现断点续训练

新版:

在ModelArts训练中实现断点续训练或增量训练,建议使用存储挂载功能。

在创建训练作业时,通过挂载存储路径来保存和读取Checkpoint文件。具体操作如下:

  1. 在训练作业中配置存储路径挂载,将存储Checkpoint的目录挂载到训练容器的本地目录。
  2. 训练过程中,将Checkpoint文件保存到挂载的本地目录中,数据会自动同步到挂载的存储位置。
  3. 对于断点续训练,确保挂载的存储目录中包含之前的Checkpoint文件,训练脚本会自动加载最新的Checkpoint继续训练。

通过存储挂载方式,可以实现训练数据的持久化存储和跨作业的模型复用。

在ModelArts中创建训练作业时,可以选择以下两种存储挂载选项。以下是它们的详细对比,帮助您根据需求选择合适的存储方案。

表1 两种存储挂载

存储类型

性能

容量

适用场景

价格

备注

SFS Turbo

适用于多种应用场景,包括AI训练、AIGC、自动驾驶、渲染、EDA仿真、企业NAS应用等

较高

通用

OBS

基于对象存储服务作为统一数据存储的大数据场景。

适中

高频读取,低频写入。

旧版:

在ModelArts训练中实现断点续训练或增量训练,建议使用“训练输出”功能。

在创建训练作业时,设置训练“输出”参数名称为“train_output”,用户可通过环境变量或超参方式获取该参数。设置成功后可在指定的训练输出的数据存储位置中保存Checkpoint,“预下载至本地目录”选择“下载”。选择预下载至本地目录时,系统在训练作业启动前,自动将数据存储位置中的Checkpoint文件下载到训练容器的本地目录。

图1 训练输出设置

断点续训练建议和训练容错检查(即自动重启)功能同时使用。在创建训练作业页面,开启“自动重启”开关。训练环境预检测失败、或者训练容器硬件检测故障、或者训练作业失败时会自动重新下发并运行训练作业。

VeRL框架reload ckpt

VeRL是一个灵活、高效且被广泛使用的强化学习(RL)训练库,后训练的事实标准框架。VeRL是论文 HybridFlow: A Flexible and Efficient RLHF Framework 的开源实现。

  1. VeRL的训练yaml中配置trainer.save_freq参数和trainer.default_local_dir参数

    VeRL通过trainer.default_local_dir参数配置输出目录,该目录下会存在多个global_steps_xx权重目录,通过trainer.save_freq参数配置权重保存的频率,每隔一定的步数保存权重。

  2. VeRL的训练yaml中配置trainer.resume_mode参数
    当trainer.resume_mode设置为auto时,VeRL会自动遍历trainer.default_local_dir路径,加载最新且有效的ckpt。以ModelArts中设置训练输出实现断点续训练中的参数名称为train_output为例,参数设置如下所示:
    trainer.default_local_dir="${train_output}" 
    trainer.resume_mode=auto

MindSpeed-LLM框架reload ckpt

MindSpeed LLM是基于昇腾生态的大语言模型分布式训练框架,旨在为华为 昇腾芯片 生态合作伙伴提供端到端的大语言模型训练方案,包含分布式预训练、分布式指令微调以及对应的开发工具链,如:数据预处理、权重转换、在线推理、基线评估。作为昇腾计算主打的训练框架,在性能上做了极致的优化,特别在大参数、大集群和MOE类型模型的训练性能突出,且兼容Megatron-LM框架,对于Megatron客户可以平滑迁移。

  1. MindSpeed-LLM中配置--save参数和--save-interval参数

    MindSpeed-LLM训练启动脚本中通过--save参数配置输出目录,该目录下会存在多个iter_xx权重目录和一个记录最新权重保存步数的文件latest_checkpointed_iteration.txt,每次save会更新latest_checkpointed_iteration.txt。通过--save-interval参数配置权重保存的频率,每隔一定的步数保存权重。

  2. MindSpeed-LLM中配置--load参数,并且和--save参数路径保持一致

    MindSpeed-LLM训练启动脚本中通过--load参数配置输入目录,当--save参数和--load参数设置的参数保持一致时,每次重启训练业务即可加载最新的权重,以ModelArts中设置训练输出实现断点续训练中的参数名称为train_output为例,参数设置如下所示:

    --save-interval 1000 
    --save ${train_output} 
    --load ${train_output} 

LLaMa-Factory框架reload ckpt

LLaMa-Factory是开源社区中一个非常活跃的大模型训练框架,它的主打特点是简单易用,通过命令行或者WebUI界面可以轻松微调数百种大模型,包括大语言和多模态模型。LLaMa-Factory的底层是基于Transformers+DeepSpeed构建,对开源模型就有非常好的兼容性。

  1. LLaMa-Factory的训练yaml中配置output_dir参数和save_steps参数

    LLaMa-Factory通过output_dir参数配置输出目录,该目录下会存在多个checkpoint-xxx权重目录,通过save_steps参数配置权重保存频率。

  2. LLaMa-Factory的训练yaml中配置resume_from_checkpoint参数,并且和output_dir参数路径保持一致

    LLaMa-Factory通过resume_from_checkpoint显式指定本次训练使用的ckpt,如果指定了有效的ckpt则从指定的ckpt中恢复训练。当resume_from_checkpoint参数和output_dir参数保持一致时,但output_dir参数不是一个有效ckpt目录,因此还需要步骤3和步骤4。以ModelArts中设置训练输出实现断点续训练中的参数名称为train_output为例,参数设置如下所示:

    ### output
    output_dir: ${train_output}
    save_steps: 500 
    
    ### train
    resume_from_checkpoint: ${train_output}
  3. 首先需要创建一个resume.py脚本,该脚本需要传入训练配置yaml的绝对路径,具体代码如下所示:
    import os
    import re
    import sys
    
    
    def update_resume_config(config_file):  # 接收传入的配置文件路径
        # 读取配置文件内容
        with open(config_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
    
        resume_line_num = None
        resume_path = None
    
        # 查找 resume_from_checkpoint 行
        for i, line in enumerate(lines):
            if line.strip().startswith('resume_from_checkpoint:'):
                resume_line_num = i
                # 提取值
                parts = line.split(':', 1)
                if len(parts) > 1:
                    resume_path = parts[1].strip().strip('"\'')  # 去除引号
                break
    
        # 如果没找到或值为null,不处理
        if resume_line_num is None or resume_path in (None, 'null', ''):
            return
    
        # 检查路径并找到最新的checkpoint
        new_resume_path = None
        if os.path.isdir(resume_path):
            # 查找所有 checkpoint-数字 文件夹
            checkpoint_pattern = re.compile(r'^checkpoint-(\d+)$')
            checkpoints = []
    
            for item in os.listdir(resume_path):
                item_path = os.path.join(resume_path, item)
                if os.path.isdir(item_path):
                    match = checkpoint_pattern.match(item)
                    if match:
                        step = int(match.group(1))
                        checkpoints.append((step, item_path))
    
            # 如果找到checkpoint,使用最新的那个
            if checkpoints:
                checkpoints.sort(key=lambda x: x[0])
                new_resume_path = checkpoints[-1][1]
    
        # 修改配置行
        indent = len(line) - len(line.lstrip())  # 保留原缩进
        if new_resume_path:
            lines[resume_line_num] = f'{" " * indent}resume_from_checkpoint: {new_resume_path}\n'
        else:
            lines[resume_line_num] = f'{" " * indent}resume_from_checkpoint: null\n'
    
        # 写回文件
        with open(config_file, 'w', encoding='utf-8') as f:
            f.writelines(lines)
    
    
    if __name__ == "__main__":
        # 从命令行参数中获取配置文件路径
        if len(sys.argv) < 2:
            print("用法:python resume.py <配置文件路径>")
            sys.exit(1)
        config_file = sys.argv[1]  # 接收命令行传入的 abc.yaml
        update_resume_config(config_file)  # 传入函数执行
  4. 修改训练启动脚本,在执行llamafactory-cli train命令之前,通过resume.py脚本修改训练配置中yaml中resume_from_checkpoint参数。resume.py脚本会获取resume_from_checkpoint参数中的路径,去该路径下查找步数最大的checkpoint-xxx权重目录,然后将resume_from_checkpoint参数修改为该权重目录的绝对路径。以train_lora/deepseek3_lora_sft_kt.yaml为例,其中WORK_DIR为工作路径,shell脚本修改如下所示:
    #!/bin/bash
    ...
    ...
    
    python $WORK_DIR/resume.py $WORK_DIR/LLaMA-Factory/examples/train_lora/deepseek3_lora_sft_kt.yaml
    llamafactory-cli train $WORK_DIR/LLaMA-Factory/examples/train_lora/deepseek3_lora_sft_kt.yaml 

PyTorch版reload ckpt

  • PyTorch模型保存有两种方式。
    • 仅保存模型参数
      state_dict = model.state_dict()
      torch.save(state_dict, path)
    • 保存整个Model(不推荐)
      torch.save(model, path)
  • 可根据step步数、时间等周期性保存模型的训练过程的产物。

    将模型训练过程中的网络权重、优化器权重、以及epoch进行保存,便于中断后继续训练恢复。

       checkpoint = {
               "net": model.state_dict(),
               "optimizer": optimizer.state_dict(),
               "epoch": epoch   
       }
       if not os.path.isdir('model_save_dir'):
           os.makedirs('model_save_dir')
       torch.save(checkpoint,'model_save_dir/ckpt_{}.pth'.format(str(epoch)))
  • 完整代码示例。
    import os
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--train_output", type=str)
    args, unparsed = parser.parse_known_args()
    args = parser.parse_known_args()
    # train_output 将被赋值为"/home/ma-user/modelarts/outputs/train_output_0" 
    train_output = args.train_output
    
    # 判断输出路径中是否有模型文件。如果无文件则默认从头训练,如果有模型文件,则加载epoch值最大的ckpt文件当做预训练模型。
    if os.listdir(train_output):
        print('> load last ckpt and continue training!!')
        last_ckpt = sorted([file for file in os.listdir(train_output) if file.endswith(".pth")])[-1]
        local_ckpt_file = os.path.join(train_output, last_ckpt)
        print('last_ckpt:', last_ckpt)
        # 加载断点
        checkpoint = torch.load(local_ckpt_file)  
        # 加载模型可学习参数
        model.load_state_dict(checkpoint['net'])  
        # 加载优化器参数
        optimizer.load_state_dict(checkpoint['optimizer'])  
        # 获取保存的epoch,模型会在此epoch的基础上继续训练
        start_epoch = checkpoint['epoch']  
    start = datetime.now()
    total_step = len(train_loader)
    for epoch in range(start_epoch + 1, args.epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ...
    
        # 保存模型训练过程中的网络权重、优化器权重、以及epoch
        checkpoint = {
              "net": model.state_dict(),
              "optimizer": optimizer.state_dict(),
              "epoch": epoch
            }
        if not os.path.isdir(train_output):
            os.makedirs(train_output)
            torch.save(checkpoint, os.path.join(train_output, 'ckpt_best_{}.pth'.format(epoch)))
    

MindSpore版reload ckpt

import os
import argparse
from resnet import resnet50
from mindspore.nn.optim.momentum import Momentum 
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint
from mindspore.train.callback import LossMonitor

parser = argparse.ArgumentParser()
parser.add_argument("--train_output", type=str)
parser.add_argument("--batch_size", type=int, default=32, help="Batch size.") 
parser.add_argument("--num_classes", type=int, default=10, help="Num classes.") 
parser.add_argument("--do_train", type=bool, default=True, help="Do train or not.") 
args_opt, unparsed = parser.parse_known_args()
# train_output 将被赋值为"/home/ma-user/modelarts/outputs/train_output_0" 。
train_output = args_opt.train_output

# 初始定义的网络、损失函数及优化器,详细请参见MindSpore保存与加载。
# 1.初始定义的网络,以“ResNet50”为例。详细请参见ResNet50。
net = resnet50(args_opt.batch_size, args_opt.num_classes)
# 2.定义损失函数,详细请参见MindSpore自定义损失函数。
ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# 3.定义优化器,详细请参见MindSpore自定义优化器。
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
# 首次训练的epoch初始值,mindspore1.3及以后版本会支持定义epoch_size初始值。
# cur_epoch_num = 0
# 判断输出obs路径中是否有模型文件。如果无文件则默认从头训练,如果有模型文件,则加载epoch值最大的ckpt文件当做预训练模型。
if os.listdir(train_output):
    last_ckpt = sorted([file for file in os.listdir(train_output) if file.endswith(".ckpt")])[-1]
    print('last_ckpt:', last_ckpt)
    last_ckpt_file = os.path.join(train_output, last_ckpt)
     # 加载断点,详细请参见mindspore.load_checkpoint。
    param_dict = load_checkpoint(last_ckpt_file) 
    print('> load last ckpt and continue training!!')
    # 加载模型参数到net。
    load_param_into_net(net, param_dict)
    # 加载模型参数到opt。
    load_param_into_net(opt, param_dict)

    # 获取保存的epoch值,模型会在此epoch的基础上继续训练,此参数在mindspore1.3及以后版本会支持。
    # if param_dict.get("epoch_num"):
    #     cur_epoch_num = int(param_dict["epoch_num"].data.asnumpy())
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
# as for train, users could use model.train
if args_opt.do_train:
    dataset = create_dataset()
    batch_num = dataset.get_dataset_size()
    config_ck = CheckpointConfig(save_checkpoint_steps=batch_num,
                                     keep_checkpoint_max=35)
    # append_info=[{"epoch_num": cur_epoch_num}],mindspore1.3及以后版本会支持append_info参数,保存当前时刻的epoch值。
    # 保存网络参数,详细请参见mindspore.train.ModelCheckpoint。
    ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10",
                                     directory=args_opt.train_output,
                                     config=config_ck)
    loss_cb = LossMonitor()
    model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
    # model.train(epoch_size-cur_epoch_num, dataset, callbacks=[ckpoint_cb, loss_cb]),mindspore1.3及以后版本支持从断点恢复训练。

相关文档