温馨提示×

如何在Linux上实现PyTorch分布式训练

小樊
53
2025-10-14 05:27:07
栏目: 智能运维

在Linux上实现PyTorch分布式训练,可以按照以下步骤进行:

1. 环境准备

确保你已经安装了Python和PyTorch。你可以使用Anaconda或直接通过pip安装。

# 安装Anaconda(如果还没有安装)
wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
bash Anaconda3-2023.07-2-Linux-x86_64.sh

# 创建一个新的conda环境
conda create -n pytorch_dist python=3.9
conda activate pytorch_dist

# 安装PyTorch
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch

2. 启动分布式训练

PyTorch提供了多种启动分布式训练的方法,这里我们使用torch.distributed.launch工具。

示例代码

假设你有一个简单的训练脚本train.py,内容如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(784, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        return self.fc(x)

def train(rank, world_size):
    # 初始化进程组
    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)

    # 创建模型、损失函数和优化器
    model = SimpleNet().to(rank)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    # 加载数据集
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)

    # 训练模型
    for epoch in range(5):
        sampler.set_epoch(epoch)
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data[0].to(rank), data[1].to(rank)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Rank {rank}, Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}')

    # 销毁进程组
    torch.distributed.destroy_process_group()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--world-size', type=int, default=4, help='number of distributed processes')
    parser.add_argument('--rank', type=int, default=0, help='rank of the process')
    args = parser.parse_args()

    train(args.rank, args.world_size)

启动分布式训练

使用torch.distributed.launch启动分布式训练:

python -m torch.distributed.launch --nproc_per_node=4 train.py --world-size 4 --rank 0

这里的--nproc_per_node参数指定了每个节点上的GPU数量,--world-size是总的进程数,--rank是当前进程的排名。

3. 验证分布式训练

确保所有进程都能正常运行并且输出结果一致。你可以通过查看日志文件或直接在终端中观察输出来验证。

4. 注意事项

  • 网络配置:确保所有节点之间可以互相通信。
  • GPU配置:确保每个节点都有可用的GPU,并且CUDA驱动和库已正确安装。
  • 环境变量:确保所有节点上的环境变量设置一致,特别是NCCL_DEBUGHOROVOD_TIMELINE等调试相关的环境变量。

通过以上步骤,你应该能够在Linux上成功实现PyTorch的分布式训练。

0