我正在研究编码器-解码器设置。我希望能够运行一次编码器,然后运行多个解码器。我想出的解决方案是为解码器提供一个 TF 条件节点(使用 tf.where),它包含编码器的最终隐藏状态(在这种情况下,当我请求解码器输出时,TF 将运行编码器),或带有编码器存储结果的占位符(在这种情况下,理论上 TF 不需要运行编码器)。
这是代码的相关部分:
encoder_state = tf.where(gen_math_ops.greater_equal(branching_points, 0), encoder_state,
rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])
因为我没有从这种方法中获得加速,所以我很确定它不起作用并且 tf.where 的两个分支每次都由 TF 运行,即使它只需要从占位符读取。
有没有什么方法可以使用 tf.where 使其不运行编码器?我查看了该方法的描述,但我不确定是否始终计算两个分支,我在这个问题上看到了相互矛盾的信息。
谢谢!
最佳答案
tf.cond()
当您想要推迟执行其中一个分支直到对谓词求值时,可以使用函数。
encoder_state = tf.cond(
tf.greater_equal(branching_points, 0),
lambda: encoder_state,
lambda: tf.nn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])
关于Tensorflow:可以阻止 tf.where 的一个分支执行吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44569219/