tensorflow - 堆叠 LSTM 的初始状态结构

标签 tensorflow keras lstm recurrent-neural-network

使用 tf.keras.layers.RNN 的 TensorFlow (1.13.1) 中多层/堆叠 RNN 的初始状态所需的结构是什么? API?

我尝试了以下方法:

lstm_cell_sizes = [256, 256, 256]
lstm_cells = [tf.keras.layers.LSTMCell(size) for size in lstm_cell_sizes]

state_init = [tf.placeholder(tf.float32, shape=[None] + cell.state_size) for cell in lstm_cells]

tf.keras.layers.RNN(lstm_cells, ...)(inputs, initial_state=state_init)

这导致:
ValueError: Could not pack sequence. Structure had 6 elements, but flat_sequence had 3 elements.  Structure: ([256, 256], [256, 256], [256, 256]), flat_sequence: [<tf.Tensor 'player/Placeholder:0' shape=(?, 256, 256) dtype=float32>, <tf.Tensor 'player/Placeholder_1:0' shape=(?, 256, 256) dtype=float32>, <tf.Tensor 'player/Placeholder_2:0' shape=(?, 256, 256) dtype=float32>].

如果我改变 state_init是一个扁平的张量列表,形状为 [None, 256]相反,我得到:
ValueError: An `initial_state` was passed that is not compatible with `cell.state_size`. Received `state_spec`=[InputSpec(shape=(None, 256), ndim=2), InputSpec(shape=(None, 256), ndim=2), InputSpec(shape=(None, 256), ndim=2)]; however `cell.state_size` is [[256, 256], [256, 256], [256, 256]]

Tensorflow RNN docs对此相当模糊:

"You can specify the initial state of RNN layers symbolically by calling them with the keyword argument initial_state. The value of initial_state should be a tensor or list of tensors representing the initial state of the RNN layer."

最佳答案

我相信你在 TF2 中的做法是这样的:

import tensorflow.compat.v2 as tf #If you have a newer version of TF1
#import tensorflow as tf          #If you have TF2

sentence_max_length = 5
batch_size = 3
n_hidden = 2
x = tf.constant(np.reshape(np.arange(30),(batch_size,sentence_max_length, n_hidden)), dtype = tf.float32)

stacked_lstm = tf.keras.layers.StackedRNNCells([tf.keras.layers.LSTMCell(128) for _ in range(2)])

lstm_layer = tf.keras.layers.RNN(stacked_lstm,return_state=False,return_sequences=False)

result = lstm_layer(x)
print(result)

关于tensorflow - 堆叠 LSTM 的初始状态结构,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55765447/

相关文章:

deep-learning - Convolution2D 层在 Keras 中究竟是如何工作的?

python - 无法解释一行使用 tensorflow 创建 LSTM 单元的 python 代码

python - tensorflow:简单 LSTM 网络的共享变量错误

opencv - 仅在 1 个类上使用 tensorflow 对象检测

django - 凯拉斯预测 celery 任务不归队

keras - 禁用 keras "Found # images belonging to # classes."消息

python - 使用 keras 进行零均值卷积

python - 如何动态训练 LSTM 模型?

python - 导入错误 : cannot import name 'trace' from 'tensorflow.python.profiler'

tensorflow - Keras - 有状态 LSTM 与无状态 LSTM