温馨提示×

PyTorch中如何使用预训练的模型

小樊
93
2024-03-05 18:36:02
栏目: 编程语言

在PyTorch中使用预训练的模型可以通过torchvision库中的models模块实现。该模块包含了一些常用的预训练模型,如ResNet、VGG、AlexNet等。以下是一个使用预训练的ResNet模型的示例:

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

# 加载预训练的ResNet模型
model = models.resnet18(pretrained=True)
model.eval()

# 加载一张图片进行推理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img = Image.open('image.jpg')
img = transform(img)
img = img.unsqueeze(0)  # 添加一个维度作为batch

# 进行推理
output = model(img)

在上面的示例中,我们首先加载了预训练的ResNet模型,并设置为evaluation模式。然后,我们加载了一张图片,并对其进行预处理,最后通过模型进行推理得到输出。需要注意的是,我们在推理之前还需要调用model.eval()来将模型设置为evaluation模式。

0