温馨提示×

Linux PyTorch如何进行模型保存与加载

小樊
59
2025-08-07 13:16:42
栏目: 智能运维

在Linux环境下使用PyTorch进行模型的保存与加载是一个常见的操作。以下是详细的步骤和示例代码:

保存模型

  1. 保存整个模型: 使用torch.save()函数可以保存整个模型,包括模型的结构、参数和优化器状态。

    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    # 定义一个简单的模型
    class SimpleModel(nn.Module):
        def __init__(self):
            super(SimpleModel, self).__init__()
            self.fc = nn.Linear(10, 5)
    
        def forward(self, x):
            return self.fc(x)
    
    model = SimpleModel()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    
    # 保存模型
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, 'model.pth')
    
  2. 只保存模型参数: 如果只需要保存模型的参数,可以使用model.state_dict()

    torch.save(model.state_dict(), 'model_params.pth')
    

加载模型

  1. 加载整个模型: 使用torch.load()函数加载模型,并使用load_state_dict()方法将参数加载到模型中。

    # 加载模型
    model = SimpleModel()
    model.load_state_dict(torch.load('model.pth'))
    model.eval()  # 设置模型为评估模式
    
  2. 只加载模型参数: 如果之前只保存了模型参数,可以直接加载到模型中。

    model = SimpleModel()
    model.load_state_dict(torch.load('model_params.pth'))
    model.eval()  # 设置模型为评估模式
    

注意事项

  • 设备兼容性:如果模型是在GPU上训练的,而加载时在CPU上,需要将模型参数移动到CPU。

    model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
    
  • 版本兼容性:确保保存和加载模型的PyTorch版本一致,否则可能会出现兼容性问题。

  • 优化器状态:如果保存了优化器状态,可以在加载模型后继续训练。

    optimizer.load_state_dict(torch.load('model.pth')['optimizer_state_dict'])
    

通过以上步骤,你可以在Linux环境下轻松地进行PyTorch模型的保存与加载。

0