假设我有一个大小为 batch
x max_len
x output_size
的数组,其中 batch
, max_len
和output_size
都对应于正自然数。我有一个索引列表,对应于维度 1 中的各个项目(即 max_len
)。如何从给定这些索引的数组中进行选择?
作为一个具体的例子,假设我有以下内容:
>>> l = np.random.randn(4,5,6)
>>> l.shape
(4, 5, 6)
>>> idx = [0,0,2,3]
当我选择给定idx
的l
时,我得到:
>>> l[:,idx,:].shape
(4, 4, 6)
>>>
我也尝试了np.take
但达到了相同的结果:
>>> np.take(l,idx,axis=1).shape
(4, 4, 6)
>>>
但是,我正在寻找的输出是(4,1,6)
,因为我试图仅让一个项目查看批处理
中的每个元素(即第一维)。如何产生具有正确形状的输出?
最佳答案
使用np.take_along_axis
扩展 idx
使其具有与 l
相同的 ndim 后 -
np.take_along_axis(l,np.asarray(idx)[:,None,None],axis=1)
具有显式整数数组索引 -
l[np.arange(len(idx)),idx][:,None] # skip [:,None] for (4,6) shaped o/p
关于python - 从具有索引列表的多维数组中进行选择,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59123991/