我正在尝试使用 torch.load 加载预训练模型。
我收到以下错误:
ModuleNotFoundError: No module named 'utils'
我已通过从命令行打开它来检查我使用的路径是否正确。可能是什么原因造成的?
这是我的代码:
import torch
import sys
PATH = './gan.pth'
model = torch.load(PATH)
model.eval()
编辑: 整个错误堆栈:
Traceback (most recent call last):
File "load.py", line 6, in <module>
model = torch.load(PATH)
File "C:\Users\user\anaconda3\envs\pytorch-flask\lib\site-packages\torch\serialization.py", line 595, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "C:\Users\user\anaconda3\envs\pytorch-flask\lib\site-packages\torch\serialization.py", line 774, in _legacy_load
result = unpickler.load()
ModuleNotFoundError: No module named 'utils'
最佳答案
编辑这个答案没有提供问题的答案,但解决了给定代码中的另一个问题
.pth
文件只存储模型的参数,而不是模型本身。当您想要加载模型时,您将需要 .pt/-h
文件和模型类的 python 代码。然后你可以像这样加载它:
# your model
class YourModel(nn.Modules):
def __init__(self):
super(YourModel, self).__init__()
. . .
def forward(self, x):
. . .
# the pytorch save-file in which you stored your trained model
model_file = "<your path>"
model = Model()
model = model.load_state_dict(torch.load(model_file))
model.eval()
关于python - torch torch .load ModuleNotFoundError : No module named 'utils' ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65538179/