python - 当 input_shape 指定为 3-d 时,Keras SimpleRNN 出错

标签 python neural-network theano keras recurrent-neural-network

我正在尝试在 Keras 上的 SimpleRNN 上根据文本进行训练。

在 Keras 中,我为 SimpleRNN 指定了一个非常简单的参数,如下所示:

model = Sequential()
model.add(SimpleRNN(output_dim=1, input_shape=(1,1,1))

我理解 input_shape 应该是 (nb_samples, timesteps, input_dim),和我的 train_x.shape 一样

所以我很惊讶我收到了以下错误。

Traceback (most recent call last):
  File "C:/Users/xxx/xxxx/xxx/xxx.py", line 262, in <module>
    model.add(SimpleRNN(output_dim=vocab_size, input_shape=train_x.shape))
  File "C:\Anaconda3\envs\py34\lib\site-packages\keras\models.py", line 275, in add
    layer.create_input_layer(batch_input_shape, input_dtype)
  File "C:\Anaconda3\envs\py34\lib\site-packages\keras\engine\topology.py", line 367, in create_input_layer
    self(x)
  File "C:\Anaconda3\envs\py34\lib\site-packages\keras\engine\topology.py", line 467, in __call__
    self.assert_input_compatibility(x)
  File "C:\Anaconda3\envs\py34\lib\site-packages\keras\engine\topology.py", line 408, in assert_input_compatibility
    str(K.ndim(x)))
Exception: Input 0 is incompatible with layer simplernn_1: expected ndim=3, found ndim=4

不确定为什么 keras 在仅指定 3 时“找到 ndim=4”!

为了清楚起见,我的

train_x.shape = (73, 84, 400)

vocab_size=400

.只要输入 3d 及以上的 input_shape,我意识到会导致错误。

任何帮助将不胜感激!!! :))

最佳答案

您不应在模型的输入形状中包含 n_samples。因此,您必须为层的输入形状指定大小为 2 的元组(或将形状的第一个元素设置为 None)。这里 Keras 会自动将 None 添加到您的输入形状中,从而导致 ndim=4。 有关这方面的更多信息,请参见 here .

此外,您的 input_dim=400(假设您使用词汇表中单词的单热编码表示)和您的训练数据包含 73 文本(漂亮small) 每个的长度都是 84。所以你应该设置 input_shape=(84,400).

关于python - 当 input_shape 指定为 3-d 时,Keras SimpleRNN 出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39171170/

相关文章:

c# - 具有 6 个特征的人工神经网络训练

machine-learning - CNN 能否学会比其他特征 channel 更多地权衡某些特征 channel ?

machine-learning - 缩放 LSTM 权重有意义吗?

python - 有没有办法将 libgpuarray 与 Intel GPU 一起使用?

python - 如何对 url 中发布的字典的键进行排序?

Python Beautiful Soup 'NoneType' 对象错误

php - 两个脚本之间的通信

python - 叠瓦循环和 Python 语法

NumPy 与 Theano?

python - 警告 (theano.sandbox.cuda) : CUDA is installed, 但设备 gpu 不可用(错误:cuda 不可用)