Ubuntu下PyTorch内存优化实用指南
一 基础显存优化
二 进阶显存节省技术
三 监控与诊断工具
四 OOM应急与排查流程
五 一键可用的最小优化模板
import torch, torch.nn as nn, torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
device = torch.device('cuda')
model = nn.Linear(1024, 1024).to(device)
optimizer = optim.Adam(model.parameters())
scaler = GradScaler()
accumulation_steps = 4
for i, (x, y) in enumerate(dataloader):
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
with autocast():
loss = criterion(model(x), y) / accumulation_steps
scaler.scale(loss).backward()
if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True) # 更省显存
# 可选:阶段结束后清理
# del x, y, loss
# torch.cuda.empty_cache()
model.eval()
with torch.no_grad():
for x in dataloader:
x = x.to(device, non_blocking=True)
out = model(x)
# 可选:清理
# del x, out
# torch.cuda.empty_cache()
def log_mem(msg=""):
a = torch.cuda.memory_allocated() / 1024**2
r = torch.cuda.memory_reserved() / 1024**2
print(f"{msg} Allocated={a:.1f}MB Reserved={r:.1f}MB")
loader = DataLoader(dataset, batch_size=bs, num_workers=4,
pin_memory=True, prefetch_factor=2)
export PYTORCH_CUDA_ALLOC_CONF="garbage_collection_threshold:0.8,max_split_size_mb:128"
以上模板覆盖了混合精度、梯度累积、no_grad、清理缓存与监控等关键要点,可直接嵌入现有训练脚本并根据显存余量微调参数。