温馨提示×

Linux环境下PyTorch模型保存与加载

小樊
36
2025-10-10 02:17:14
栏目: 智能运维

在Linux环境下,使用PyTorch保存和加载模型非常简单。以下是一个简单的例子来说明如何保存和加载一个PyTorch模型。

首先,我们需要导入所需的库并定义一个简单的模型:

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 5)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()

接下来,我们可以使用torch.save()函数将模型保存到文件中:

torch.save(model, 'model.pth')

现在,模型已经被保存到了名为model.pth的文件中。要加载模型,我们可以使用torch.load()函数:

loaded_model = torch.load('model.pth')

加载模型后,我们可以像使用原始模型一样使用它:

input_data = torch.randn(1, 10)
output = loaded_model(input_data)

注意:在加载模型时,确保你的环境中已经安装了与保存模型时相同的PyTorch版本。否则,可能会出现兼容性问题。

0