我已将我的 PyTorch 模型导出到 ONNX。现在,有没有办法让我从那个 ONNX 模型中获取输入层?
将 PyTorch 模型导出到 ONNX
import torch.onnx
checkpoint = torch.load("./saved_pytorch_model.pth")
model.load_state_dict(checkpoint['state_dict'])
input = torch.tensor(df_X.values).float()
torch.onnx.export(model, input, "onnx_model.onnx")
正在加载 ONNX 模型
onnx_model = onnx.load('onnx_model.onnx')
我希望能够以某种方式从 onnx_model 获取输入层。这可能吗?
最佳答案
ONNX 模型是一个 protobuf 结构,定义如下 ( https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto )。您可以使用为 python 生成的标准 protobuf 方法来处理它(请参阅:https://developers.google.com/protocol-buffers/docs/reference/python-generated)。我不明白你到底想提取什么。但是您可以遍历构成图形的节点 (model.graph.node)。图中的第一个节点可能对应也可能不对应您可能认为的第一层(这取决于翻译的完成方式)。您还可以获得模型的输入 (model.graph.input)。
关于python - 如何从 ONNX 模型获取输入数据?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56795995/