python - 如何根据 tensorflow 中另一个矩阵获得的最大值和次要值以及索引来获取矩阵中每一行的值?

标签 python tensorflow

如何根据 tensorflow 中另一个矩阵获得的最大和次要值索引获取矩阵中每一行的值?例如。我有一个矩阵张量 A 为

    [[1,2,3],
     [6,5,4],
     [7,9,8]], 

和矩阵张量 B

    [[10,11,12],
     [13,14,15],
     [16,17,18]]. 

然后我得到矩阵A的最大值和次最大值索引向量

    [[2,1],
     [0,2],
     [1,2]] 

通过使用 tf.nn_topk。然后我想从这些索引中获取矩阵 B 的确切值,即

    [[12,11],
     [13,15],
     [17,18]]. 

我该怎么办?似乎 tf.gather_nd 可以完成这项工作,但我不知道如何为其生成二维索引。

最佳答案

因此,对于这种特定情况,此代码返回值。

  1. 它没有矢量化
  2. 它不是通用的

它只是为 gather_nd 创建一个模板,如下所示。

 [[0 1]
  [0 2]
  [1 2]
  [1 0]
  [2 2]
  [2 1]]

其他人可能有更紧凑的想法。

import tensorflow as tf

A = tf.Variable([[10,11,12],
     [13,14,15],
     [16,17,18]], )

B = tf.Variable([[2,1],
     [0,2],
     [1,2]] )


sess = tf.Session()
sess.run(tf.global_variables_initializer())

indices =  sess.run(B)

incre = tf.Variable(0)
template = tf.Variable(tf.zeros([6,2],tf.int32))

sess.run(tf.global_variables_initializer())

#There are 3 rows in the indices array
row = tf.gather( indices , [0,1,2])


for i in range(0, row.get_shape()[0] ) :

    newrow = tf.gather(row, i)

    exprow1 = tf.concat([tf.constant([i]), newrow[1:]], axis=0)
    exprow2 = tf.concat([tf.constant([i]), newrow[:1]], axis=0)

    template = tf.scatter_update(template, incre, exprow1)
    template = tf.scatter_update(template, incre + 1, exprow2)

    #Dataflow execution dependency is enforced.
    with tf.control_dependencies([template]): 
        incre = tf.assign(incre,incre + 2)

print(sess.run(tf.gather_nd(A,template)))

输出是这样的。

[11 12 15 13 18 17]

关于python - 如何根据 tensorflow 中另一个矩阵获得的最大值和次要值以及索引来获取矩阵中每一行的值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51393213/

相关文章:

python - 将 3 堆数字从左到右逐行排序

tensorflow - 如何设置用于二元分类的神经网络架构

python - tf.function 输入参数

java - tensorflow java模型推理将获取的张量转换为字符串?

python - 如何在 Windows 上的 Python 2.7 上安装 Tensorflow?

python - Keras 意外的内核正则化器错误

python - 从S3向Elasticsearch发送日志文件的Python 2.7与3.8 Lambda

python - 为什么我收到 Typeerror : 'int' object not iterable

python - 如何使用 pandas 将整个列字符串转换为数据框中的 float ?

python - 替换 N-d numpy 数组中的字符串