创建多机多卡的分布式训练(DistributedDataParallel)
在深度学习领域,随着模型复杂度的不断提升和训练规模的不断扩大,传统的单机训练方式已难以满足实际需求。在这一背景下,如何在多机多卡环境下实现高效的分布式训练成为了亟待解决的关键问题。
针对这一问题,本章节将详细介绍基于PyTorch引擎的多机多卡数据并行训练方法。本文不仅提供具体的代码适配操作过程,还通过完整的代码示例帮助读者更好地理解和实践。特别地,本文以ResNet18在CIFAR10数据集上的图像分类任务为例,展示了如何进行分布式训练改造(DDP),为读者提供可直接参考和复现的实践方案。
训练流程简述
相比于DP,DDP能够启动多进程进行运算,从而大幅度提升计算资源的利用率。可以基于torch.distributed实现真正的分布式计算,具体的原理此处不再赘述。大致的流程如下:
- 初始化进程组。
- 创建分布式并行模型,每个进程都会有相同的模型和参数。
- 创建数据分发Sampler,使每个进程加载一个mini batch中不同部分的数据。
- 网络中相邻参数分桶,一般为神经网络模型中需要进行参数更新的每一层网络。
- 每个进程前向传播并各自计算梯度。
- 模型某一层的参数得到梯度后会马上进行通讯并进行梯度平均。
- 各GPU更新模型参数。
具体流程图如下:

