Ubuntu上PyTorch内存管理实战指南
一 核心机制与常见瓶颈
二 代码级显存清理与最佳实践
import torch, gc
def train_step(model, data, optimizer, criterion):
inputs, targets = data
inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
optimizer.zero_grad(set_to_none=True)
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
del inputs, targets, outputs, loss
gc.collect()
torch.cuda.empty_cache() # 可选:按周期调用
三 训练阶段显存优化策略
scaler = torch.cuda.amp.GradScaler()
accum_steps = 4
for i, (inputs, targets) in enumerate(loader):
inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
with torch.cuda.amp.autocast():
outputs = model(inputs)
loss = criterion(outputs, targets) / accum_steps
scaler.scale(loss).backward()
if (i + 1) % accum_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
四 系统与CUDA层优化
五 诊断工具与排错流程
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
on_trace_ready=lambda prof: prof.export_chrome_trace("trace.json"),
record_shapes=True, profile_memory=True
) as prof:
for _ in range(5):
train_step(model, next(data_iter), optimizer, criterion)
prof.step()