python - 获取 tensorflow 中dynamic_rnn的最后输出?

标签 python tensorflow

我正在使用 dynamic_rnn 处理 MNIST 数据:

# LSTM Cell
lstm = rnn_cell.LSTMCell(num_units=200,
                         forget_bias=1.0,
                         initializer=tf.random_normal)

# Initial state
istate = lstm.zero_state(batch_size, "float")

# Get lstm cell output
output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)

# Output at last time point T
output_at_T = output[:, 27, :]

完整代码:http://pastebin.com/bhf9MgMe

lstm的输入是(batch_size, sequence_length, input_size)

因此 output_at_T 的维度为 (batch_size, sequence_length, num_units) 其中 num_units=200

我需要沿 sequence_length 维度获取最后一个输出。在上面的代码中,这被硬编码为 27。但是,我事先不知道 sequence_length,因为它可以在我的应用程序中逐批更改。

我试过了:

output_at_T = output[:, -1, :]

但它表示尚未实现负索引,我尝试使用占位符变量和常量(理想情况下,我可以将特定批处理的 sequence_length 输入其中);都没有用。

有什么方法可以在 tensorflow atm 中实现类似的东西?

最佳答案

您是否注意到 dynamic_rnn 有两个输出?

  1. 输出 1,我们称之为 h,在每个时间步都有所有输出(即 h_1、h_2 等),
  2. 输出 2,final_state,有两个元素:cell_state,以及批处理中每个元素的最后一个输出(只要您将序列长度输入到 dynamic_rnn)。

所以来自:

h, final_state= tf.dynamic_rnn( ..., sequence_length=[batch_size_vector], ... )

批处理中每个元素的最后状态是:

final_state.h

请注意,这包括批处理的每个元素的序列长度不同的情况,因为我们使用的是 sequence_length 参数。

关于python - 获取 tensorflow 中dynamic_rnn的最后输出?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36817596/

相关文章:

python - Keras - TypeError : Output tensors to a Model must be Keras tensors - while modelling multiple input , 多输出网络

python - YOLO在YOLO 9000中如何计算P(Object)

python - 人工神经网络 - 编译错误

python - 如何使用两个损失函数训练模型?

tensorflow - 为什么当我在 tensorflow 中更改测试批量大小时,结果不同

python - 提取访问多个(纬度,经度)的 ID 的最大距离

python - Ubuntu - 尝试安装 Python Couchbase lib - "libcouchbase/couchbase.h: No such file or directory"

python - 如何在无服务器 Lambda (Python) 中下载 S3 文件

python - pip install tensorflow 找不到名为 client_load_reporting_filter.h 的文件

python - 直接从 UI 文件加载 QDialog?