在CentOS上保存和加载PyTorch模型时,可以采用以下技巧:
使用torch.save()函数:
import torch
import torchvision.models as models
# 创建一个模型实例
model = models.resnet18(pretrained=True)
# 保存整个模型
torch.save(model, 'model.pth')
# 或者只保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')
保存额外的信息: 如果需要保存模型的架构、优化器状态等额外信息,可以创建一个字典来存储这些内容:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
# 其他需要保存的信息
}
torch.save(checkpoint, 'checkpoint.pth')
加载整个模型:
model = torch.load('model.pth')
model.eval() # 设置模型为评估模式
加载模型的状态字典:
model = models.resnet18(pretrained=False) # 创建一个相同架构的新模型
model.load_state_dict(torch.load('model_state_dict.pth'))
model.eval() # 设置模型为评估模式
加载完整的检查点:
checkpoint = torch.load('checkpoint.pth')
model = models.resnet18(pretrained=False) # 创建一个相同架构的新模型
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
# 其他需要加载的信息
设备兼容性: 如果在GPU上训练模型,保存的模型会包含GPU相关的信息。在CPU上加载时,需要将模型移动到CPU:
model = torch.load('model.pth', map_location=torch.device('cpu'))
版本兼容性: 确保保存和加载模型时使用的PyTorch版本一致,否则可能会出现不兼容的问题。
安全性: 避免加载来自不可信来源的模型文件,以防止潜在的安全风险。
通过以上技巧,你可以在CentOS上有效地保存和加载PyTorch模型。