python - 有没有办法通过字符串加载 torchvision 模型?

标签 python deep-learning computer-vision pytorch torchvision

<分区>

目前,我使用以下代码加载预训练的 torchvision 模型:

import torchvision
torchvision.models.resnet101(pretrained=True)

但是,我喜欢将模型名称作为字符串参数,然后使用该字符串加载预训练模型。这样做的伪代码类似于:

model_name = 'resnet101'
torchvision.models.get(model_name)(pretrained=True)

有没有办法以一种相当简单的方式实现这一点?

最佳答案

您可以使用 torch.hub :

model_str = 'resnet50'
model = torch.hub.load('pytorch/vision', model_str, pretrained=True)

可以通过以下方式找到所有可用的字符串模型:

torch.hub.list('pytorch/vision', force_reload=True)

输出:

['alexnet',
 'deeplabv3_mobilenet_v3_large',
 'deeplabv3_resnet101',
 'deeplabv3_resnet50',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'fcn_resnet101',
 'fcn_resnet50',
 'googlenet',
 'inception_v3',
 'lraspp_mobilenet_v3_large',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet_v2',
 'mobilenet_v3_large',
 'mobilenet_v3_small',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext50_32x4d',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'squeezenet1_0',
 'squeezenet1_1',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'wide_resnet101_2',
 'wide_resnet50_2']

关于python - 有没有办法通过字符串加载 torchvision 模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67454884/

相关文章:

c++ - 使用 cuda 而不是 cudnn 测试 caffe 时的未知池方法

python - ValueError : Input 0 of layer sequential is incompatible with the layer: : expected min_ndim=4, 发现 ndim=3。收到的完整形状 : [8, 28, 28]

python - Tensorflow:从图像中预测一个点,用点标签训练模型

图像处理如何检测图像中的特定自定义形状

python - 限制文件浏览器小型读取

python - 加载序列化 json 对象时出现问题

opencv - 设置图像捕获和图像处理的场景

opencv - 在盘子里的一堆水果中检测出香蕉或苹果,成功率 > 90%。 (见图)

python - python 的 ORM 库 peewee 中与foreignkeyfield对象一起使用的 "related_name"属性是什么?

python - 创建一个看起来像 Tkinter 的表