假设值
和张量T
都具有形状(N,K)
。现在,如果我们从矩阵的角度考虑它们,我希望对于 T
的每一行都获取与 values
具有最大值的索引相对应的行元素。我可以使用
max_indicies = tf.argmax(T, 1)
返回形状为(N)
的张量。现在,我如何从 T
收集这些索引,以便获得形状 N
的东西?我试过了
result = tf.gather(T,max_indices)
但它没有做正确的事情 - 它返回形状 (N,K)
的东西,这意味着它没有收集任何东西。
最佳答案
您可以使用tf.gather_nd .
例如,
import tensorflow as tf
sess = tf.InteractiveSession()
values = tf.constant([[0, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 1, 0]])
T = tf.constant([[0, 1, 2 , 3],
[4, 5, 6 , 7],
[8, 9, 10, 11]])
max_indices = tf.argmax(values, axis=1)
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0].
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0],
dtype=max_indices.dtype),
max_indices),
axis=1))
print(result.eval())
但是当values
和T
的行列较高时,tf.gather_nd
的使用就会有点尴尬。我将当前的解决方案发布在this question上。对于高维值
和T
,可能有更好的解决方案。
关于python - 沿着张量的第二维收集元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42175635/