python - 4D LSTM : Trouble with I/O Shapes

标签 python keras lstm recurrent-neural-network

我正在尝试让 4D TimeDistributed(LSTM(...)) 在 Keras 中工作,但我在输入/输出形状方面遇到问题。

batch_size = 1

model = Sequential()

model.add(TimeDistributed(LSTM(7, batch_input_shape=(batch_size,
    look_back,dataset.shape[1], dataset.shape[2]), stateful=True,
    return_sequences=True), batch_input_shape=(batch_size,
    look_back, dataset.shape[1], dataset.shape[2])))

model.add(TimeDistributed(LSTM(7, batch_input_shape= (batch_size,
    look_back,dataset.shape[1],dataset.shape[2]),
    stateful=True), batch_input_shape=(batch_size, look_back,
    dataset.shape[1], dataset.shape[2])))

model.add(TimeDistributed(Dense(7, input_shape = (batch_size,
   1,look_back, dataset.shape[1],dataset.shape[2]))))

model.compile(loss = 'mean_squared_error', optimizer='adam')

for i in range(10):
    model.fit(trainX, trainY, epochs=1, batch_size=batch_size,
        verbose=2, shuffle=False)
    model.reset_states()

trainX、trainY 和 dataset 的输入形状如下:

trainX.shape = (63, 3, 34607, 7)
trainY.shape = (63, 34607, 7)
dataset.shape = (100, 34607, 7)

我收到的错误如下:

Error when checking target: expected time_distributed_59 to have shape (1, 3, 7) but got array with shape (63, 34607, 7)

上面提到的层是关于最后一个TimeDistributed Dense Layer。

这是我打印出每层的输入和输出形状时的输出:

(1, 3, 34607, 7) layer[0] - Input
(1, 3, 34607, 7) layer[0] - Output
(1, 3, 34607, 7) layer[1] - Input
(1, 3, 7) layer[1] - Output
(1, 3, 7) layer[2] - Input
(1, 3, 7) layer[2] - Output

但是,最终输出层应该是形状为 (1, 1, 34067, 7) 或形状 (1, 34067, 7) 的预测

谢谢您的建议!

最佳答案

您没有在第二个时间分布式 LSTM 层上设置返回序列 = True;默认为 false。这可以解释您得到的 (1,3,7) 输出形状。

关于python - 4D LSTM : Trouble with I/O Shapes,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45107851/

相关文章:

lstm - 类型错误 : view() takes at most 2 arguments (3 given)

Python 小写​​列表理解测试失败时似乎是正确的?

python - 在 .so 文件中使用 C 模块时出现段错误

tensorflow - Google Colab TPU 比 GPU 花费更多时间

python - keras model.fit 函数打印的准确率与验证集还是训练集有关?

keras - 如何使用 LSTM 实现端到端关系提取

python - Keras:如何实现LSTM的目标复制?

python - 是否可以在从 C 调用的 ctypes 回调中引发 Python 异常?

javascript - 如何使用Django将表中的数据保存到数据库?

machine-learning - 如何学习语言模型?