温馨提示×

如何在Linux中使用PyTorch进行分布式训练

小樊
45
2025-12-08 15:05:47
栏目: 智能运维

在Linux中使用PyTorch进行分布式训练,可以遵循以下步骤:

1. 环境准备

  • 安装PyTorch:确保你已经安装了PyTorch。可以使用pip或conda进行安装。
    pip install torch torchvision
    
  • 安装NCCL:NCCL(NVIDIA Collective Communications Library)是用于多GPU和多节点通信的库。
    conda install -c nvidia nccl
    

2. 配置分布式环境

  • 设置环境变量
    export MASTER_ADDR='<master_ip>'
    export MASTER_PORT='<master_port>'
    export WORLD_SIZE='<number_of_gpus_or_nodes>'
    export RANK='<rank_id>'
    
    • MASTER_ADDR:主节点的IP地址。
    • MASTER_PORT:主节点的端口号。
    • WORLD_SIZE:总的GPU或节点数。
    • RANK:当前节点的排名。

3. 编写分布式训练代码

使用PyTorch的torch.distributed模块来编写分布式训练代码。以下是一个简单的示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms

def main(rank, world_size):
    # 初始化分布式环境
    dist.init_process_group(backend='nccl', init_method=f'tcp://{master_addr}:{master_port}', world_size=world_size, rank=rank)

    # 创建模型并移动到GPU
    model = nn.Linear(10, 10).to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    # 创建数据加载器
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)

    # 创建优化器
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)

    # 训练模型
    for epoch in range(5):
        sampler.set_epoch(epoch)
        for data, target in dataloader:
            data, target = data.to(rank), target.to(rank)
            optimizer.zero_grad()
            output = ddp_model(data)
            loss = nn.functional.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch}, Loss: {loss.item()}')

    # 清理分布式环境
    dist.destroy_process_group()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--world_size', type=int, default=2)
    parser.add_argument('--rank', type=int, default=0)
    args = parser.parse_args()

    main(args.rank, args.world_size)

4. 运行分布式训练

使用torch.distributed.launchmpirun来启动分布式训练。

使用torch.distributed.launch

python -m torch.distributed.launch --nproc_per_node=<number_of_gpus> --nnodes=<number_of_nodes> --node_rank=<node_rank> --master_addr=<master_ip> --master_port=<master_port> your_script.py

使用mpirun

mpirun -np <total_number_of_gpus> -host <master_ip> python your_script.py --world_size <number_of_gpus> --rank <rank>

5. 注意事项

  • 网络配置:确保所有节点之间的网络连接正常。
  • 防火墙设置:关闭防火墙或配置防火墙以允许分布式训练所需的端口通信。
  • 资源分配:确保每个节点有足够的GPU和内存资源。

通过以上步骤,你可以在Linux环境中使用PyTorch进行分布式训练。

0