python - 在 Tensorflow 中设置交集

标签 python tensorflow intersection set-intersection

我想检查稀疏张量中是否包含一组给定值。稀疏张量称为“标签”,只有一个维度,其中包含 id 列表。

最后这似乎是一个简单的集合交集问题,所以我尝试了这个。

sparse_ids = load_ids_as_sparse_tensor()
wanted_ids = tf.constant([34, 56, 12])
intersection = tf.sets.set_intersection(
    wanted_ids,
    tf.cast(sparse_ids.values, tf.int32)
)
contains_any_wanted_ids = tf.not_equal(tf.size(intersection), 0)

但是,我收到此错误:

ValueError:形状必须至少为 2 级,但对于输入形状为“DenseToDenseSetOperation”(操作:“DenseToDenseSetOperation”)的等级为 1:[3]、[?]。

有什么想法吗?

最佳答案

以下代码有效。不过,我不确定结果是否是你想要的。

import tensorflow as tf
a = tf.constant([34, 56, 12])
b = tf.constant([56])
intersection = tf.sets.set_intersection(a[None,:],b[None,:])
sess=tf.Session()
sess.run(intersection)

输出:

SparseTensorValue(indices=array([[0, 0]], dtype=int64), values=array([56]), dense_shape=array([1, 1], dtype=int64))

关于python - 在 Tensorflow 中设置交集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53346195/

相关文章:

java - 在两个 Sprite 碰撞之前获得最后一个双倍

python - 查找列表中的每个第 n 个元素

python - 是否有与 Ruby 的字符串插值等效的 Python?

python - 我正在学习 Python,但我不明白这种表示法

tensorflow - 当不再需要时如何从内存中释放张量?

java - 为什么这条线交点代码不起作用?

python - Google Analytics API - 我可以使用自己的 Google 帐户向其他人显示报告吗?

tensorflow - 使用tensorflow实现自动编码器

python - 输出层有四个节点,但我想使用其中一个节点的输出,如何修复?

c++ - 表达式: Sequence not ordered