python - 如何加载和使用 PyTorch (.pth.tar) 模型

标签 python python-3.x neural-network pytorch torch

我对 Torch 不是很熟悉,我主要使用 Tensorflow。但是,我需要使用在 Torch 中重新训练的重新训练的初始模型。由于为我的特定应用程序重新训练初始模型需要大量计算资源,我想使用已经重新训练的模型。
此模型保存为 .pth.tar文件。
我希望能够首先加载这个模型。到目前为止,我已经能够弄清楚我必须使用以下内容:

model = torch.load('iNat_2018_InceptionV3.pth.tar', map_location='cpu')
这似乎有效,因为 print(model)打印出大量数字和其他值,我认为这些值是权重和偏差的值。
在此之后,我需要能够用它对图像进行分类。我一直无法弄清楚这一点。我必须如何格式化图像?图像是否应该转换为数组?在此之后,我必须如何将输入数据传递给网络?

最佳答案

你基本上需要做与 tensorflow 相同的事情。也就是说,当您存储网络时,只会存储参数(即网络中的可训练对象),而不是“胶水”,这是使用训练模型所需的全部逻辑。
因此,如果您有 .pth.tar文件,您可以加载它,从而覆盖已定义模型的参数值。

这意味着保存/加载模型的一般过程如下:

  • 编写您的网络定义(即您的 nn.Module 对象)
  • 以您想要的方式训练或以其他方式更改网络参数
  • 使用 torch.save 保存参数
  • 当您想使用该网络时,请使用 nn.Module 的相同定义对象首先实例化 pytorch 网络
  • 然后使用 torch.load 覆盖网络参数的值

  • 这是有关如何执行此操作的一些引用资料的讨论:pytorch forums

    这是一个超短的 mwe:
    # to store
    torch.save({
        'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    }, 'filename.pth.tar')
    
    # to load
    checkpoint = torch.load('filename.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    

    关于python - 如何加载和使用 PyTorch (.pth.tar) 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51857274/

    相关文章:

    python - 从枚举类获取值,其中枚举成员名称在 Python 中运行时已知

    python - 当需要包含非 python 文件时,使用 PyInstaller 创建单个 exe?

    c++ - c/c++中神经网络的实现方式是什么?

    python - Keras 模型给出的测试精度为 1.0

    python - 从元组列表中获取 "NaN"的元组索引

    python - 从 Firebird 数据库表中获取列名列表

    python-3.x - ffmpeg-python 从 x11grab 获取输入

    matlab - 在 MATLAB 中从头开始编写基本神经网络

    python - 如何将包含列表的列表列表转换为 DataFrame

    python - 如何执行编译后的python代码