python - 使用 np.argpartition 对多维数组中的值进行索引

标签 python numpy

我有一个这样的数组:

>>> a = np.arange(60).reshape([3,4,5])
>>> a
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

我想检索其中一个维度的前 k 个值。例如,我将选择 k=2 并沿着中间维度。

我试过使用 argpartition 并且它似乎做了正确的事情,但是我在使用它的输出从原始数组中检索值时遇到了问题。这是我使用 argpartition 的方式:

>>> indices = np.argpartition(a, 2, axis=1)
>>> indices
array([[[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]]])

>>> indices[:,-2:,:]
array([[[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]]])

但我无法通过使用这些索引进行切片来获取值。

>>> a[:,indices[:,-2:,:],:].shape
(3, 3, 2, 5, 5)

我期待看到形状为 (3,2,5) 的数组(因为我正在寻找沿中轴的前 2 个),我想它看起来像这样:

>>> magic_output
array([[[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

如何使用 argpartition 中的索引访问值?

最佳答案

np.argpartition 得到最小的 k 索引。因此,要获得前 k 索引,我们需要沿所需轴使用否定输入数组。然后,我们需要使用这些索引通过 NumPy 的高级索引 索引到该轴并获得所需的输出。

因此,实现将是 -

k = 2
m,n = a.shape[0], a.shape[2]
idx = np.argpartition(-a,k,axis=1)[:,k-1::-1]
out = a[np.arange(m)[:,None,None], idx, np.arange(n)]

sample 运行-

1)输入:

In [180]: a
Out[180]: 
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

2) 提议的代码:

In [206]: k = 2
     ...: m,n = a.shape[0], a.shape[2]
     ...: idx = np.argpartition(-a,k,axis=1)[:,k-1::-1]
     ...: out = a[np.arange(m)[:,None,None], idx, np.arange(n)]
     ...: 

3) 检查中间结果和输出:

In [207]: idx
Out[207]: 
array([[[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]]])

In [208]: out
Out[208]: 
array([[[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

关于python - 使用 np.argpartition 对多维数组中的值进行索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42832711/

相关文章:

python 3, headless RaspPi,python-evdev 不可能使用语言环境 de_DE

python:找到一组的所有拉丁方(或列数较少的部分方)

python - BeautifulSoup 查找找到的标签后的下一个特定标签

python - 在 DateTime 系列中使用 np.select - Pandas Python

python - 使用Python改变numpy数组的维度

python - 在 Fipy 中求解多个偏微分方程

Python 分组两个列表

python - 如何将浮点列表的 2d np.array 转换为浮点的 2d np.array,将列表值堆叠到行

python - 处理图像时如何标准化 scipy 的 convolve2d?

python - 为什么 dill 比 numpy 数组的 pickle 更快更磁盘高效