温馨提示×

Ubuntu系统中如何备份和恢复PyTorch模型

小樊
34
2025-12-20 18:04:23
栏目: 智能运维

在Ubuntu系统中备份和恢复PyTorch模型,可以按照以下步骤进行:

备份PyTorch模型

  1. 保存模型权重: 使用torch.save()函数将模型的权重保存到一个文件中。

    import torch
    import torchvision.models as models
    
    # 创建一个示例模型
    model = models.resnet18(pretrained=True)
    
    # 保存模型权重
    torch.save(model.state_dict(), 'model_weights.pth')
    
  2. 保存整个模型(可选): 如果你希望保存整个模型(包括架构和权重),可以使用torch.save()函数保存整个模型对象。

    # 保存整个模型
    torch.save(model, 'model.pth')
    

恢复PyTorch模型

  1. 加载模型权重: 使用torch.load()函数加载之前保存的模型权重,并将其加载到模型中。

    # 创建一个与之前相同的模型架构
    model = models.resnet18(pretrained=False)
    
    # 加载模型权重
    model.load_state_dict(torch.load('model_weights.pth'))
    
  2. 加载整个模型(可选): 如果你之前保存了整个模型,可以直接加载整个模型对象。

    # 加载整个模型
    model = torch.load('model.pth')
    

注意事项

  • 设备一致性:在加载模型权重时,确保模型和权重在同一设备上(CPU或GPU)。如果模型在GPU上训练,但在CPU上加载,需要将权重移动到CPU。

    # 如果模型在GPU上训练,但在CPU上加载
    model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))
    
  • 模型架构一致性:确保加载权重的模型架构与保存权重的模型架构一致。如果不一致,可能会导致加载失败或模型行为异常。

示例代码总结

import torch
import torchvision.models as models

# 创建一个示例模型
model = models.resnet18(pretrained=True)

# 保存模型权重
torch.save(model.state_dict(), 'model_weights.pth')

# 加载模型权重
model = models.resnet18(pretrained=False)
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))

通过以上步骤,你可以在Ubuntu系统中轻松备份和恢复PyTorch模型。

0