python - 将 LSTM Pytorch 模型转换为 ONNX 时遇到问题

标签 python deep-learning pytorch lstm onnx

我正在尝试将我的 LSTM 异常检测 Pytorch 模型导出到 ONNX,但我遇到了错误。请看看我下面的代码。

注意:我的数据形状为 [2685, 5, 6]。
这是我定义模型的地方:

class Model(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim):
        super(Model, self).__init__()
        self.hidden_dim = hidden_dim 
        self.layer_dim = layer_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, input_dim)   
    def forward(self, x):
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        out = self.fc1(out) 
        out = self.fc2(out)
        return out

input_dim = 6
hidden_dim = 3
layer_dim = 2
model = Model(input_dim, hidden_dim, layer_dim)

我可以训练它并用它很好地测试。但是导出的时候问题来了:

model.eval()
import torch.onnx
torch_out = torch.onnx.export(model, 
                         torch.randn(2685, 5, 6), 
                         "onnx_model.onnx", 
                         export_params = True
                        )

但我有以下错误:
LSTM(6, 3, num_layers=2, batch_first=True)
Linear(in_features=3, out_features=3, bias=True)
Linear(in_features=3, out_features=6, bias=True)
['input_1', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear']

/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/symbolic.py:173: UserWarning: ONNX export failed on RNN/GRU/LSTM because batch_first not supported
  warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-264-28c6c55537ab> in <module>()
     10                          torch.randn(2685, 5, 6),
     11                          "onnx_model.onnx",
---> 12                          export_params = True
     13                         )

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/__init__.py in export(*args, **kwargs)
     23 def export(*args, **kwargs):
     24     from torch.onnx import utils
---> 25     return utils.export(*args, **kwargs)
     26 
     27 

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
    129             operator_export_type=operator_export_type, opset_version=opset_version,
    130             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
--> 131             strip_doc_string=strip_doc_string)
    132 
    133 

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
    367         if export_params:
    368             proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type,
--> 369                                                    strip_doc_string)
    370         else:
    371             proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type, strip_doc_string)

RuntimeError: ONNX export failed: Couldn't export operator aten::lstm

Defined at:
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(522): forward_impl
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(539): forward_tensor
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(559): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(481): _slow_forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(491): __call__
<ipython-input-255-468cef410a2c>(14): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(481): _slow_forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(491): __call__
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/jit/__init__.py(294): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(493): __call__
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/jit/__init__.py(231): get_trace_graph
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(225): _trace_and_get_graph_from_model
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(266): _model_to_graph
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(363): _export
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(131): export
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/__init__.py(25): export
<ipython-input-264-28c6c55537ab>(12): <module>
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2963): run_code
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2903): run_ast_nodes
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2785): _run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2662): run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/zmqshell.py(537): run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/ipkernel.py(208): do_execute
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(399): execute_request
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(233): dispatch_shell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(283): dispatcher
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(432): _run_callback
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(480): _handle_recv
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(450): _handle_events
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/platform/asyncio.py(117): _handle_events
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/events.py(145): _run
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/base_events.py(1432): _run_once
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/base_events.py(422): run_forever
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/platform/asyncio.py(127): start
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelapp.py(486): start
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/traitlets/config/application.py(658): launch_instance
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/__main__.py(3): <module>
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/runpy.py(85): _run_code
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/runpy.py(193): _run_module_as_main


