python - TensorFlow 获取特定列的每一行的元素

标签 python numpy tensorflow

如果 A 是像这样的 TensorFlow 变量

A = tf.Variable([[1, 2], [3, 4]])

index是另一个变量

index = tf.Variable([0, 1])

我想使用这个索引来选择每行中的列。在这种情况下,第一行的项目 0 和第二行的项目 1。

如果 A 是一个 Numpy 数组,那么要获取索引中提到的相应行的列,我们可以这样做

x = A[np.arange(A.shape[0]), index]

结果是

[1, 4]

TensorFlow 的等效操作是什么?我知道 TensorFlow 不支持很多索引操作。如果不能直接完成,有什么解决办法?

最佳答案

您可以使用行索引扩展列索引,然后使用 gather_nd:

import tensorflow as tf

A = tf.constant([[1, 2], [3, 4]])
indices = tf.constant([1, 0])

# prepare row indices
row_indices = tf.range(tf.shape(indices)[0])

# zip row indices with column indices
full_indices = tf.stack([row_indices, indices], axis=1)

# retrieve values by indices
S = tf.gather_nd(A, full_indices)

session = tf.InteractiveSession()
session.run(S)

关于python - TensorFlow 获取特定列的每一行的元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39684415/

相关文章:

python - 如何初始化 2D numpy 数组

tensorflow - 使用tensorflow在keras中切片损失函数的输入

python - 如何定义获取第一个元素的 python lambda?

python - Django:一个 View 的基本身份验证(避免中间件)

python - 使用 SQLAlchemy 在关联表中插入数据时出现 IntegrityError

python - 如何使用 np.save 将文件保存在 python 的不同目录中?

python - 我在执行 multivariate_gauss pdf 时遇到什么问题?

python - InvalidArgumentError : Received a label value of 8825 which is outside the valid range of [0, 8825) SEQ2SEQ 模型

python - Tensorflow 返回相同的预测

python - 在Python中重命名目录中的多个文件