python - 有没有类似于 tf.cond 但用于向量谓词的东西?

标签 python python-3.x tensorflow

假设我有一个 5D 张量 x 和一个 1D bool 掩码 m,其中 m.shape[0] == x.shape[0],我想根据 m 相应的 bool 条目来决定应在 x 内的每个 4D 样本上应用哪个子网络。

据我所知,tf.cond 仅接受标量预测器。虽然 tf.boolean_mask 可能有助于根据需要将批处理内的样本分成两个子集,但我不确定如何将输出重新打包回一个 5D 张量,不会打乱原始样本顺序。有什么提示吗?

最佳答案

最简单的事情是评估两个模型上的数据,然后使用 tf.where 来选择最终输出。

import tensorflow as tf

def model1(x):
    return 2 * x

def model2(x):
    return -3 * x

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.placeholder(tf.float32, [None, None])  # 2D for simplicity
    m = tf.placeholder(tf.bool, [None])
    y1 = model1(x)
    y2 = model2(x)
    y = tf.where(m, y1, y2)
    print(sess.run(y, feed_dict={x: [[1, 2], [3, 4], [5, 6]], m: [True, False, True]}))
    # [[  2.   4.]
    #  [ -9. -12.]
    #  [ 10.  12.]]

如果您确实想避免这种情况,可以使用 tf.boolean_mask 来拆分数据,然后使用 tf.scatter_nd 重新组合。这是一种可能的方法。

import tensorflow as tf

def model1(x):
    return 2 * x

def model2(x):
    return -3 * x

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.placeholder(tf.float32, [None, None])
    m = tf.placeholder(tf.bool, [None])
    n = tf.size(m)
    i = tf.range(n)
    x1 = tf.boolean_mask(x, m)
    i1 = tf.boolean_mask(i, m)
    y1 = model1(x1)
    m_neg = ~m
    x2 = tf.boolean_mask(x, m_neg)
    i2 = tf.boolean_mask(i, m_neg)
    y2 = model2(x2)
    y = tf.scatter_nd(tf.expand_dims(tf.concat([i1, i2], axis=0), 1),
                      tf.concat([y1, y2], axis=0),
                      tf.concat([[n], tf.shape(y1)[1:]], axis=0))
    print(sess.run(y, feed_dict={x: [[1, 2], [3, 4], [5, 6]], m: [True, False, True]}))
    # [[  2.   4.]
    #  [ -9. -12.]
    #  [ 10.  12.]]

关于python - 有没有类似于 tf.cond 但用于向量谓词的东西?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57575484/

相关文章:

python - 在 pytest 中,如何暂时禁用类方法中的捕获?

python - Tensorflow:LSTM 和 GradientDescentOptimizer 的梯度为零

python - 使用 tensorflow 求解函数方程

python - 在 MongoEngine 查询中获取引用对象的数据,而不仅仅是 id

python - Microsoft Azure 数据仓库和 SqlAlchemy

python - Matplotlib:图形左边缘和 y 轴之间的固定间距

python - 为什么在 Linux 上导入 numpy 会增加 1 GB 的虚拟内存?

python - 如何在python中模拟const变量

django - 禁用分页器时如何防止将行计数发送到数据库

python - 如何构建具有多个输入的 Tensorflow 模型?