温馨提示×

Ubuntu中PyTorch内存不足如何解决

小樊
34
2025-12-20 17:58:26
栏目: 智能运维

Ubuntu下PyTorch内存不足的排查与优化

一、先快速定位是GPU显存还是系统内存

  • 查看GPU显存与进程:运行nvidia-smi,关注显存占用与进程PID;必要时用sudo kill -9 PID结束“僵尸”进程,释放显存。
  • 查看系统内存:运行free -hhtop,确认是否出现系统内存(RAM)吃满。
  • 实时监控:使用watch -n 1 nvidia-smi观察显存波动,定位峰值与泄漏趋势。
    以上命令能快速判断问题发生在GPU还是CPU/RAM侧,从而决定后续优化路径。

二、GPU显存不足时的高效做法

  • 降低Batch Size,必要时配合梯度累积(Gradient Accumulation)保持有效批量与收敛稳定性。
  • 开启混合精度训练 AMP:用torch.cuda.amp.autocastGradScaler降低显存占用并提速。
  • 验证/测试阶段使用with torch.no_grad(),避免保存中间激活与梯度。
  • 训练循环中及时清理:在合适位置执行del outputs, loss, inputs, labels; gc.collect(); torch.cuda.empty_cache(),释放无用张量与缓存。
  • 优化数据加载:合理设置num_workerspin_memory,避免数据预处理成为瓶颈或占用过多内存。
  • 若仍不足,考虑更轻量的模型架构升级显存更大的GPU
    这些手段在Ubuntu+PyTorch场景下成熟可靠,能覆盖大多数显存不足问题。

三、系统内存不足时的处理

  • 排查与释放:关闭无关进程,重启训练脚本/机器以缓解内存碎片;在Python端用**del + gc.collect()**及时回收。
  • 降低数据加载的内存压力:将DataLoader的num_workers调小或设为0pin_memory=False;必要时从源端降低预处理分辨率/精度。
  • 处理超大数据集:使用numpy.memmap按需读取,避免一次性将整个数据集载入内存。
  • 临时兜底:增加Swap交换空间(示例:创建8GB交换文件并启用),缓解OOM但会牺牲速度。
  • 长期方案:增加**物理内存(RAM)**或迁移到内存更充足的机器/云端实例。
    以上措施可有效应对“DefaultCPUAllocator: not enough memory”等系统内存不足报错。

四、可直接套用的代码片段

  • 混合精度训练模板(AMP)
import torch, gc
from torch.cuda.amp import autocast, GradScaler

model.train()
optimizer.zero_grad()
scaler = GradScaler()

for data, target in train_loader:
    with autocast():
        output = model(data)
        loss = criterion(output, target)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    # 释放本轮临时张量
    del output, loss, data, target
    gc.collect()
    torch.cuda.empty_cache()
  • 动态批大小与系统内存监控
import torch, psutil, numpy as np

def find_max_batch(dataset, start=4, max_attempts=5):
    bs = start
    for _ in range(max_attempts):
        try:
            _ = torch.stack([dataset[i][0] for i in range(bs)])  # 仅测试张量分配
            bs *= 2
        except RuntimeError:
            bs = max(1, bs // 2)
    return bs

def monitor(interval=1):
    while True:
        m = psutil.virtual_memory()
        print(f"Mem: {m.total/1e9:.1f}GB used: {m.percent}% avail: {m.available/1e9:.1f}GB")
        time.sleep(interval)

上述模板覆盖了AMP、显存清理与动态批大小,配合系统监控更易定位瓶颈。

五、实用命令清单

  • GPU监控与进程清理:nvidia-smiwatch -n 1 nvidia-smisudo kill -9 PID
  • 系统内存查看:free -hhtop
  • 可选的系统级缓存清理(仅在必要时、谨慎使用):sync && echo 3 > /proc/sys/vm/drop_caches
    这些命令覆盖了日常排查、清理与监控的高频操作。

0