我在理解堆叠式 LSTM 网络中各层的输入-输出流时遇到了一些困难。假设我已经创建了一个如下所示的堆叠式 LSTM 网络:
# parameters
time_steps = 10
features = 2
input_shape = [time_steps, features]
batch_size = 32
# model
model = Sequential()
model.add(LSTM(64, input_shape=input_shape, return_sequences=True))
model.add(LSTM(32,input_shape=input_shape))
我们的堆叠 LSTM 网络由 2 个 LSTM 层组成,分别具有 64 和 32 个隐藏单元。在这种情况下,我们期望在每个时间步第一个 LSTM 层 -LSTM(64)- 将作为输入传递给第二个 LSTM 层 -LSTM(32)- 一个大小为 [batch_size, time-step , hidden_unit_length]
,表示第一个 LSTM 层在当前时间步的隐藏状态。让我困惑的是:
- 第二个 LSTM 层 -LSTM(32)- 是否接收第一层 -LSTM(64)- 的隐藏状态作为
X(t)
(作为输入),其大小为[batch_size, time-step, hidden_unit_length]
并通过它自己的隐藏网络传递它 - 在这种情况下由 32 个节点组成 -? - 如果第一个为真,为什么第一个 -LSTM(64)- 和第二个 -LSTM(32)- 的
input_shape
相同,而第二个只处理第一层?在我们的例子中,不应该将input_shape
设置为[32, 10, 64]
吗?
我发现下面的 LSTM 可视化非常有用(找到 here )但它不会在堆叠式 lstm 网络上扩展:
任何帮助将不胜感激。 谢谢!
最佳答案
只有第一层需要input_shape
。后续层将前一层的输出作为其输入(因此它们的 input_shape
参数值被忽略)
下面的模型
model = Sequential()
model.add(LSTM(64, return_sequences=True, input_shape=(5, 2)))
model.add(LSTM(32))
代表下面的架构
你可以从 model.summary()
验证它
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_26 (LSTM) (None, 5, 64) 17152
_________________________________________________________________
lstm_27 (LSTM) (None, 32) 12416
=================================================================
替换行
model.add(LSTM(32))
与
model.add(LSTM(32, input_shape=(1000000, 200000)))
仍然会为您提供相同的架构(使用 model.summary()
验证),因为 input_shape
被忽略,因为它将上一层的张量输出作为输入.
如果你需要一个序列来排序架构,如下所示
你应该使用代码:
model = Sequential()
model.add(LSTM(64, return_sequences=True, input_shape=(5, 2)))
model.add(LSTM(32, return_sequences=True))
应该返回一个模型
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_32 (LSTM) (None, 5, 64) 17152
_________________________________________________________________
lstm_33 (LSTM) (None, 5, 32) 12416
=================================================================
关于deep-learning - 堆叠式 LSTM 网络中每个 LSTM 层的输入是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55385906/