在CentOS系统上,使用PyTorch保存和加载模型主要涉及到以下几个步骤:
训练模型: 在训练过程中,你可以定期保存模型的状态字典(state_dict)。
import torch
import torch.nn as nn
# 假设你有一个模型类 MyModel
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 定义模型层
def forward(self, x):
# 定义前向传播
return x
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练循环
for epoch in range(num_epochs):
# 训练代码...
# ...
# 每隔一定epoch保存模型
if (epoch + 1) % save_interval == 0:
torch.save(model.state_dict(), f'model_epoch_{epoch + 1}.pth')
保存整个模型:
如果你想保存整个模型(包括模型架构和状态字典),可以使用torch.save直接保存模型对象。
torch.save(model, 'model.pth')
加载模型状态字典:
当你需要加载之前保存的模型状态字典时,可以使用load_state_dict方法。
model = MyModel() # 创建一个新的模型实例
model.load_state_dict(torch.load('model_epoch_10.pth'))
model.eval() # 设置模型为评估模式
加载整个模型: 如果你之前保存了整个模型,可以直接加载。
model = torch.load('model.pth')
model.eval() # 设置模型为评估模式
设备兼容性:如果你在GPU上训练模型,但在CPU上加载模型,需要将模型移动到CPU上。
model = torch.load('model.pth', map_location=torch.device('cpu'))
版本兼容性:确保保存和加载模型的PyTorch版本一致,否则可能会出现兼容性问题。
安全性:从不可信来源加载模型时要小心,因为这可能会导致安全问题。
通过以上步骤,你可以在CentOS系统上轻松地保存和加载PyTorch模型。