如何在不需要在某处定义模型类的情况下保存 PyTorch 模型?
免责声明 :
在 Best way to save a trained model in PyTorch? ,还有否 无需访问模型类代码即可保存模型的解决方案(或工作解决方案)。
最佳答案
如果您打算使用可用的 Pytorch 库(即 Python、C++ 或它支持的其他平台中的 Pytorch)进行推理,那么最好的方法是通过 TorchScript .
我觉得最简单的就是用trace = torch.jit.trace(model, typical_input)
然后 torch.jit.save(trace, path)
.然后,您可以使用 torch.jit.load(path)
加载跟踪模型。 .
这是一个非常简单的例子。我们制作两个文件:train.py
:
import torch
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x):
x = torch.relu(self.linear(x))
return x
model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
print(model(x))
traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")
infer.py
:import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
print(loaded_trace(x))
按顺序运行这些会得到结果:
python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])
python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])
结果是一样的,所以我们很好。 (请注意,由于 nn.Linear 层初始化的随机性,这里每次的结果都会不同)。
TorchScript 提供了将更复杂的架构和图形定义(包括 if 语句、while 循环等)保存在单个文件中,而无需在推理时重新定义图形。有关更高级的可能性,请参阅文档(上面链接)。
关于python - 在无法访问模型类代码的情况下保存 PyTorch 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59287728/