python - 在 Tensorflow 2.0 中用另一个张量索引张量的第 k 维

标签 python numpy tensorflow indexing tensorflow2.0

我有一个张量 probs形状为(None, None, 110)代表(batch_size, sequence_length, 110)在 LSTM 中。 我有另一个张量indices形状为(None, None) ,其中包含要从 probs 第三维中选择的元素的索引.

我想使用indices索引张量 probs

Numpy 等效项:

k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]

shape[0]shape[1]probs未知,tf.meshgrid()不是一个选择。 我发现tf.gather , tf.gather_ndtf.batch_gather ,但他们似乎都没有做我想做的事。

有人知道怎么做吗?

最佳答案

您可以使用 tf.gather_nd 来做到这一点像这样:

indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)

顺便说一下,在 NumPy 中你可以使用 np.take_along_axis做同样的事情:

indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]

关于python - 在 Tensorflow 2.0 中用另一个张量索引张量的第 k 维,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62191509/

相关文章:

python - 如何使用Python在图像上生成随机正弦条纹?

python - 为什么 Numpy 比这个 cython 例程好 3 倍

python - Pillow 和 Numpy,获取像素值

python - 如何将给定行 `i` 或列 `j` 与标量相乘?

tensorflow - 为什么在 Tensorflow 文档中没有提到 contrib.layers.linear?

python - pyserial - 如何连续读取和解析

python - 如何修复我的井字游戏中检查完整棋盘的功能

python - 在 pyparsing 中使用 escChar 和 escQuote

tensorflow - 我的训练数据集对于我的神经网络来说是否太复杂?

TensorFlow 没有属性 "with_dependencies"