python - 从具有索引列表的多维数组中进行选择

标签 python python-3.x numpy indexing numpy-slicing

假设我有一个大小为 batch x max_len x output_size 的数组,其中 batch, max_lenoutput_size都对应于正自然数。我有一个索引列表,对应于维度 1 中的各个项目(即 max_len)。如何从给定这些索引的数组中进行选择?

作为一个具体的例子,假设我有以下内容:

>>> l = np.random.randn(4,5,6)
>>> l.shape
(4, 5, 6)
>>> idx = [0,0,2,3]

当我选择给定idxl时,我得到:

>>> l[:,idx,:].shape
(4, 4, 6)
>>>

我也尝试了np.take但达到了相同的结果:

>>> np.take(l,idx,axis=1).shape
(4, 4, 6)
>>> 

但是,我正在寻找的输出是(4,1,6),因为我试图仅让一个项目查看批处理中的每个元素(即第一维)。如何产生具有正确形状的输出?

最佳答案

使用np.take_along_axis扩展 idx 使其具有与 l 相同的 ndim 后 -

np.take_along_axis(l,np.asarray(idx)[:,None,None],axis=1)

具有显式整数数组索引 -

l[np.arange(len(idx)),idx][:,None] # skip [:,None] for (4,6) shaped o/p

关于python - 从具有索引列表的多维数组中进行选择,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59123991/

相关文章:

python - 使用 openpyxl 获取空单元格坐标

python - 将箱线图排列为带有 seaborn `FacetGrid` 的网格

python - 重置 Airflow DAG 执行时间

python - 如何向多索引 Pandas 数据帧添加新行和新列?

python - 填充数据框列Python中的缺失值

python通过偏移轮廓/缩小多边形来分离圆形粒子

python - 为什么Python中向量的维数是(N,)而不是(N,1)?

python - 对于具有排序、浮点索引和列的 DataFrame,根据 DataFrame 值使用线性插值计算值

python - 列表索引超出范围(使用 json)

python -/usr/bin/python3 : No module named pytest