Ubuntu下PyTorch内存不足的排查与优化
一、先快速定位是GPU显存还是系统内存
二、GPU显存不足时的高效做法
三、系统内存不足时的处理
四、可直接套用的代码片段
import torch, gc
from torch.cuda.amp import autocast, GradScaler
model.train()
optimizer.zero_grad()
scaler = GradScaler()
for data, target in train_loader:
with autocast():
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# 释放本轮临时张量
del output, loss, data, target
gc.collect()
torch.cuda.empty_cache()
import torch, psutil, numpy as np
def find_max_batch(dataset, start=4, max_attempts=5):
bs = start
for _ in range(max_attempts):
try:
_ = torch.stack([dataset[i][0] for i in range(bs)]) # 仅测试张量分配
bs *= 2
except RuntimeError:
bs = max(1, bs // 2)
return bs
def monitor(interval=1):
while True:
m = psutil.virtual_memory()
print(f"Mem: {m.total/1e9:.1f}GB used: {m.percent}% avail: {m.available/1e9:.1f}GB")
time.sleep(interval)
上述模板覆盖了AMP、显存清理与动态批大小,配合系统监控更易定位瓶颈。
五、实用命令清单