在Linux下,使用PyTorch保存和加载模型的主要方法是使用torch.save()和torch.load()函数。以下是保存和加载模型的详细步骤:
首先,我们需要定义一个模型。这里以一个简单的多层感知器(MLP)为例:
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
input_size = 784
hidden_size = 128
output_size = 10
model = MLP(input_size, hidden_size, output_size)
接下来,我们可以使用torch.save()函数将模型保存到文件中:
torch.save(model, 'model.pth')
这将在当前目录下创建一个名为model.pth的文件,其中包含模型的权重和结构信息。
要加载模型,我们首先需要实例化相同的模型结构,然后使用torch.load()函数从文件中加载权重:
loaded_model = MLP(input_size, hidden_size, output_size)
loaded_model.load_state_dict(torch.load('model.pth'))
现在,loaded_model变量包含了从model.pth文件中加载的模型权重和结构。你可以像使用原始模型一样使用它:
input_data = torch.randn(1, input_size)
output = loaded_model(input_data)
注意:在加载模型时,请确保模型结构与保存时的结构相同。如果结构不同,你可能会遇到错误或意外的行为。