tensorflow - 如何在每个纪元后重置 tensorflow 中 GRU 的状态

标签 tensorflow machine-learning deep-learning recurrent-neural-network gated-recurrent-unit

我正在使用 tensorflow GRU 单元来实现 RNN。我将上述内容用于最长 5 分钟的视频。因此,由于下一个状态会自动输入到 GRU 中,因此如何在每个时期之后手动重置 RNN 的状态。换句话说,我希望训练开始时的初始状态始终为 0。这是我的代码片段:

with tf.variable_scope('GRU'):
    latent_var = tf.reshape(latent_var, shape=[batch_size, time_steps, latent_dim])

    cell = tf.nn.rnn_cell.GRUCell(cell_size)   
    H, C = tf.nn.dynamic_rnn(cell, latent_var, dtype=tf.float32)  
    H = tf.reshape(H, [batch_size, cell_size]) 
....

非常感谢任何帮助!

最佳答案

使用tf.nn.dynamic_rnninitial_state参数:

initial_state: (optional) An initial state for the RNN. If cell.state_size is an integer, this must be a Tensor of appropriate type and shape [batch_size, cell.state_size]. If cell.state_size is a tuple, this should be a tuple of tensors having shapes [batch_size, s] for s in cell.state_size.

文档中的改编示例:

# create a GRUCell
cell = tf.nn.rnn_cell.GRUCell(cell_size)

# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

# defining initial state
initial_state = cell.zero_state(batch_size, dtype=tf.float32)

# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)
<小时/>

另请注意,尽管 initial_state 不是占位符,您也可以向其提供值。因此,如果希望保留一个纪元内的状态,但在纪元开始时从零开始,您可以这样做:

# Compute the zero state array of the right shape once
zero_state = sess.run(initial_state)

# Start with a zero vector and update it 
cur_state = zero_state
for batch in get_batches():
  cur_state, _ = sess.run([state, ...], feed_dict={initial_state=cur_state, ...})

关于tensorflow - 如何在每个纪元后重置 tensorflow 中 GRU 的状态,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48523923/

相关文章:

javascript - 为什么我的神经网络训练方法没有被调用? (ML5.JS)

arrays - 大规模数据集的核方法

python - TensorFlow:如果 tf.train.batch 已经并行出队示例,并行排队示例是否会加快批量创建速度?

python - 在 Python 中处理/显示极大的值

amazon-web-services - 咖啡 |检查失败 : error == cudaSuccess (2 vs. 0) 内存不足

python - 实现多对多回归任务

python - 将 TensorFlow 张量转换为 Numpy 数组

python - 错误 :tensorflow:Couldn't match files for checkpoint

python - 如何从 C++ 使用 TensorFlow Estimator?

r - R 中的朴素贝叶斯分类器 (e1071) 的行为不符合预期(简单示例)