更新时间:2024-12-10 GMT+08:00
分享

增量模型训练

什么是增量训练

增量训练(Incremental Learning)是机器学习领域中的一种训练方法,它允许人工智能(AI)模型在已经学习了一定知识的基础上,增加新的训练数据到当前训练流程中,扩展当前模型的知识和能力,而不需要从头开始。

增量训练不需要一次性存储所有的训练数据,缓解了存储资源有限的问题;另一方面,增量训练节约了重新训练中需要消耗大量算力、时间以及经济成本。

增量训练特别适用于以下情况:

  • 数据流更新:在实际应用中,数据可能会持续更新,增量训练允许模型适应新的数据而不必重新训练。
  • 资源限制:如果重新训练一个大型模型成本过高,增量训练可以是一个更经济的选择。
  • 避免灾难性遗忘:在传统训练中,新数据可能会覆盖旧数据的知识,导致模型忘记之前学到的内容。增量训练通过保留旧知识的同时学习新知识来避免这个问题。

增量训练在很多领域都有应用,比如自然语言处理、计算机视觉和推荐系统等。它使得AI系统能够更加灵活和适应性强,更好地应对现实世界中不断变化的数据环境。

ModelArts Standard中如何实现增量训练

增量训练是通过Checkpoint机制实现。

Checkpoint的机制是:在模型训练的过程中,不断地保存训练结果(包括但不限于EPOCH、模型权重、优化器状态、调度器状态)。当需要增加新的数据继续训练时,只需要加载Checkpoint,并用Checkpoint信息初始化训练状态即可。用户需要在代码里加上reload ckpt的代码,使能读取前一次训练保存的预训练模型。

在ModelArts训练中实现增量训练,建议使用“训练输出”功能。

在创建训练作业时,设置训练“输出”参数为“train_url”,在指定的训练输出的数据存储位置中保存Checkpoint,“预下载至本地目录”选择“下载”。选择预下载至本地目录时,系统在训练作业启动前,自动将数据存储位置中的Checkpoint文件下载到训练容器的本地目录。

图1 训练输出设置

PyTorch版reload ckpt

  1. PyTorch模型保存有两种方式。
    • 仅保存模型参数
      state_dict = model.state_dict()
      torch.save(state_dict, path)
    • 保存整个Model(不推荐)
      torch.save(model, path)
  2. 可根据step步数、时间等周期性保存模型的训练过程的产物。

    将模型训练过程中的网络权重、优化器权重、以及epoch进行保存,便于中断后继续训练恢复。

       checkpoint = {
               "net": model.state_dict(),
               "optimizer": optimizer.state_dict(),
               "epoch": epoch   
       }
       if not os.path.isdir('model_save_dir'):
           os.makedirs('model_save_dir')
       torch.save(checkpoint,'model_save_dir/ckpt_{}.pth'.format(str(epoch)))
  3. 完整代码示例。
    import os
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--train_url", type=str)
    args, unparsed = parser.parse_known_args()
    args = parser.parse_known_args()
    # train_url 将被赋值为"/home/ma-user/modelarts/outputs/train_url_0" 
    train_url = args.train_url
    
    # 判断输出路径中是否有模型文件。如果无文件则默认从头训练,如果有模型文件,则加载epoch值最大的ckpt文件当做预训练模型。
    if os.listdir(train_url):
        print('> load last ckpt and continue training!!')
        last_ckpt = sorted([file for file in os.listdir(train_url) if file.endswith(".pth")])[-1]
        local_ckpt_file = os.path.join(train_url, last_ckpt)
        print('last_ckpt:', last_ckpt)
        # 加载断点
        checkpoint = torch.load(local_ckpt_file)  
        # 加载模型可学习参数
        model.load_state_dict(checkpoint['net'])  
        # 加载优化器参数
        optimizer.load_state_dict(checkpoint['optimizer'])  
        # 获取保存的epoch,模型会在此epoch的基础上继续训练
        start_epoch = checkpoint['epoch']  
    start = datetime.now()
    total_step = len(train_loader)
    for epoch in range(start_epoch + 1, args.epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ...
    
        # 保存模型训练过程中的网络权重、优化器权重、以及epoch
        checkpoint = {
              "net": model.state_dict(),
              "optimizer": optimizer.state_dict(),
              "epoch": epoch
            }
        if not os.path.isdir(train_url):
            os.makedirs(train_url)
            torch.save(checkpoint, os.path.join(train_url, 'ckpt_best_{}.pth'.format(epoch)))

MindSpore版reload ckpt

import os
import argparse
from resnet import resnet50
from mindspore.nn.optim.momentum import Momentum 
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint
from mindspore.train.callback import LossMonitor

parser = argparse.ArgumentParser()
parser.add_argument("--train_url", type=str)
parser.add_argument("--batch_size", type=int, default=32, help="Batch size.") 
parser.add_argument("--num_classes", type=int, default=10, help="Num classes.") 
parser.add_argument("--do_train", type=bool, default=True, help="Do train or not.") 
args_opt, unparsed = parser.parse_known_args()
# train_url 将被赋值为"/home/ma-user/modelarts/outputs/train_url_0" 。
train_url = args_opt.train_url

# 初始定义的网络、损失函数及优化器,详细请参见MindSpore保存与加载。
# 1.初始定义的网络,以“ResNet50”为例。详细请参见ResNet50。
net = resnet50(args_opt.batch_size, args_opt.num_classes)
# 2.定义损失函数,详细请参见MindSpore自定义损失函数。
ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# 3.定义优化器,详细请参见MindSpore自定义优化器。
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
# 首次训练的epoch初始值,mindspore1.3及以后版本会支持定义epoch_size初始值。
# cur_epoch_num = 0
# 判断输出obs路径中是否有模型文件。如果无文件则默认从头训练,如果有模型文件,则加载epoch值最大的ckpt文件当做预训练模型。
if os.listdir(train_url):
    last_ckpt = sorted([file for file in os.listdir(train_url) if file.endswith(".ckpt")])[-1]
    print('last_ckpt:', last_ckpt)
    last_ckpt_file = os.path.join(train_url, last_ckpt)
     # 加载断点,详细请参见mindspore.load_checkpoint。
    param_dict = load_checkpoint(last_ckpt_file) 
    print('> load last ckpt and continue training!!')
    # 加载模型参数到net。
    load_param_into_net(net, param_dict)
    # 加载模型参数到opt。
    load_param_into_net(opt, param_dict)

    # 获取保存的epoch值,模型会在此epoch的基础上继续训练,此参数在mindspore1.3及以后版本会支持。
    # if param_dict.get("epoch_num"):
    #     cur_epoch_num = int(param_dict["epoch_num"].data.asnumpy())
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
# as for train, users could use model.train
if args_opt.do_train:
    dataset = create_dataset()
    batch_num = dataset.get_dataset_size()
    config_ck = CheckpointConfig(save_checkpoint_steps=batch_num,
                                     keep_checkpoint_max=35)
    # append_info=[{"epoch_num": cur_epoch_num}],mindspore1.3及以后版本会支持append_info参数,保存当前时刻的epoch值。
    # 保存网络参数,详细请参见mindspore.train.ModelCheckpoint。
    ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10",
                                     directory=args_opt.train_url,
                                     config=config_ck)
    loss_cb = LossMonitor()
    model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
    # model.train(epoch_size-cur_epoch_num, dataset, callbacks=[ckpoint_cb, loss_cb]),mindspore1.3及以后版本支持从断点恢复训练。

相关文档