python - TensorFlow:获取非零张量中最小元素索引的有效方法?

标签 python tensorflow boolean tensor

我使用 TensorFlow 1.12。我有一个一维张量 tag_mask_sizes,它主要包含零,但也包含一些正整数。如何有效地获取非零的最小元素的索引?我尝试了以下方法:

tag_mask_sizes_suppressed = tf.map_fn(lambda x: x if tf.not_equal(x, tf.constant(0, dtype=tf.uint8)) else 9999999, tag_mask_sizes)
        smallest_mask_index = tf.argmin(tag_mask_sizes_suppressed)

但是,tf.not_equal() 会产生一个 boolean 张量,我无法在 lambda 内的 if-else 条件下对其进行有效计算。还有其他像这样的优雅解决方案吗?

虽然我通常急切地执行,但这个问题发生在我在 tf.Dataset.map() 中使用的函数中,该函数并未急切地执行。

最佳答案

其实你的代码等价于下面的代码

tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)

矢量化方法将明显快于 tf.map_fn()。此外,还有一些矢量化方法可以得到一维张量中不为零的最小元素的索引。一个例子:

import tensorflow as tf
# tf.enable_eager_execution()

tag_mask_sizes = tf.constant([2,0,1,3,1,32,0,0,0], dtype=tf.int32)

# approach 1, the disadvantage is that the maximum must be specified and only the first minimum can be found.
tag_mask_sizes_suppressed = tf.where(tf.not_equal(tag_mask_sizes, 0),tag_mask_sizes,tag_mask_sizes+9999999)
smallest_mask_index1 = tf.argmin(tag_mask_sizes_suppressed)

# approach 2, only the first minimum can be found.
tag_mask_sizes_nozeroidx = tf.where(tf.not_equal(tag_mask_sizes, 0))
tag_mask_sizes_suppressed = tf.gather_nd(tag_mask_sizes,tag_mask_sizes_nozeroidx)
smallest_mask_index2 = tag_mask_sizes_nozeroidx[tf.argmin(tag_mask_sizes_suppressed)]

# approach 3, find all minimum
tag_mask_sizes_suppressed = tf.boolean_mask(tag_mask_sizes,tf.not_equal(tag_mask_sizes, 0))
smallest_mask_index3 = tf.squeeze(tf.where(tf.equal(tag_mask_sizes,tf.reduce_min(tag_mask_sizes_suppressed))))

with tf.Session() as sess:
    print(sess.run(smallest_mask_index1))
    print(sess.run(smallest_mask_index2))
    print(sess.run(smallest_mask_index3))

# print
2
[2]
[2 4]

关于python - TensorFlow:获取非零张量中最小元素索引的有效方法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56206571/

相关文章:

python - 为什么我们使用 numpy.argmax() 从 numpy 预测数组中返回索引?

java - 创建大小为 n 的 boolean 数组的所有可能方式?

java - 在 Java 中,在评估 emptyConstructor 条件时如何区别对待空类构造函数?

python - 有没有办法用 PIL 来回退丢失的字形?

python - 如何在 pandas 数据框中引用以数字作为名称的列?

python - Tensorflow Adam 优化器 vs Keras Adam 优化器

python - 将两个 boolean 列转换为 Pandas 中的类 ID

python - Spark 作业未结束 : Show of dataframe

python - 使用 Scrapy XPATH 获取属性名称

tensorflow - Keras 中填充输出的 F1 分数