python - 沿着张量的第二维收集元素

标签 python tensorflow

假设和张量T都具有形状(N,K)。现在,如果我们从矩阵的角度考虑它们,我希望对于 T 的每一行都获取与 values 具有最大值的索引相对应的行元素。我可以使用

轻松找到这些索引
max_indicies = tf.argmax(T, 1)

返回形状为(N)的张量。现在,我如何从 T 收集这些索引,以便获得形状 N 的东西?我试过了

result = tf.gather(T,max_indices)

但它没有做正确的事情 - 它返回形状 (N,K) 的东西,这意味着它没有收集任何东西。

最佳答案

您可以使用tf.gather_nd .

例如,

import tensorflow as tf

sess = tf.InteractiveSession()

values = tf.constant([[0, 0, 0, 1],
                      [0, 1, 0, 0],
                      [0, 0, 1, 0]])

T = tf.constant([[0, 1, 2 ,  3],
                 [4, 5, 6 ,  7],
                 [8, 9, 10, 11]])

max_indices = tf.argmax(values, axis=1)
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0].
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0], 
                                            dtype=max_indices.dtype),
                                   max_indices),
                                  axis=1))

print(result.eval())

但是当valuesT的行列较高时,tf.gather_nd的使用就会有点尴尬。我将当前的解决方案发布在this question上。对于高维T,可能有更好的解决方案。

关于python - 沿着张量的第二维收集元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42175635/

相关文章:

python - SymPy 和复数的平方根

tensorflow - 使用 Tensorflow 图形转换工具

python - 凯拉斯 - ValueError : The first layer in a Sequential model must get an `input_shape` or `batch_input_shape` argument

tensorflow - 如何使用 TensorFlow 打乱整个数据集?

python - 当 LSTM 层的输入数量大于或小于该层中 LSTM 单元的数量时,Keras 会做什么?

python - "seek"在新创建的二进制文件上使用时是否写入零字节?

python - Python 中的简单 URL GET/POST 函数

tensorflow - 带有 TFRecord 训练/测试文件的 mnist 和 cifar10 示例

python - 在 Mint 上为 Python3 安装 PyAudio 时遇到问题

python - 当找到或未找到正则表达式模式时记录消息(详细 re.sub)