Ubuntu下PyTorch调试技巧汇总
print语句,输出变量值或执行流程信息(如print(f"Input shape: {input.shape}, Loss: {loss.item()}")),快速定位异常位置。import pdb; pdb.set_trace(),程序运行到此处会暂停,可通过n(下一步)、s(进入函数)、c(继续)、p 变量名(打印变量)等命令逐步排查问题;ipdb提供语法高亮和自动补全,pdb++支持更多功能(如变量历史查看)。print的灵活日志工具,支持不同日志级别(DEBUG/INFO/WARNING/ERROR)。配置logging.basicConfig(level=logging.DEBUG)后,可通过logging.debug(f"Debug info: {variable}")记录信息,便于后续过滤和分析。launch.json文件(指定Python路径和脚本),在代码中设置断点,点击“Run and Debug”启动,通过调试控制台查看输出,变量面板实时显示变量状态。torch.autograd.set_detect_anomaly(True)开启梯度异常检测,若模型训练中出现梯度爆炸/消失(如NaN损失),会抛出详细错误信息(包括计算图路径和异常原因),帮助快速定位梯度问题。torch.autograd.profiler分析模型性能瓶颈。例如:with profiler.profile(record_shapes=True) as prof:
output = model(input_data)
print(prof.key_averages().table(sort_by="cuda_time_total"))
输出显示各操作的时间消耗(CPU/GPU)、调用次数,便于优化慢操作(如不必要的张量拷贝)。torch.utils.tensorboard记录训练指标(损失、准确率、模型图)。示例代码:from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/experiment-1')
for epoch in range(num_epochs):
writer.add_scalar('Loss/train', train_loss, epoch)
writer.add_scalar('Accuracy/train', train_acc, epoch)
writer.close()
终端运行tensorboard --logdir=runs,通过浏览器访问http://localhost:6006查看实时曲线和模型结构。assert input.shape == (batch_size, channels, height, width), "Input shape mismatch"
assert torch.all(loss >= 0), "Loss should be non-negative"
条件不满足时抛出AssertionError,提前暴露逻辑错误。unittest或pytest框架编写测试用例,验证模型组件(如层、函数、数据加载器)的正确性。例如:import unittest
class TestModel(unittest.TestCase):
def test_linear_layer(self):
layer = torch.nn.Linear(10, 1)
input = torch.randn(5, 10)
output = layer(input)
self.assertEqual(output.shape, (5, 1))
if __name__ == '__main__':
unittest.main()
运行测试用例可快速发现组件问题,避免问题累积。print(input.shape, target.shape))、数值范围(print(input.min(), input.max())),验证是否符合模型输入要求(如通道顺序、尺寸)。torch.autograd.gradcheck验证自定义层的梯度计算是否正确。例如:from torch.autograd import gradcheck
test_input = torch.randn(2, 3, dtype=torch.double, requires_grad=True)
assert gradcheck(MyCustomLayer(), (test_input,), eps=1e-6, atol=1e-4)
若梯度检查失败,需检查自定义层的前向/反向传播逻辑。forward方法中打印各层的输入/输出形状和值(如print(f"Layer1 input: {x.shape}, output: {out.shape}")),确保数据流正确(如维度匹配、无意外截断)。