python - 在 TensorFlow 中,如何断言列表的值在某个集合中?

标签 python tensorflow

我有一个一维 tf.uint8 张量 x 并且想断言该张量内的所有值都在集合 s 中定义。 s 在图形定义时是固定的,因此它不是动态计算的张量。

在普通的 Python 中,我想做某事。像下面这样:

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}
assert all(el in s for el in x), "This should fail, as 5 is not in s"

我知道我可以使用 tf.Assert对于断言部分,但我正在努力定义条件部分(el in s)。执行此操作的最简单/最规范的方法是什么?

较旧的答案 Determining if A Value is in a Set in TensorFlow对我来说还不够:首先,写下来和理解起来很复杂,其次,它使用的是广播的tf.equal,这在计算方面比适当的基于集合的检查更昂贵.

最佳答案

一个简单的方法可能是这样的:

import tensorflow as tf

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}

x_t = tf.constant(x, dtype=tf.uint8)
s_t = tf.constant(list(s), dtype=tf.uint8)
# Check every value in x against every value in s
xs_eq = tf.equal(x_t[:, tf.newaxis], s_t)
# Check every element in x is equal to at least one element in s
assert_op = tf.Assert(tf.reduce_all(tf.reduce_any(xs_eq, axis=1)), [x_t])
with tf.control_dependencies([assert_op]):
    # Use x_t...

这将创建一个大小为 (len(x), len(s)) 的中间张量。如果这是有问题的,您还可以将问题拆分为独立的张量,例如:

import tensorflow as tf

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}

x_t = tf.constant(x, dtype=tf.uint8)
# Count where each x matches each s
x_in_s = [tf.cast(tf.equal(x_t, si), tf.int32) for si in s]
# Add matches and check there is at least one match per x
assert_op = tf.Assert(tf.reduce_all(tf.add_n(x_in_s) > 0), [x_t])

编辑:

实际上,既然你说你的值是 tf.uint8,你可以使用 bool 数组让事情变得更好:

import tensorflow as tf

x = [1, 2, 3, 1, 11, 3, 5]
s = {1, 2, 3, 11, 12, 13}

x_t = tf.constant(x, dtype=tf.uint8)
s_t = tf.constant(list(s), dtype=tf.uint8)
# One-hot vectors of values included in x and s
x_bool = tf.scatter_nd(tf.cast(x_t[:, tf.newaxis], tf.int32),
                       tf.ones_like(x_t, dtype=tf.bool), [256])
s_bool = tf.scatter_nd(tf.cast(s_t[:, tf.newaxis], tf.int32),
                       tf.ones_like(s_t, dtype=tf.bool), [256])
# Check that all values in x are in s
assert_op = tf.Assert(tf.reduce_all(tf.equal(x_bool, x_bool & s_bool)), [x_t])

这需要线性时间和常量内存。

编辑 2:虽然最后一种方法理论上在这种情况下是最好的,但做几个快速基准测试时,我只能看到当我达到数十万个元素时性能上的显着差异,无论如何这三个使用 tf.uint8 仍然非常快。

关于python - 在 TensorFlow 中,如何断言列表的值在某个集合中?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53061284/

相关文章:

python - 阈值函数 - 没有为任何变量提供梯度

python - POX l3_学习示例

python - Flask-SQLAlchemy:如何根据用户的角色和团队进行查询?

python - 正则表达式查找时间戳后的最后一个冒号

批量的 TensorFlow 图像操作

python - 通过 tensorflow 对整数数据进行分类

python - 查找矩阵大对角线下方的所有元素

python - 如何正确运行 2 个同时等待事物的线程?

python - Anaconda Tensorflow 中缺少 ImageNet (Inception v3) 模型?

linux - Udacity 深度学习的 TensorFlow 设置