python - 如何加载和使用预保留的 PyTorch InceptionV3 模型对图像进行分类

标签 python pytorch torch

我和How can I load and use a PyTorch (.pth.tar) model有同样的问题没有公认的答案,也没有我能弄清楚如何遵循给出的建议的答案。

我是 PyTorch 的新手。我正在尝试加载此处引用的预训练 PyTorch 模型:https://github.com/macaodha/inat_comp_2018

我很确定我漏掉了一些胶水。

# load the model
import torch
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')

# try to get it to classify an image
imsize = 256
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])

def image_loader(image_name):
    """load image, returns cuda tensor"""
    image = Image.open(image_name)
    image = loader(image).float()
    image = Variable(image, requires_grad=True)
    image = image.unsqueeze(0)  
    return image.cpu()  #assumes that you're using CPU

image = image_loader("test-image.jpg")

产生错误:

in () ----> 1 model.predict(image)

AttributeError: 'dict' object has no attribute 'predict

最佳答案

问题

你的 model实际上不是模型。保存时,它不仅包含参数,还包含有关模型的其他信息,其形式有点类似于字典。

因此,torch.load("iNat_2018_InceptionV3.pth.tar")简单地返回 dict ,当然没有名为 predict 的属性.

model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
type(model)
# dict

解决方案

在这种情况下,在一般情况下,您首先需要做的是根据官方指南实例化您想要的模型类 "Load models" .

# First try
from torchvision.models import Inception3
v3 = Inception3()
v3.load_state_dict(model['state_dict']) # model that was imported in your code.

但是,直接输入model['state_dict']会引发一些关于 Inception3 形状不匹配的错误的参数。

了解 Inception3 的变化很重要在它的实例化之后。幸运的是,您可以在原作者的 train_inat.py 中找到它。 .

# What the author has done
model = inception_v3(pretrained=True)
model.fc = nn.Linear(2048, args.num_classes) #where args.num_classes = 8142
model.aux_logits = False

既然我们知道要更改什么,让我们对我们的第一次尝试做一些修改。

# Second try
from torchvision.models import Inception3
v3 = Inception3()
v3.fc = nn.Linear(2048, 8142)
v3.aux_logits = False
v3.load_state_dict(model['state_dict']) # model that was imported in your code.

成功加载模型!

关于python - 如何加载和使用预保留的 PyTorch InceptionV3 模型对图像进行分类,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53844826/

相关文章:

python-3.x - 用户警告 : Implicit dimension choice for log_softmax has been deprecated

python - Lua 的科学库?

python - 为什么在 pyTorch 强化学习示例的 nn.Module 中返回 self.head(x.view(x.size(0), -1))

python - 加载名称符合某些约定的包的所有子模块

Python3 Pandas 按列名称未知的列进行过滤

python - 信用卡验证练习中的 while 循环 - python2.7

torch - PyTorch 中的无量纲转置

python - 使用字符集编码 UTF-8 发送电子邮件 - Python + boto3

regression - Pytorch loss inf nan

python - 无法使用割炬创建张量