python - 使用 tensorflow 中的另一个索引列表访问张量的元素

标签 python tensorflow

我需要使用我拥有的另一个索引列表来访问张量的元素,但目前使用简单的语法似乎是不可能的。我不确定这是否是一个错误,所以我将其发布在这里以希望修复我的语法。我的代码是:

import tensorflow as tf
import numpy as np

sess = tf.Session()
input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
idx_list = np.array([0,2])
output = input[:, idx_list]

print(sess.run(output))

但我收到错误:

ValueError: Shapes must be equal rank, but are 0 and 1 From merging shape 0 with other shapes. for 'strided_slice/stack_1' (op: 'Pack') with input shapes: [], [2].

我安装的tensorflow版本是tensorflow-1.1.0-cp35(pip安装)。

更新:

我通过 tf.fn_map 执行此操作,但我真的怀疑这是进行索引的正确方法:

output = tf.transpose(tf.map_fn(lambda x: input[:,x], idx_list),perm=[1,0])

更新:

有一个特定的issue registered为此,最新评论中的一个不错的片段可能会有所帮助。同时这个操作并不像numpy那么容易......

最佳答案

您可以使用tf.gathertf.transpose来完成此操作,如下所示:

import tensorflow as tf
import numpy as np

sess = tf.Session()
input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
idx_list = np.array([0,2])
output = tf.transpose(tf.gather(tf.transpose(input),idx_list))
output.eval(session=sess)

这会打印

array([[1, 3],
       [4, 6],
       [7, 9]])

关于python - 使用 tensorflow 中的另一个索引列表访问张量的元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43735158/

相关文章:

python - 从 DataFrame 行元素生成元组

python - numpy切片奇怪的行为

python - 比较每一行的数据框列中的元素 - Python

python - 找不到 WSGI 应用程序 - 404

python - 如何垂直堆叠具有不同列名的 pandas 数据框

tensorflow - 在 tensorflow 2 中,什么成为图形的一部分,什么不是?

TensorFlow 可与 Slurm Interactive Session 配合使用,但不适用于 Slurm Job

python - 如何在 Tensorflow 中使用 SWA 实现 Batch Norm?

tensorflow - 检查目标 : expected dense_Dense2 to have shape x, 时出错,但得到形状为 y 的数组

python - 使用 OpenCV 图像提供 Inception