python - 堆叠 RNN 的输入形状

标签 python python-3.x deep-learning keras

我正在尝试使用 Keras 和 TensorFlow 后端串联堆叠一些 RNN。我可以使用单个 SimpleRNN 层创建模型,但是当我尝试添加第二个 SimpleRNN 层时,我无法计算出适当的输入大小。

from keras import models
from keras.layers.recurrent import SimpleRNN
from keras.layers import Activation


model = models.Sequential()

hidden_units = 256
skeleton_dimensions = 3 * 16  # 3 dimensions for 16 joints
input_temporal_length = 7

in_shape = (input_temporal_length, skeleton_dimensions,)

# three hidden layers of 256 each
model.add(SimpleRNN(hidden_units, input_shape=in_shape,
                    activation='relu', use_bias=True,))
# what input shape is this supposed to have?
model.add(SimpleRNN(hidden_units, input_shape=(1, skeleton_dimensions,),
                    activation='relu', use_bias=True,))

我的第二个 SimpleRNN 应该使用什么作为输入形状?

Recurrent Layers 的文档似乎暗示:

Output shape

  • if return_sequences: 3D tensor with shape (batch_size, timesteps, units).
  • else, 2D tensor with shape (batch_size, units).

鉴于return_sequences自动设置为False我尝试适本地设置下一个维度的input_shape,但我收到错误:

Using TensorFlow backend.
Traceback (most recent call last):
  File "rnn_agony.py", line 19, in <module>
    activation='relu', use_bias=True,))
  File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 455, in add
    output_tensor = layer(self.outputs[0])
  File "/usr/local/lib/python3.5/dist-packages/keras/layers/recurrent.py", line 252, in __call__
    return super(Recurrent, self).__call__(inputs, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 511, in __call__
    self.assert_input_compatibility(inputs)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 413, in assert_input_compatibility
    str(K.ndim(x)))
ValueError: Input 0 is incompatible with layer simple_rnn_2: expected ndim=3, found ndim=2

最佳答案

如果您正在堆叠 RNN,则需要设置 return_sequences=True 并且不再需要设置 input_shape。这很直观,因为 RNN 需要输入序列。

关于python - 堆叠 RNN 的输入形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43084541/

相关文章:

python - 在 Django 的 OrderedDict 中使用字典

python-3.x - 无法切换到 iframe 并使用 Python 关闭 Selenium 中的框架

python - 为什么只有一个字符被正确更改?

python - key 错误 : "Unable to open object (object ' imgs' doesn't exist)"

machine-learning - 如何在 Tensorflow 2.0 中使用 K.get_session 或如何迁移它?

python - 如何使用python docx获取word文档中文本的实际样式

python - Python中如何将带参数的函数作为参数传递给另一个函数?

python - 如何使用 Bottle-python 让 API 接受 URL 作为 GET 或 POST 请求的参数

algorithm - 从文件中排序名称的有效方法?

gcc - caffe cake错误由ccache