python - 如何在 TensorFlow 中使用索引数组?

标签 python numpy tensorflow

如果给定一个形状为(5,3)的矩阵a和形状为(5,)<的索引数组b,我们可以很容易的得到对应的向量c,通过,

c = a[np.arange(5), b]

但是,我不能用 tensorflow 做同样的事情,

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, [5,])
# this line throws error
c = a[tf.range(5), b]

Traceback (most recent call last): File "", line 1, in File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 513, in _SliceHelper name=name)

File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 671, in strided_slice shrink_axis_mask=shrink_axis_mask) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3688, in strided_slice shrink_axis_mask=shrink_axis_mask, name=name) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op op_def=op_def) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2397, in create_op set_shapes_for_outputs(ret) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1757, in set_shapes_for_outputs shapes = shape_func(op) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1707, in call_with_requiring return call_cpp_shape_fn(op, require_shape_fn=True) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn debug_python_shape_fn, require_shape_fn) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 675, in _call_cpp_shape_fn_impl raise ValueError(err.message) ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice_14' (op: 'StridedSlice') with input shapes: [5,3], [2,5], [2,5], [2].

我的问题是,如果我使用上述方法无法在 tensorflow 中产生与在 numpy 中一样的预期结果,我该怎么办?

最佳答案

TensorFlow 目前未实现此功能。 GitHub issue #4638正在跟踪 NumPy 风格的“高级”索引的实现。但是,您可以使用 tf.gather_nd()运营商实现您的计划:

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, (5,))

row_indices = tf.range(5)

# `indices` is a 5 x 2 matrix of coordinates into `a`.
indices = tf.transpose([row_indices, b])

c = tf.gather_nd(a, indices)

关于python - 如何在 TensorFlow 中使用索引数组?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42622815/

相关文章:

python - Ray:如何在一个 GPU 上运行多个 actor?

python - 如何从 tensorflow 中的 RNN 模型中提取细胞状态和隐藏状态?

python - Django 管理 ListView

python - 对多个 NumPy 数组进行排序

python - 为 keras reshape 数组

python - 如何在数据框的数组列中选择一个元素?

python - 用于编译网络 (CNN) 的 Keras 自定义损失函数中出现错误

python - 为什么欧氏距离在这里不起作用?

python - cx_Oracle : how do I get the ORA-xxxxx error number?

python - "error: (-215) ssize.area() > 0 in function cv::resize"的大图像上的 OpenCV 调整大小失败