pdb是Python标准库中的交互式调试工具,适合快速定位代码问题。在PyTorch代码中插入import pdb; pdb.set_trace(),程序执行到该行时会暂停,进入交互式调试模式。常用命令包括:
n(next):执行下一行代码;s(step):进入函数内部;c(continue):继续执行至下一个断点;p <变量名>:打印变量值;q(quit):退出调试模式。PyCharm、VSCode等IDE提供直观的调试界面,无需手动插入断点命令。以PyCharm为例:
TorchSnooper是专为PyTorch设计的调试工具,可自动输出函数中每行代码的张量形状、数据类型、设备(CPU/GPU)、是否需要梯度等信息,无需手动添加打印语句。
pip install torchsnooper;@torchsnooper.snoop()装饰器,运行脚本后会自动打印详细日志。例如:import torch
import torchsnooper
@torchsnooper.snoop()
def myfunc(mask, x):
y = torch.zeros(6)
y.masked_scatter_(mask, x)
return y
日志会显示y的形状、类型等信息,帮助快速定位张量维度不匹配等问题。PyTorch Profiler可分析模型的计算时间、内存占用、GPU利用率等性能指标,支持生成可视化报告(如TensorBoard)。
with torch.profiler.profile(
on_trace_ready=torch.profiler.tensorboard_trace_handler("trace_pt"),
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
on_trace_ready=lambda prof: prof.export_chrome_trace("trace.json")
) as prof:
for step, data in enumerate(trainloader):
inputs, labels = data[0].to(device), data[1].to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
prof.step()
tensorboard --logdir=trace_pt启动TensorBoard,在“Profile”标签页查看性能分析结果,识别耗时操作(如矩阵乘法、梯度计算)。assert语句用于验证代码中的关键条件(如张量维度、数值范围),条件不满足时抛出AssertionError,帮助快速定位逻辑错误。例如:
assert x.shape == (batch_size, input_dim), f"Expected shape {(batch_size, input_dim)}, got {x.shape}"
assert torch.allclose(loss, expected_loss, atol=1e-6), "Loss value is incorrect"
assert语句应放在可能出现问题的代码段(如数据预处理、模型输出后),避免影响正常运行。
使用Python的logging模块记录程序运行时的变量值、执行流程,比print语句更灵活(可设置日志级别、输出到文件)。例如:
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logging.debug(f"Input tensor: {x}, Shape: {x.shape}, Device: {x.device}")
日志级别说明:DEBUG(详细信息)、INFO(一般信息)、WARNING(警告)、ERROR(错误)、CRITICAL(严重错误)。
PyTorch的torch.autograd.set_detect_anomaly(True)可检测梯度计算中的异常(如NaN、无穷大),帮助定位梯度爆炸或消失问题。使用时需注意:
try-except捕获异常,定位具体出错位置。例如:torch.autograd.set_detect_anomaly(True)
try:
loss.backward()
except RuntimeError as e:
print(f"Gradient anomaly detected: {e}")
ipdb是pdb的增强版,支持语法高亮、代码补全,提升调试体验。使用方法与pdb类似:
pip install ipdb;import ipdb; ipdb.set_trace();n、s、c等)。以上方法可根据调试需求组合使用(如用TorchSnooper查看张量信息+Profiler分析性能+assert检查逻辑),提高调试效率。