我有一个张量 probs
形状为(None, None, 110)
代表(batch_size, sequence_length, 110)
在 LSTM 中。
我有另一个张量indices
形状为(None, None)
,其中包含要从 probs
第三维中选择的元素的索引.
我想使用indices
索引张量 probs
。
Numpy 等效项:
k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]
自 shape[0]
和shape[1]
的probs
未知,tf.meshgrid()
不是一个选择。
我发现tf.gather
, tf.gather_nd
和tf.batch_gather
,但他们似乎都没有做我想做的事。
有人知道怎么做吗?
最佳答案
您可以使用 tf.gather_nd
来做到这一点像这样:
indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)
顺便说一下,在 NumPy 中你可以使用 np.take_along_axis
做同样的事情:
indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]
关于python - 在 Tensorflow 2.0 中用另一个张量索引张量的第 k 维,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62191509/