代码改造点
- 引入多进程启动机制:初始化进程
- 引入几个变量:tcp协议,rank进程序号,worldsize开启的进程数量
- 分发数据:DataLoader中多了一个Sampler参数,避免不同进程数据重复
- 模型分发:DistributedDataParallel(model)
- 模型保存:在序号为0的进程下保存模型
import torch class Net(torch.nn.Module): pass model = Net().cuda() ### DistributedDataParallel Begin ### model = torch.nn.parallel.DistributedDataParallel(Net().cuda()) ### DistributedDataParallel End ###
多节点分布式调测适配及代码示例
在分布式训练中,DistributedDataParallel(DDP)是一种常用的数据并行训练方法。在这一机制下,不同进程会分别从原始数据中加载batch数据,最终通过将各个进程计算得到的梯度进行平均,得到最终的梯度值。由于DDP能够利用更多的样本量进行计算,因此其计算出的梯度更加可靠,这也意味着在分布式训练中可以适当增大学习率以加快收敛速度。
为了帮助读者更好地理解和实践,我们以ResNet18在CIFAR10数据集上的图像分类任务为例,提供了完整的单机训练和分布式训练改造(DDP)代码示例。该代码支持多节点分布式训练,同时兼容CPU和GPU分布式训练环境。值得注意的是,用户可以通过注释掉代码中的分布式改造点,轻松切换为单节点单卡训练模式。
在训练代码中,我们设计了三个主要参数输入模块:训练基础参数、分布式参数和数据相关参数。其中,分布式参数由平台自动注入,无需用户手动定义。在数据相关参数中,我们特别设计了一个custom_data开关,用于控制是否使用自定义数据进行训练。当custom_data设置为"true"时,系统将基于PyTorch生成随机数据进行训练和验证,这为用户提供了灵活的实验环境。
cifar10数据集
在Notebook中,无法直接使用默认版本的torchvision获取数据集,因此示例代码中提供了三种训练数据加载方式。
cifar-10数据集下载链接,单击“CIFAR-10 python version”。
- 尝试基于torchvision获取cifar10数据集。
- 基于数据链接下载数据并解压,放置在指定目录下,训练集和测试集的大小分别为(50000,3,32,32)和(10000,3,32,32)。
- 考虑到下载cifar10数据集较慢,基于torch生成类似cifar10的随机数据集,训练集和测试集的大小分别为(5000,3,32,32)和(1000,3,32,32),标签仍为10类,指定custom_data = 'true'后可直接进行训练作业,无需加载数据。
训练代码
以下代码中以“### 分布式改造,... ###”注释的代码即为多节点分布式训练需要适配的代码改造点。
不对示例代码进行任何修改,适配数据路径后即可在ModelArts上完成多节点分布式训练。
注释掉分布式代码改造点,即可完成单节点单卡训练。完整代码见分布式训练完整代码示例。
- 导入依赖包
import datetime import inspect import os import pickle import random import argparse import numpy as np import torch import torch.distributed as dist from torch import nn, optim from torch.utils.data import TensorDataset, DataLoader from torch.utils.data.distributed import DistributedSampler from sklearn.metrics import accuracy_score
- 定义加载数据的方法和随机数,由于加载数据部分代码较多,此处省略
def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def get_data(path): pass
- 定义网络结构
class Block(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.residual_function = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels) ) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): out = self.residual_function(x) + self.shortcut(x) return nn.ReLU(inplace=True)(out) class ResNet(nn.Module): def __init__(self, block, num_classes=10): super().__init__() self.conv1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) self.conv2 = self.make_layer(block, 64, 64, 2, 1) self.conv3 = self.make_layer(block, 64, 128, 2, 2) self.conv4 = self.make_layer(block, 128, 256, 2, 2) self.conv5 = self.make_layer(block, 256, 512, 2, 2) self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.dense_layer = nn.Linear(512, num_classes) def make_layer(self, block, in_channels, out_channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(in_channels, out_channels, stride)) in_channels = out_channels return nn.Sequential(*layers) def forward(self, x): out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) out = self.conv4(out) out = self.conv5(out) out = self.avg_pool(out) out = out.view(out.size(0), -1) out = self.dense_layer(out) return out
- 进行训练和验证
def main(): file_dir = os.path.dirname(inspect.getframeinfo(inspect.currentframe()).filename) seed = datetime.datetime.now().year setup_seed(seed) parser = argparse.ArgumentParser(description='Pytorch distribute training', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--enable_gpu', default='true') parser.add_argument('--lr', default='0.01', help='learning rate') parser.add_argument('--epochs', default='100', help='training iteration') parser.add_argument('--init_method', default=None, help='tcp_port') parser.add_argument('--rank', type=int, default=0, help='index of current task') parser.add_argument('--world_size', type=int, default=1, help='total number of tasks') parser.add_argument('--custom_data', default='false') parser.add_argument('--data_url', type=str, default=os.path.join(file_dir, 'input_dir')) parser.add_argument('--output_dir', type=str, default=os.path.join(file_dir, 'output_dir')) args, unknown = parser.parse_known_args() args.enable_gpu = args.enable_gpu == 'true' args.custom_data = args.custom_data == 'true' args.lr = float(args.lr) args.epochs = int(args.epochs) if args.custom_data: print('[warning] you are training on custom random dataset, ' 'validation accuracy may range from 0.4 to 0.6.') ### 分布式改造,DDP初始化进程,其中init_method, rank和world_size参数均由平台自动入参 ### dist.init_process_group(init_method=args.init_method, backend="nccl", world_size=args.world_size, rank=args.rank) ### 分布式改造,DDP初始化进程,其中init_method, rank和world_size参数均由平台自动入参 ### tr_set, val_set = get_data(args.data_url, custom_data=args.custom_data) batch_per_gpu = 128 gpus_per_node = torch.cuda.device_count() if args.enable_gpu else 1 batch = batch_per_gpu * gpus_per_node tr_loader = DataLoader(tr_set, batch_size=batch, shuffle=False) ### 分布式改造,构建DDP分布式数据sampler,确保不同进程加载到不同的数据 ### tr_sampler = DistributedSampler(tr_set, num_replicas=args.world_size, rank=args.rank) tr_loader = DataLoader(tr_set, batch_size=batch, sampler=tr_sampler, shuffle=False, drop_last=True) ### 分布式改造,构建DDP分布式数据sampler,确保不同进程加载到不同的数据 ### val_loader = DataLoader(val_set, batch_size=batch, shuffle=False) lr = args.lr * gpus_per_node max_epoch = args.epochs model = ResNet(Block).cuda() if args.enable_gpu else ResNet(Block) ### 分布式改造,构建DDP分布式模型 ### model = nn.parallel.DistributedDataParallel(model) ### 分布式改造,构建DDP分布式模型 ### optimizer = optim.Adam(model.parameters(), lr=lr) loss_func = torch.nn.CrossEntropyLoss() os.makedirs(args.output_dir, exist_ok=True) for epoch in range(1, max_epoch + 1): model.train() train_loss = 0 ### 分布式改造,DDP sampler, 基于当前的epoch为其设置随机数,避免加载到重复数据 ### tr_sampler.set_epoch(epoch) ### 分布式改造,DDP sampler, 基于当前的epoch为其设置随机数,避免加载到重复数据 ### for step, (tr_x, tr_y) in enumerate(tr_loader): if args.enable_gpu: tr_x, tr_y = tr_x.cuda(), tr_y.cuda() out = model(tr_x) loss = loss_func(out, tr_y) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() print('train | epoch: %d | loss: %.4f' % (epoch, train_loss / len(tr_loader))) val_loss = 0 pred_record = [] real_record = [] model.eval() with torch.no_grad(): for step, (val_x, val_y) in enumerate(val_loader): if args.enable_gpu: val_x, val_y = val_x.cuda(), val_y.cuda() out = model(val_x) pred_record += list(np.argmax(out.cpu().numpy(), axis=1)) real_record += list(val_y.cpu().numpy()) val_loss += loss_func(out, val_y).item() val_accu = accuracy_score(real_record, pred_record) print('val | epoch: %d | loss: %.4f | accuracy: %.4f' % (epoch, val_loss / len(val_loader), val_accu), '\n') if args.rank == 0: # save ckpt every epoch torch.save(model.state_dict(), os.path.join(args.output_dir, f'epoch_{epoch}.pth')) if __name__ == '__main__': main()
- 结果对比
分别以单机单卡和两节点16卡两种资源类型完成100epoch的cifar-10数据集训练,训练时长和测试集准确率如下。
表1 训练结果对比 资源类型
单机单卡
两节点16卡
耗时
60分钟
20分钟
准确率
80+
80+
分布式训练完整代码示例
以下对resnet18在cifar10数据集上的分类任务,给出了分布式训练改造(DDP)的完整代码示例。
训练启动文件main.py内容如下(如果需要执行单机单卡训练作业,则将分布式改造的代码删除):
import datetime import inspect import os import pickle import random import logging import argparse import numpy as np from sklearn.metrics import accuracy_score import torch from torch import nn, optim import torch.distributed as dist from torch.utils.data import TensorDataset, DataLoader from torch.utils.data.distributed import DistributedSampler file_dir = os.path.dirname(inspect.getframeinfo(inspect.currentframe()).filename) def load_pickle_data(path): with open(path, 'rb') as file: data = pickle.load(file, encoding='bytes') return data def _load_data(file_path): raw_data = load_pickle_data(file_path) labels = raw_data[b'labels'] data = raw_data[b'data'] filenames = raw_data[b'filenames'] data = data.reshape(10000, 3, 32, 32) / 255 return data, labels, filenames def load_cifar_data(root_path): train_root_path = os.path.join(root_path, 'cifar-10-batches-py/data_batch_') train_data_record = [] train_labels = [] train_filenames = [] for i in range(1, 6): train_file_path = train_root_path + str(i) data, labels, filenames = _load_data(train_file_path) train_data_record.append(data) train_labels += labels train_filenames += filenames train_data = np.concatenate(train_data_record, axis=0) train_labels = np.array(train_labels) val_file_path = os.path.join(root_path, 'cifar-10-batches-py/test_batch') val_data, val_labels, val_filenames = _load_data(val_file_path) val_labels = np.array(val_labels) tr_data = torch.from_numpy(train_data).float() tr_labels = torch.from_numpy(train_labels).long() val_data = torch.from_numpy(val_data).float() val_labels = torch.from_numpy(val_labels).long() return tr_data, tr_labels, val_data, val_labels def get_data(root_path, custom_data=False): if custom_data: train_samples, test_samples, img_size = 5000, 1000, 32 tr_label = [1] * int(train_samples / 2) + [0] * int(train_samples / 2) val_label = [1] * int(test_samples / 2) + [0] * int(test_samples / 2) random.seed(2021) random.shuffle(tr_label) random.shuffle(val_label) tr_data, tr_labels = torch.randn((train_samples, 3, img_size, img_size)).float(), torch.tensor(tr_label).long() val_data, val_labels = torch.randn((test_samples, 3, img_size, img_size)).float(), torch.tensor( val_label).long() tr_set = TensorDataset(tr_data, tr_labels) val_set = TensorDataset(val_data, val_labels) return tr_set, val_set elif os.path.exists(os.path.join(root_path, 'cifar-10-batches-py')): tr_data, tr_labels, val_data, val_labels = load_cifar_data(root_path) tr_set = TensorDataset(tr_data, tr_labels) val_set = TensorDataset(val_data, val_labels) return tr_set, val_set else: try: import torchvision from torchvision import transforms tr_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms) val_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms) return tr_set, val_set except Exception as e: raise Exception( f"{e}, you can download and unzip cifar-10 dataset manually, " "the data url is http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz") class Block(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.residual_function = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels) ) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): out = self.residual_function(x) + self.shortcut(x) return nn.ReLU(inplace=True)(out) class ResNet(nn.Module): def __init__(self, block, num_classes=10): super().__init__() self.conv1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) self.conv2 = self.make_layer(block, 64, 64, 2, 1) self.conv3 = self.make_layer(block, 64, 128, 2, 2) self.conv4 = self.make_layer(block, 128, 256, 2, 2) self.conv5 = self.make_layer(block, 256, 512, 2, 2) self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.dense_layer = nn.Linear(512, num_classes) def make_layer(self, block, in_channels, out_channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(in_channels, out_channels, stride)) in_channels = out_channels return nn.Sequential(*layers) def forward(self, x): out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) out = self.conv4(out) out = self.conv5(out) out = self.avg_pool(out) out = out.view(out.size(0), -1) out = self.dense_layer(out) return out def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def obs_transfer(src_path, dst_path): import moxing as mox mox.file.copy_parallel(src_path, dst_path) logging.info(f"end copy data from {src_path} to {dst_path}") def main(): seed = datetime.datetime.now().year setup_seed(seed) parser = argparse.ArgumentParser(description='Pytorch distribute training', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--enable_gpu', default='true') parser.add_argument('--lr', default='0.01', help='learning rate') parser.add_argument('--epochs', default='100', help='training iteration') parser.add_argument('--init_method', default=None, help='tcp_port') parser.add_argument('--rank', type=int, default=0, help='index of current task') parser.add_argument('--world_size', type=int, default=1, help='total number of tasks') parser.add_argument('--custom_data', default='false') parser.add_argument('--data_url', type=str, default=os.path.join(file_dir, 'input_dir')) parser.add_argument('--output_dir', type=str, default=os.path.join(file_dir, 'output_dir')) args, unknown = parser.parse_known_args() args.enable_gpu = args.enable_gpu == 'true' args.custom_data = args.custom_data == 'true' args.lr = float(args.lr) args.epochs = int(args.epochs) if args.custom_data: logging.warning('you are training on custom random dataset, ' 'validation accuracy may range from 0.4 to 0.6.') ### 分布式改造,DDP初始化进程,其中init_method, rank和world_size参数均由平台自动入参 ### dist.init_process_group(init_method=args.init_method, backend="nccl", world_size=args.world_size, rank=args.rank) ### 分布式改造,DDP初始化进程,其中init_method, rank和world_size参数均由平台自动入参 ### tr_set, val_set = get_data(args.data_url, custom_data=args.custom_data) batch_per_gpu = 128 gpus_per_node = torch.cuda.device_count() if args.enable_gpu else 1 batch = batch_per_gpu * gpus_per_node tr_loader = DataLoader(tr_set, batch_size=batch, shuffle=False) ### 分布式改造,构建DDP分布式数据sampler,确保不同进程加载到不同的数据 ### tr_sampler = DistributedSampler(tr_set, num_replicas=args.world_size, rank=args.rank) tr_loader = DataLoader(tr_set, batch_size=batch, sampler=tr_sampler, shuffle=False, drop_last=True) ### 分布式改造,构建DDP分布式数据sampler,确保不同进程加载到不同的数据 ### val_loader = DataLoader(val_set, batch_size=batch, shuffle=False) lr = args.lr * gpus_per_node * args.world_size max_epoch = args.epochs model = ResNet(Block).cuda() if args.enable_gpu else ResNet(Block) ### 分布式改造,构建DDP分布式模型 ### model = nn.parallel.DistributedDataParallel(model) ### 分布式改造,构建DDP分布式模型 ### optimizer = optim.Adam(model.parameters(), lr=lr) loss_func = torch.nn.CrossEntropyLoss() os.makedirs(args.output_dir, exist_ok=True) for epoch in range(1, max_epoch + 1): model.train() train_loss = 0 ### 分布式改造,DDP sampler, 基于当前的epoch为其设置随机数,避免加载到重复数据 ### tr_sampler.set_epoch(epoch) ### 分布式改造,DDP sampler, 基于当前的epoch为其设置随机数,避免加载到重复数据 ### for step, (tr_x, tr_y) in enumerate(tr_loader): if args.enable_gpu: tr_x, tr_y = tr_x.cuda(), tr_y.cuda() out = model(tr_x) loss = loss_func(out, tr_y) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() print('train | epoch: %d | loss: %.4f' % (epoch, train_loss / len(tr_loader))) val_loss = 0 pred_record = [] real_record = [] model.eval() with torch.no_grad(): for step, (val_x, val_y) in enumerate(val_loader): if args.enable_gpu: val_x, val_y = val_x.cuda(), val_y.cuda() out = model(val_x) pred_record += list(np.argmax(out.cpu().numpy(), axis=1)) real_record += list(val_y.cpu().numpy()) val_loss += loss_func(out, val_y).item() val_accu = accuracy_score(real_record, pred_record) print('val | epoch: %d | loss: %.4f | accuracy: %.4f' % (epoch, val_loss / len(val_loader), val_accu), '\n') if args.rank == 0: # save ckpt every epoch torch.save(model.state_dict(), os.path.join(args.output_dir, f'epoch_{epoch}.pth')) if __name__ == '__main__': main()
常见问题
- 示例代码中如何使用不同的数据集?
- 上述代码如果使用cifar10数据集,则将数据集下载并解压后,上传至OBS桶中,文件目录结构如下:
DDP |--- main.py |--- input_dir |------ cifar-10-batches-py |-------- data_batch_1 |-------- data_batch_2 |-------- ...
其中“DDP”为创建训练作业时的“代码目录”,“main.py”为上文代码示例(即创建训练作业时的“启动文件”),“cifar-10-batches-py”为解压后的数据集文件夹(放在input_dir文件夹下)。
- 如果使用自定义的随机数据,则将代码示例中的参数“custom_data”改为“true”,修改后内容如下:
parser.add_argument('--custom_data', default='true')
然后直接运行代码示例“main.py”即可,创建训练作业的参数与上图相同。
- 上述代码如果使用cifar10数据集,则将数据集下载并解压后,上传至OBS桶中,文件目录结构如下:
- 为什么DDP可以不输入主节点ip?
“parser.add_argument('--init_method', default=None, help='tcp_port')”中的init method参数值会包含主节点的ip和端口,由平台自动入参,不需要用户输入主节点的ip和端口。