温馨提示×

Linux下PyTorch模型保存与加载方法

小樊
50
2025-09-21 04:06:37
栏目: 智能运维

在Linux下,使用PyTorch保存和加载模型的主要方法是使用torch.save()torch.load()函数。以下是保存和加载模型的详细步骤:

  1. 保存模型:

首先,我们需要定义一个模型。这里以一个简单的多层感知器(MLP)为例:

import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

input_size = 784
hidden_size = 128
output_size = 10

model = MLP(input_size, hidden_size, output_size)

接下来,我们可以使用torch.save()函数将模型保存到文件中:

torch.save(model, 'model.pth')

这将在当前目录下创建一个名为model.pth的文件,其中包含模型的权重和结构信息。

  1. 加载模型:

要加载模型,我们首先需要实例化相同的模型结构,然后使用torch.load()函数从文件中加载权重:

loaded_model = MLP(input_size, hidden_size, output_size)
loaded_model.load_state_dict(torch.load('model.pth'))

现在,loaded_model变量包含了从model.pth文件中加载的模型权重和结构。你可以像使用原始模型一样使用它:

input_data = torch.randn(1, input_size)
output = loaded_model(input_data)

注意:在加载模型时,请确保模型结构与保存时的结构相同。如果结构不同,你可能会遇到错误或意外的行为。

0