Tensorflow:可以阻止 tf.where 的一个分支执行吗?

标签 tensorflow

我正在研究编码器-解码器设置。我希望能够运行一次编码器,然后运行多个解码器。我想出的解决方案是为解码器提供一个 TF 条件节点(使用 tf.where),它包含编码器的最终隐藏状态(在这种情况下,当我请求解码器输出时,TF 将运行编码器),或带有编码器存储结果的占位符(在这种情况下,理论上 TF 不需要运行编码器)。

这是代码的相关部分:

encoder_state = tf.where(gen_math_ops.greater_equal(branching_points, 0), encoder_state,
                         rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])

因为我没有从这种方法中获得加速,所以我很确定它不起作用并且 tf.where 的两个分支每次都由 TF 运行,即使它只需要从占位符读取。

有没有什么方法可以使用 tf.where 使其不运行编码器?我查看了该方法的描述,但我不确定是否始终计算两个分支,我在这个问题上看到了相互矛盾的信息。

谢谢!

最佳答案

tf.cond()当您想要推迟执行其中一个分支直到对谓词求值时,可以使用函数。

encoder_state = tf.cond(
    tf.greater_equal(branching_points, 0),
    lambda: encoder_state,
    lambda: tf.nn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])

关于Tensorflow:可以阻止 tf.where 的一个分支执行吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44569219/

相关文章:

tensorflow - gcloud 作业提交预测 'can' t 解码 json' 与 --data-format=TF_RECORD

Python 如何将 pyaudio 字节转换为虚拟文件?

python - Tensorflow机器之间损失的主要差异

python - 检查目标 : expected dense_2 to have shape (9, 时出错,但得到形状为 (30,) 的数组

python - 如何为分类值列表的列创建嵌入

python - 将 Numpy 数组转换为张量

python - 当模型具有注意力层时,无法从 Model.get_config() 加载 keras 中的模型

python - Tensorflow:x - reduce_mean(x) 的梯度为 0

tensorflow - 'numpy.dtype' 对象在 keras 中没有属性 'base_dtype'

python - 未找到 :/lib/python2. 7/site-packages/tensorflow/tensorboard/lib/svg/summary-icon.sv