嗨, tensorflow 初学者,
我想删除实现中的任何 numpy 代码,只使用 tensorflow 函数。目前我正在尝试过滤掉背景边界框和置信度较低的框。为此,我想要一个名为 keep 的索引,我可以用它来跟踪要保留的框:
# Filter out background boxes
keep = np.where(class_ids > 0)[0]
# Filter out low confidence boxes
if config.DETECTION_MIN_CONFIDENCE:
keep = np.intersect1d(
keep, np.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[0])
class_ids 是一个形状为 (1000) 的张量,其中每个条目都是 0 到 80 之间的数字,具体取决于类别(总共 81 个类别)。
class_scores 是形状为 (1000) 的张量,其中每个条目都是相应边界框类别的概率。
我知道 np.where() 很容易更改为 tf.where 但如何使用 tensorflow 获得与 np.intersect1d() 相同的功能?
感谢您的帮助。
最佳答案
这似乎重复了 numpy.intersect1d 示例。
import tensorflow as tf
a = tf.constant([3, 1, 2, 1])
b = tf.constant([1, 3, 4, 3])
# This set appears to be sorted, but that is not documented behavior.
s = tf.sets.set_intersection(a[None,:], b[None, :])
fsort = tf.contrib.framework.sort(s.values)
with tf.Session() as sess:
print(sess.run(s).values)
print(sess.run(fsort))
此输出
[1 3]
[1 3]
通过一些测试示例,set 函数似乎给出了有序结果,但我无法验证它是否总是会这样做。因此,您可能需要使用 contrib 函数来确定。
关于python - 求两个张量的交集。返回两个输入张量中已排序的唯一值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48303354/