Graph we tried to export:
graph(%input.1 : Float(2685, 5, 6),
      %lstm.weight_ih_l0 : Float(12, 6),
      %lstm.weight_hh_l0 : Float(12, 3),
      %lstm.bias_ih_l0 : Float(12),
      %lstm.bias_hh_l0 : Float(12),
      %lstm.weight_ih_l1 : Float(12, 3),
      %lstm.weight_hh_l1 : Float(12, 3),
      %lstm.bias_ih_l1 : Float(12),
      %lstm.bias_hh_l1 : Float(12),
      %fc1.weight : Float(3, 3),
      %fc1.bias : Float(3),
      %fc2.weight : Float(6, 3),
      %fc2.bias : Float(6)):
  %13 : Long() = onnx::Constant[value={0}](), scope: Model
  %14 : Tensor = onnx::Shape(%input.1), scope: Model
  %15 : Long() = onnx::Gather[axis=0](%14, %13), scope: Model
  %16 : Long() = onnx::Constant[value={2}](), scope: Model
  %17 : Long() = onnx::Constant[value={3}](), scope: Model
  %18 : Tensor = onnx::Unsqueeze[axes=[0]](%16)
  %19 : Tensor = onnx::Unsqueeze[axes=[0]](%15)
  %20 : Tensor = onnx::Unsqueeze[axes=[0]](%17)
  %21 : Tensor = onnx::Concat[axis=0](%18, %19, %20)
  %22 : Float(2, 2685, 3) = onnx::ConstantOfShape[value={0}](%21), scope: Model
  %23 : Long() = onnx::Constant[value={0}](), scope: Model
  %24 : Tensor = onnx::Shape(%input.1), scope: Model
  %25 : Long() = onnx::Gather[axis=0](%24, %23), scope: Model
  %26 : Long() = onnx::Constant[value={2}](), scope: Model
  %27 : Long() = onnx::Constant[value={3}](), scope: Model
  %28 : Tensor = onnx::Unsqueeze[axes=[0]](%26)
  %29 : Tensor = onnx::Unsqueeze[axes=[0]](%25)
  %30 : Tensor = onnx::Unsqueeze[axes=[0]](%27)
  %31 : Tensor = onnx::Concat[axis=0](%28, %29, %30)
  %32 : Float(2, 2685, 3) = onnx::ConstantOfShape[value={0}](%31), scope: Model
  %33 : Long() = onnx::Constant[value={1}](), scope: Model/LSTM[lstm]
  %34 : Long() = onnx::Constant[value={2}](), scope: Model/LSTM[lstm]
  %35 : Double() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
  %36 : Long() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
  %37 : Long() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
  %38 : Long() = onnx::Constant[value={1}](), scope: Model/LSTM[lstm]
  %input.2 : Float(2685!, 5!, 3), %40 : Float(2, 2685, 3), %41 : Float(2, 2685, 3) = aten::lstm(%input.1, %22, %32, %lstm.weight_ih_l0, %lstm.weight_hh_l0, %lstm.bias_ih_l0, %lstm.bias_hh_l0, %lstm.weight_ih_l1, %lstm.weight_hh_l1, %lstm.bias_ih_l1, %lstm.bias_hh_l1, %33, %34, %35, %36, %37, %38), scope: Model/LSTM[lstm]
  %42 : Float(3!, 3!) = onnx::Transpose[perm=[1, 0]](%fc1.weight), scope: Model/Linear[fc1]
  %43 : Float(2685, 5, 3) = onnx::MatMul(%input.2, %42), scope: Model/Linear[fc1]
  %44 : Float(2685, 5, 3) = onnx::Add(%43, %fc1.bias), scope: Model/Linear[fc1]
  %45 : Float(3!, 6!) = onnx::Transpose[perm=[1, 0]](%fc2.weight), scope: Model/Linear[fc2]
  %46 : Float(2685, 5, 6) = onnx::MatMul(%44, %45), scope: Model/Linear[fc2]
  %47 : Float(2685, 5, 6) = onnx::Add(%46, %fc2.bias), scope: Model/Linear[fc2]
  return (%47)

这是什么意思?我应该怎么做才能正确导出?

最佳答案

如果您是从 Google 来到这里的,则之前的答案不再是最新的。 ONNX 现在支持 LSTM operator 。请注意,除非您使用 dynamic_axes 参数,否则默认情况下从 PyTorch 导出将修复输入序列长度。
下面是我改编自 torch.onnx FAQ 的最小 LSTM 导出示例

import torch
import onnx
from torch import nn
import numpy as np
import onnxruntime.backend as backend
import numpy as np

torch.manual_seed(0)

layer_count = 4

model = nn.LSTM(10, 20, num_layers=layer_count, bidirectional=True)
model.eval()

with torch.no_grad():
    input = torch.randn(1, 3, 10)
    h0 = torch.randn(layer_count * 2, 3, 20)
    c0 = torch.randn(layer_count * 2, 3, 20)
    output, (hn, cn) = model(input, (h0, c0))

    # default export
    torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx')
    onnx_model = onnx.load('lstm.onnx')
    # input shape [5, 3, 10]
    print(onnx_model.graph.input[0])

    # export with `dynamic_axes`
    torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx',
                    input_names=['input', 'h0', 'c0'],
                    output_names=['output', 'hn', 'cn'],
                    dynamic_axes={'input': {0: 'sequence'}, 'output': {0: 'sequence'}})
    onnx_model = onnx.load('lstm.onnx')
    # input shape ['sequence', 3, 10]

# Check export
y, (hn, cn) = model(input, (h0, c0))
y_onnx, hn_onnx, cn_onnx = backend.run(
    onnx_model, 
    [input.numpy(), h0.numpy(), c0.numpy()],
    device='CPU'
)

np.testing.assert_almost_equal(y_onnx, y.detach(), decimal=5)
np.testing.assert_almost_equal(hn_onnx, hn.detach(), decimal=5)
np.testing.assert_almost_equal(cn_onnx, cn.detach(), decimal=5)
我已经测试了这个例子:
火炬==1.4.0,
onnx=1.7.0

关于python - 将 LSTM Pytorch 模型转换为 ONNX 时遇到问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57299674/

相关文章:

C++ 和 Python - 检查 PyObject 的类型失败

python - Numpy dtype - 不理解的数据类型

machine-learning - 微调具有更大输入尺寸的模型

python - TensorFlow 专家组合

python - 喀拉斯 LSTM : Error when checking model input dimension

tensorflow - 对如何在 Web 应用程序中使用一些基本的机器学习感到好奇

pytorch - 如何理解mbart中的decoder_start_token_id和forced_bos_token_id?

python - 计算 python pandas 中两列之间的相同单词数

python - 如何禁用plotly.express.line中的趋势线?

python - 给定输入大小: (128x1x1).计算出的输出大小: (128x0x0).输出大小太小