更新时间:2024-05-06 GMT+08:00
分布式训练完整代码示例
以下对resnet18在cifar10数据集上的分类任务,给出了分布式训练改造(DDP)的完整代码示例。
训练启动文件main.py内容如下(如果需要执行单机单卡训练任务,则将分布式改造的代码删除):
import datetime import inspect import os import pickle import random import logging import argparse import numpy as np from sklearn.metrics import accuracy_score import torch from torch import nn, optim import torch.distributed as dist from torch.utils.data import TensorDataset, DataLoader from torch.utils.data.distributed import DistributedSampler file_dir = os.path.dirname(inspect.getframeinfo(inspect.currentframe()).filename) def load_pickle_data(path): with open(path, 'rb') as file: data = pickle.load(file, encoding='bytes') return data def _load_data(file_path): raw_data = load_pickle_data(file_path) labels = raw_data[b'labels'] data = raw_data[b'data'] filenames = raw_data[b'filenames'] data = data.reshape(10000, 3, 32, 32) / 255 return data, labels, filenames def load_cifar_data(root_path): train_root_path = os.path.join(root_path, 'cifar-10-batches-py/data_batch_') train_data_record = [] train_labels = [] train_filenames = [] for i in range(1, 6): train_file_path = train_root_path + str(i) data, labels, filenames = _load_data(train_file_path) train_data_record.append(data) train_labels += labels train_filenames += filenames train_data = np.concatenate(train_data_record, axis=0) train_labels = np.array(train_labels) val_file_path = os.path.join(root_path, 'cifar-10-batches-py/test_batch') val_data, val_labels, val_filenames = _load_data(val_file_path) val_labels = np.array(val_labels) tr_data = torch.from_numpy(train_data).float() tr_labels = torch.from_numpy(train_labels).long() val_data = torch.from_numpy(val_data).float() val_labels = torch.from_numpy(val_labels).long() return tr_data, tr_labels, val_data, val_labels def get_data(root_path, custom_data=False): if custom_data: train_samples, test_samples, img_size = 5000, 1000, 32 tr_label = [1] * int(train_samples / 2) + [0] * int(train_samples / 2) val_label = [1] * int(test_samples / 2) + [0] * int(test_samples / 2) random.seed(2021) random.shuffle(tr_label) random.shuffle(val_label) tr_data, tr_labels = torch.randn((train_samples, 3, img_size, img_size)).float(), torch.tensor(tr_label).long() val_data, val_labels = torch.randn((test_samples, 3, img_size, img_size)).float(), torch.tensor( val_label).long() tr_set = TensorDataset(tr_data, tr_labels) val_set = TensorDataset(val_data, val_labels) return tr_set, val_set elif os.path.exists(os.path.join(root_path, 'cifar-10-batches-py')): tr_data, tr_labels, val_data, val_labels = load_cifar_data(root_path) tr_set = TensorDataset(tr_data, tr_labels) val_set = TensorDataset(val_data, val_labels) return tr_set, val_set else: try: import torchvision from torchvision import transforms tr_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms) val_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms) return tr_set, val_set except Exception as e: raise Exception( f"{e}, you can download and unzip cifar-10 dataset manually, " "the data url is http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz") class Block(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.residual_function = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels) ) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): out = self.residual_function(x) + self.shortcut(x) return nn.ReLU(inplace=True)(out) class ResNet(nn.Module): def __init__(self, block, num_classes=10): super().__init__() self.conv1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) self.conv2 = self.make_layer(block, 64, 64, 2, 1) self.conv3 = self.make_layer(block, 64, 128, 2, 2) self.conv4 = self.make_layer(block, 128, 256, 2, 2) self.conv5 = self.make_layer(block, 256, 512, 2, 2) self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.dense_layer = nn.Linear(512, num_classes) def make_layer(self, block, in_channels, out_channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(in_channels, out_channels, stride)) in_channels = out_channels return nn.Sequential(*layers) def forward(self, x): out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) out = self.conv4(out) out = self.conv5(out) out = self.avg_pool(out) out = out.view(out.size(0), -1) out = self.dense_layer(out) return out def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def obs_transfer(src_path, dst_path): import moxing as mox mox.file.copy_parallel(src_path, dst_path) logging.info(f"end copy data from {src_path} to {dst_path}") def main(): seed = datetime.datetime.now().year setup_seed(seed) parser = argparse.ArgumentParser(description='Pytorch distribute training', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--enable_gpu', default='true') parser.add_argument('--lr', default='0.01', help='learning rate') parser.add_argument('--epochs', default='100', help='training iteration') parser.add_argument('--init_method', default=None, help='tcp_port') parser.add_argument('--rank', type=int, default=0, help='index of current task') parser.add_argument('--world_size', type=int, default=1, help='total number of tasks') parser.add_argument('--custom_data', default='false') parser.add_argument('--data_url', type=str, default=os.path.join(file_dir, 'input_dir')) parser.add_argument('--output_dir', type=str, default=os.path.join(file_dir, 'output_dir')) args, unknown = parser.parse_known_args() args.enable_gpu = args.enable_gpu == 'true' args.custom_data = args.custom_data == 'true' args.lr = float(args.lr) args.epochs = int(args.epochs) if args.custom_data: logging.warning('you are training on custom random dataset, ' 'validation accuracy may range from 0.4 to 0.6.') ### 分布式改造,DDP初始化进程,其中init_method, rank和world_size参数均由平台自动入参 ### dist.init_process_group(init_method=args.init_method, backend="nccl", world_size=args.world_size, rank=args.rank) ### 分布式改造,DDP初始化进程,其中init_method, rank和world_size参数均由平台自动入参 ### tr_set, val_set = get_data(args.data_url, custom_data=args.custom_data) batch_per_gpu = 128 gpus_per_node = torch.cuda.device_count() if args.enable_gpu else 1 batch = batch_per_gpu * gpus_per_node tr_loader = DataLoader(tr_set, batch_size=batch, shuffle=False) ### 分布式改造,构建DDP分布式数据sampler,确保不同进程加载到不同的数据 ### tr_sampler = DistributedSampler(tr_set, num_replicas=args.world_size, rank=args.rank) tr_loader = DataLoader(tr_set, batch_size=batch, sampler=tr_sampler, shuffle=False, drop_last=True) ### 分布式改造,构建DDP分布式数据sampler,确保不同进程加载到不同的数据 ### val_loader = DataLoader(val_set, batch_size=batch, shuffle=False) lr = args.lr * gpus_per_node * args.world_size max_epoch = args.epochs model = ResNet(Block).cuda() if args.enable_gpu else ResNet(Block) ### 分布式改造,构建DDP分布式模型 ### model = nn.parallel.DistributedDataParallel(model) ### 分布式改造,构建DDP分布式模型 ### optimizer = optim.Adam(model.parameters(), lr=lr) loss_func = torch.nn.CrossEntropyLoss() os.makedirs(args.output_dir, exist_ok=True) for epoch in range(1, max_epoch + 1): model.train() train_loss = 0 ### 分布式改造,DDP sampler, 基于当前的epoch为其设置随机数,避免加载到重复数据 ### tr_sampler.set_epoch(epoch) ### 分布式改造,DDP sampler, 基于当前的epoch为其设置随机数,避免加载到重复数据 ### for step, (tr_x, tr_y) in enumerate(tr_loader): if args.enable_gpu: tr_x, tr_y = tr_x.cuda(), tr_y.cuda() out = model(tr_x) loss = loss_func(out, tr_y) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() print('train | epoch: %d | loss: %.4f' % (epoch, train_loss / len(tr_loader))) val_loss = 0 pred_record = [] real_record = [] model.eval() with torch.no_grad(): for step, (val_x, val_y) in enumerate(val_loader): if args.enable_gpu: val_x, val_y = val_x.cuda(), val_y.cuda() out = model(val_x) pred_record += list(np.argmax(out.cpu().numpy(), axis=1)) real_record += list(val_y.cpu().numpy()) val_loss += loss_func(out, val_y).item() val_accu = accuracy_score(real_record, pred_record) print('val | epoch: %d | loss: %.4f | accuracy: %.4f' % (epoch, val_loss / len(val_loader), val_accu), '\n') if args.rank == 0: # save ckpt every epoch torch.save(model.state_dict(), os.path.join(args.output_dir, f'epoch_{epoch}.pth')) if __name__ == '__main__': main()
常见问题
1、示例代码中如何使用不同的数据集?
- 上述代码如果使用cifar10数据集,则将数据集下载并解压后,上传至OBS桶中,文件目录结构如下:
DDP |--- main.py |--- input_dir |------ cifar-10-batches-py |-------- data_batch_1 |-------- data_batch_2 |-------- ...
其中“DDP”为创建训练作业时的“代码目录”,“main.py”为上文代码示例(即创建训练作业时的“启动文件”),“cifar-10-batches-py”为解压后的数据集文件夹(放在input_dir文件夹下)。
- 如果使用自定义的随机数据,则将代码示例中的参数“custom_data”改为“true”,修改后内容如下:
parser.add_argument('--custom_data', default='true')
然后直接运行代码示例“main.py”即可,创建训练作业的参数与上图相同。
2、为什么DDP可以不输入主节点ip?
“parser.add_argument('--init_method', default=None, help='tcp_port')”中的init method参数值会包含主节点的ip和端口,由平台自动入参,不需要用户输入主节点的ip和端口。
父主题: 分布式训练