如果 A
是像这样的 TensorFlow 变量
A = tf.Variable([[1, 2], [3, 4]])
和index
是另一个变量
index = tf.Variable([0, 1])
我想使用这个索引来选择每行中的列。在这种情况下,第一行的项目 0 和第二行的项目 1。
如果 A 是一个 Numpy 数组,那么要获取索引中提到的相应行的列,我们可以这样做
x = A[np.arange(A.shape[0]), index]
结果是
[1, 4]
TensorFlow 的等效操作是什么?我知道 TensorFlow 不支持很多索引操作。如果不能直接完成,有什么解决办法?
最佳答案
您可以使用行索引扩展列索引,然后使用 gather_nd:
import tensorflow as tf
A = tf.constant([[1, 2], [3, 4]])
indices = tf.constant([1, 0])
# prepare row indices
row_indices = tf.range(tf.shape(indices)[0])
# zip row indices with column indices
full_indices = tf.stack([row_indices, indices], axis=1)
# retrieve values by indices
S = tf.gather_nd(A, full_indices)
session = tf.InteractiveSession()
session.run(S)
关于python - TensorFlow 获取特定列的每一行的元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39684415/