文档首页/ AI开发平台ModelArts/ 模型训练/ 训练进阶/ 断点续训练和增量训练
更新时间:2024-07-25 GMT+08:00

断点续训练和增量训练

什么是断点续训练和增量训练

断点续训练是指因为某些原因(例如容错重启、资源抢占、作业卡死等)导致训练作业还未完成就被中断,下一次训练可以在上一次的训练基础上继续进行。这种方式对于需要长时间训练的模型而言比较友好。

增量训练是指增加新的训练数据到当前训练流程中,扩展当前模型的知识和能力。

断点续训练和增量训练均是通过checkpoint机制实现。

checkpoint的机制是:在模型训练的过程中,不断地保存训练结果(包括但不限于EPOCH、模型权重、优化器状态、调度器状态)。即便模型训练中断,也可以基于checkpoint接续训练。

当需要从训练中断的位置接续训练,只需要加载checkpoint,并用checkpoint信息初始化训练状态即可。用户需要在代码里加上reload ckpt的代码,使能读取前一次训练保存的预训练模型。

ModelArts中如何实现断点续训练和增量训练

在ModelArts训练中实现断点续训练或增量训练,建议使用“训练输出”功能。

在创建训练作业时,设置训练“输出”参数为“train_url”,在指定的训练输出的数据存储位置中保存checkpoint,“预下载至本地目录”选择“下载”。选择预下载至本地目录时,系统在训练作业启动前,自动将数据存储位置中的checkpoint文件下载到训练容器的本地目录。

图1 训练输出设置

断点续训练建议和训练容错检查(即自动重启)功能同时使用。在创建训练作业页面,开启“自动重启”开关。训练环境预检测失败、或者训练容器硬件检测故障、或者训练作业失败时会自动重新下发并运行训练作业。

PyTorch版reload ckpt

  1. PyTorch模型保存有两种方式。
    • 仅保存模型参数
      state_dict = model.state_dict()
      torch.save(state_dict, path)
    • 保存整个Model(不推荐)
      torch.save(model, path)
  2. 可根据step步数、时间等周期性保存模型的训练过程的产物。

    将模型训练过程中的网络权重、优化器权重、以及epoch进行保存,便于中断后继续训练恢复。

       checkpoint = {
               "net": model.state_dict(),
               "optimizer": optimizer.state_dict(),
               "epoch": epoch   
       }
       if not os.path.isdir('model_save_dir'):
           os.makedirs('model_save_dir')
       torch.save(checkpoint,'model_save_dir/ckpt_{}.pth'.format(str(epoch)))
  3. 完整代码示例。
    import os
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--train_url", type=str)
    args, unparsed = parser.parse_known_args()
    args = parser.parse_known_args()
    # train_url 将被赋值为"/home/ma-user/modelarts/outputs/train_url_0" 
    train_url = args.train_url
    
    # 判断输出路径中是否有模型文件。如果无文件则默认从头训练,如果有模型文件,则加载epoch值最大的ckpt文件当做预训练模型。
    if os.listdir(train_url):
        print('> load last ckpt and continue training!!')
        last_ckpt = sorted([file for file in os.listdir(train_url) if file.endswith(".pth")])[-1]
        local_ckpt_file = os.path.join(train_url, last_ckpt)
        print('last_ckpt:', last_ckpt)
        # 加载断点
        checkpoint = torch.load(local_ckpt_file)  
        # 加载模型可学习参数
        model.load_state_dict(checkpoint['net'])  
        # 加载优化器参数
        optimizer.load_state_dict(checkpoint['optimizer'])  
        # 获取保存的epoch,模型会在此epoch的基础上继续训练
        start_epoch = checkpoint['epoch']  
    start = datetime.now()
    total_step = len(train_loader)
    for epoch in range(start_epoch + 1, args.epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ...
    
        # 保存模型训练过程中的网络权重、优化器权重、以及epoch
        checkpoint = {
              "net": model.state_dict(),
              "optimizer": optimizer.state_dict(),
              "epoch": epoch
            }
        if not os.path.isdir(train_url):
            os.makedirs(train_url)
            torch.save(checkpoint, os.path.join(train_url, 'ckpt_best_{}.pth'.format(epoch)))