python - 如何从 ONNX 模型获取输入数据?

标签 python pytorch onnx

我已将我的 PyTorch 模型导出到 ONNX。现在,有没有办法让我从那个 ONNX 模型中获取输入层?

将 PyTorch 模型导出到 ONNX

import torch.onnx
checkpoint = torch.load("./saved_pytorch_model.pth")
model.load_state_dict(checkpoint['state_dict'])
input = torch.tensor(df_X.values).float()
torch.onnx.export(model, input, "onnx_model.onnx")

正在加载 ONNX 模型

onnx_model = onnx.load('onnx_model.onnx')

我希望能够以某种方式从 onnx_model 获取输入层。这可能吗?

最佳答案

ONNX 模型是一个 protobuf 结构,定义如下 ( https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto )。您可以使用为 python 生成的标准 protobuf 方法来处理它(请参阅:https://developers.google.com/protocol-buffers/docs/reference/python-generated)。我不明白你到底想提取什么。但是您可以遍历构成图形的节点 (model.graph.node)。图中的第一个节点可能对应也可能不对应您可能认为的第一层(这取决于翻译的完成方式)。您还可以获得模型的输入 (model.graph.input)。

关于python - 如何从 ONNX 模型获取输入数据?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56795995/

相关文章:

java - 使用 Apache Arrow 读取 Parquet 文件

python - 使用列表作为列中的数据类型(SQLAlchemy)

python - Python 3.5和pip的本地安装

neural-network - 二进制分类PyTorch的损失函数及其输入

python - .contigious() 在 PyTorch 中做什么?

azure - 无法在 azure 机器学习服务工作区中注册 ONNX 模型

c++ - 如何使用 C++ 在 ArmNN for Linux 中加载 base onnx 模型

python - 如何避免Python中的循环导入?

python - 如何根据 DCGAN 论文在扁平表示向量之上训练 L2-SVM 分类器

python - 混淆我的 DL 模型和 Python 的最佳方法?