python - 如何按给定段收集张量元素?

标签 python tensorflow

我正在尝试实现一个“segment_collect”(非常类似于segment_max,但收集到一个张量而不是取最大值)。

t = tf.constant(["a", "b", "c", "d"])
s = tf.constant([0, 1, 1, 0])
r = tf.segment_collect(t, s)  # r == [["a", "d"], ["b", "c"]]

一个简单的实现是使用以下伪代码逐行构建结果:

r = []
for i in range(2):
    mask = tf.equal(s, i)
    values = tf.boolean_mask(t, mask)
    r.append(values)
# convert r into a tensor at last

但这不是很有效。

后续问题是:是否有一种通用方法可以对张量进行分组/聚合?除了 tensorflow 中的segment_{min/max/mean/prod/sum}之外,这还允许更多操作,例如segment_size、segment_median、segment_percentile。

最佳答案

您可能会发现 tf.gathertf.nn.topk有帮助:

tf.gather(t, tf.nn.top_k(-s, k=tf.shape(s)[0]).indices) 

这适用于 TF 1.x 和 TF 2.0。如果需要, reshape 结果:

tf.reshape(tf.gather(t, tf.nn.top_k(-s, k=tf.shape(s)[0]).indices), shape=(-1, 2)) 

当然, reshape 假设元素被分为大小相等的组(在本例中为两个组)。

<小时/>
tf.keras.backend.eval(tf.gather(t, tf.nn.top_k(-s, k=tf.shape(s)[0]).indices))                                                           
# array([b'a', b'd', b'b', b'c'], dtype=object)

关于python - 如何按给定段收集张量元素?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56796261/

相关文章:

python - 模型方法关注点分离

python - 在没有\的情况下在Python中分割一条长线

variables - 使用 np.zeros 和使用 tf.zeros 初始化 Tensorflow 变量之间的区别

python - 实际需要 global_variables_initializer() 时

tensorflow - 将 tf.dataset 写回 TFRecord

Python:关闭终端窗口后,带有 'print' 的 GUI 函数不再有效

python - 如何使用cx_Oracle在python中执行非sql命令

Python:.pyc 文件何时创建?

tensorflow - 为什么relu6中是6?

machine-learning - 如何理解 tensorflow 张量板直方图?