温馨提示×

Ubuntu上PyTorch模型如何保存

小樊
39
2025-12-10 20:31:58
栏目: 智能运维

在Ubuntu上保存PyTorch模型通常涉及以下几个步骤:

  1. 训练模型:首先,你需要训练你的PyTorch模型。这通常涉及到定义模型架构、损失函数、优化器以及训练循环。

  2. 保存模型:训练完成后,你可以使用torch.save()函数来保存整个模型或者模型的状态字典。

下面是一个简单的例子,展示了如何定义一个模型、训练它,并保存模型的状态字典:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 5)  # 假设输入特征有10个,输出有5个

    def forward(self, x):
        return self.fc(x)

# 创建模型实例
model = SimpleModel()

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 假设我们有一些数据
inputs = torch.randn(100, 10)
targets = torch.randn(100, 5)

# 训练模型(这里只是示例,没有实际的训练循环)
for epoch in range(10):  # 假设我们训练10个周期
    optimizer.zero_grad()   # 清空梯度
    outputs = model(inputs) # 前向传播
    loss = criterion(outputs, targets) # 计算损失
    loss.backward()         # 反向传播
    optimizer.step()        # 更新权重

# 保存模型的状态字典
torch.save(model.state_dict(), 'model.pth')

在上面的代码中,model.state_dict()包含了模型的所有权重和偏置等参数。使用torch.save()函数将这些参数保存到文件model.pth中。

  1. 加载模型:当你需要在另一个程序中使用这个模型时,你可以加载保存的状态字典,并将其应用到一个新的模型实例上:
# 创建一个新的模型实例
model_new = SimpleModel()

# 加载保存的状态字典
model_new.load_state_dict(torch.load('model.pth'))

# 现在你可以使用这个模型进行预测或者继续训练

请注意,当你加载模型时,确保新的模型架构与保存时的架构完全一致,否则可能会出现错误。如果你更改了模型架构,你需要确保新的架构包含了所有必要的层和参数。

此外,如果你想要保存整个模型(包括模型的架构和类定义),你可以直接保存模型实例,而不是状态字典:

# 保存整个模型
torch.save(model, 'model_full.pth')

# 加载整个模型
model_loaded = torch.load('model_full.pth')

在这种情况下,你需要确保在加载模型的环境中可以访问到原始模型的类定义。

0