我正在使用 dynamic_rnn 处理 MNIST 数据:
# LSTM Cell
lstm = rnn_cell.LSTMCell(num_units=200,
forget_bias=1.0,
initializer=tf.random_normal)
# Initial state
istate = lstm.zero_state(batch_size, "float")
# Get lstm cell output
output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)
# Output at last time point T
output_at_T = output[:, 27, :]
完整代码:http://pastebin.com/bhf9MgMe
lstm的输入是(batch_size, sequence_length, input_size)
因此 output_at_T
的维度为 (batch_size, sequence_length, num_units)
其中 num_units=200
。
我需要沿 sequence_length
维度获取最后一个输出。在上面的代码中,这被硬编码为 27
。但是,我事先不知道 sequence_length
,因为它可以在我的应用程序中逐批更改。
我试过了:
output_at_T = output[:, -1, :]
但它表示尚未实现负索引,我尝试使用占位符变量和常量(理想情况下,我可以将特定批处理的 sequence_length
输入其中);都没有用。
有什么方法可以在 tensorflow atm 中实现类似的东西?
最佳答案
您是否注意到 dynamic_rnn 有两个输出?
- 输出 1,我们称之为 h,在每个时间步都有所有输出(即 h_1、h_2 等),
- 输出 2,final_state,有两个元素:cell_state,以及批处理中每个元素的最后一个输出(只要您将序列长度输入到 dynamic_rnn)。
所以来自:
h, final_state= tf.dynamic_rnn( ..., sequence_length=[batch_size_vector], ... )
批处理中每个元素的最后状态是:
final_state.h
请注意,这包括批处理的每个元素的序列长度不同的情况,因为我们使用的是 sequence_length 参数。
关于python - 获取 tensorflow 中dynamic_rnn的最后输出?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36817596/