python - 从 np.argpartition 索引更快地索引 3D NumPy 数组

标签 python arrays numpy

我有一个大型 3D NumPy 数组:

x = np.random.rand(1_000_000_000).reshape(500, 1000, 2000)
对于 500 个二维数组中的每一个,我只需要在每个二维数组的每一列中保留最大的 800 个元素。为了避免昂贵的排序,我决定使用 np.argpartition :
k = 800
idx = np.argpartition(x, -k, axis=1)[:, -k:]
result = x[np.arange(x.shape[0])[:, None, None], idx, np.arange(x.shape[2])]
虽然 np.argpartition相当快,使用 idx索引回 x真的很慢。是否有更快(且内存高效)的方法来执行此索引?
请注意,结果不需要按升序排序。他们只需要成为前 800 名

最佳答案

将大小减少 10 以适应我的内存,以下是各个步骤的时间:
创作:

In [65]: timeit x = np.random.rand(1_000_000_00).reshape(500, 1000, 200)
1.89 s ± 82 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [66]: x = np.random.rand(1_000_000_00).reshape(500, 1000, 200)
In [67]: k=800
种类:
In [68]: idx = np.argpartition(x, -k, axis=1)[:, -k:]
In [69]: timeit idx = np.argpartition(x, -k, axis=1)[:, -k:]

2.52 s ± 292 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 
索引:
In [70]: result = x[np.arange(x.shape[0])[:, None, None], idx, np.arange(x.shape[2])]
In [71]: timeit result = x[np.arange(x.shape[0])[:, None, None], idx, np.arange(x.shape[2])]
The slowest run took 4.11 times longer than the fastest. This could mean that an intermediate result is being cached.
2.6 s ± 1.87 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
所有三个步骤大约花费相同的时间。我没有看到上次索引有什么异常之处。这 0.8 GB。
一个简单的副本,没有索引是将近 1 秒。
In [75]: timeit x.copy()
980 ms ± 231 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
和带有高级索引的完整副本:
In [77]: timeit x[np.arange(x.shape[0])[:, None, None], np.arange(x.shape[1])[:,
    ...: None], np.arange(x.shape[2])]
1.47 s ± 37.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
尝试 idx再次:
In [78]: timeit result = x[np.arange(x.shape[0])[:, None, None], idx, np.arange(x.shape[2])]
1.71 s ± 42.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
请记住,当操作开始使用几乎所有内存时,和/或开始需要对操作系统进行交换和特殊内存请求时,时间可能真的很糟糕。
编辑
您不需要两步过程。只需使用 partition :
out = np.partition(x, -k,axis=1)[:, -k:]
这与 result 相同,并且与 idx 花费的时间相同步。

关于python - 从 np.argpartition 索引更快地索引 3D NumPy 数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69684875/

相关文章:

python - 设置字符串、循环之前和循环之后的不同行为

python - 在 tkinter python 的输入框旁边打包标签

C++如何在不复制数据的情况下添加和到数组?

copy_to_user 一个包含数组(指针)的结构

python - 编译和分发 Cython 扩展

python - 我可以装饰一个显式函数调用吗,比如 np.sqrt()

python - 如何在 Python 中对对象进行排序

python - 多对多关系 禁止直接分配到多对多集合的前向端

javascript - array.split ("stop here") array to array of arrays in javascript 数组

python - 当我尝试在一个图上同时使用一个(线)图和一个条形图制作一个图时,我得到一个奇怪的图