Updated on 2024-12-31 GMT+08:00

Resumable Training

Overview

Resumable training indicates that an interrupted training job can be automatically resumed from the checkpoint where the previous training was interrupted. This method is applicable to model training that takes a long time.

The checkpoint mechanism enables resumable training.

During model training, training results (including but not limited to epochs, model weights, optimizer status, and scheduler status) are continuously saved. In this way, an interrupted training job can be automatically resumed from the checkpoint where the previous training was interrupted.

To resume a training job, load a checkpoint and use the checkpoint information to initialize the training status. To do so, add reload ckpt to the code.

Implementing Resumable Training in ModelArts Standard

To resume model training or incrementally train a model in ModelArts Standard, configure training output.

When creating a training job, set the data path to the training output, save checkpoints in this data path, and set Predownload to Yes. If you set Predownload to Yes, the system automatically downloads the checkpoint file in the training output data path to a local directory of the training container before the training job is started.

Figure 1 Configuring training output

Enable fault tolerance check (auto restart) for resumable training. On the training job creation page, enable Auto Restart. If the environment pre-check fails, the hardware is not functional, or the training job fails, ModelArts will automatically issue the training job again.

reload ckpt for PyTorch

  • Use either of the following methods to save a PyTorch model.
    • Save model parameters only.
      state_dict = model.state_dict()
      torch.save(state_dict, path)
    • Save the entire model (not recommended).
      torch.save(model, path)
  • Save the data generated during model training at regular intervals based on steps and time.

    The data includes the network weight, optimizer weight, and epoch, which will be used to resume the interrupted training.

       checkpoint = {
               "net": model.state_dict(),
               "optimizer": optimizer.state_dict(),
               "epoch": epoch   
       }
       if not os.path.isdir('model_save_dir'):
           os.makedirs('model_save_dir')
       torch.save(checkpoint,'model_save_dir/ckpt_{}.pth'.format(str(epoch)))
  • Check the complete code example below.
    import os
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--train_url", type=str)
    args, unparsed = parser.parse_known_args()
    args = parser.parse_known_args()
    # train_url is set to /home/ma-user/modelarts/outputs/train_url_0.
    train_url = args.train_url
    
    # Check whether there is a model file in the output path. If there is no file, the model will be trained from the beginning by default. If there is a model file, the CKPT file with the maximum epoch value will be loaded as the pre-trained model.
    if os.listdir(train_url):
        print('> load last ckpt and continue training!!')
        last_ckpt = sorted([file for file in os.listdir(train_url) if file.endswith(".pth")])[-1]
        local_ckpt_file = os.path.join(train_url, last_ckpt)
        print('last_ckpt:', last_ckpt)
        # Load the checkpoint.
        checkpoint = torch.load(local_ckpt_file)  
        # Load the parameters that can be learned by the model.
        model.load_state_dict(checkpoint['net'])  
        # Load optimizer parameters.
        optimizer.load_state_dict(checkpoint['optimizer'])  
        # Obtain the saved epoch. The model will continue to be trained based on the epoch value.
        start_epoch = checkpoint['epoch']  
    start = datetime.now()
    total_step = len(train_loader)
    for epoch in range(start_epoch + 1, args.epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ...
    
        # Save the network weight, optimizer weight, and epoch during model training.
        checkpoint = {
              "net": model.state_dict(),
              "optimizer": optimizer.state_dict(),
              "epoch": epoch
            }
        if not os.path.isdir(train_url):
            os.makedirs(train_url)
            torch.save(checkpoint, os.path.join(train_url, 'ckpt_best_{}.pth'.format(epoch)))