Ubuntu下PyTorch内存不足的排查与优化
一 快速定位问题
二 训练阶段的高效优化
三 推理阶段与常见OOM场景
四 系统与CUDA层面的调优
五 实用代码片段
import torch
from torch.cuda.amp import autocast, GradScaler
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()
for data, target in dataloader:
optimizer.zero_grad()
with autocast():
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
def find_max_batch(model, input_shape, max_mem=8*1024**3):
batch_size = 1
while True:
try:
with torch.cuda.amp.autocast(enabled=True):
_ = model(torch.randn(*input_shape, device='cuda')[:batch_size])
used = torch.cuda.max_memory_allocated()
if used > 0.9 * max_mem:
return max(1, batch_size - 1)
batch_size *= 2
except RuntimeError:
return max(1, batch_size // 2)
import torch, gc
def clear_cache():
if torch.cuda.is_available():
torch.cuda.synchronize() # 等待所有流完成
torch.cuda.empty_cache()
gc.collect()