PyTorch是一个开源的深度学习框架,它提供了在Linux上进行并行计算的能力。以下是在Linux上使用PyTorch进行并行计算的几种方法:
数据并行是指将数据分割成多个小批次,然后在多个GPU上并行处理这些小批次。
import torch
import torch.nn as nn
import torch.optim as optim
# 定义模型
model = nn.Sequential(
nn.Linear(10, 50),
nn.ReLU(),
nn.Linear(50, 10)
)
# 将模型复制到多个GPU
model = nn.DataParallel(model)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 假设我们有一些数据
inputs = torch.randn(64, 10) # 输入数据
labels = torch.randint(0, 10, (64,)) # 标签
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
模型并行是指将模型的不同部分放在不同的GPU上。
import torch
import torch.nn as nn
class ModelParallelModel(nn.Module):
def __init__(self):
super(ModelParallelModel, self).__init__()
self.part1 = nn.Linear(10, 50).to('cuda:0')
self.part2 = nn.Linear(50, 10).to('cuda:1')
def forward(self, x):
x = x.to('cuda:0')
x = self.part1(x)
x = x.to('cuda:1')
x = self.part2(x)
return x
model = ModelParallelModel()
分布式并行是指在多个节点上进行并行计算,每个节点可以有多个GPU。
torch.distributed包import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# 初始化分布式环境
dist.init_process_group(backend='nccl')
# 定义模型
model = nn.Sequential(
nn.Linear(10, 50),
nn.ReLU(),
nn.Linear(50, 10)
)
# 将模型放到正确的设备上
model.to(torch.device(f'cuda:{dist.get_rank()}'))
# 包装模型为DDP模型
model = DDP(model)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 假设我们有一些数据
inputs = torch.randn(64, 10).to(torch.device(f'cuda:{dist.get_rank()}'))
labels = torch.randint(0, 10, (64,)).to(torch.device(f'cuda:{dist.get_rank()}'))
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
你可以使用torch.distributed.launch或accelerate库来启动分布式训练。
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE YOUR_TRAINING_SCRIPT.py
或者使用accelerate库:
accelerate launch YOUR_TRAINING_SCRIPT.py
torch.utils.data.DataLoader并结合torch.utils.data.distributed.DistributedSampler来确保每个进程处理不同的数据子集。通过这些方法,你可以在Linux上高效地进行PyTorch的并行计算。