python - 检查numpy数组窗口中的元素是否有限的更快方法

标签 python arrays performance numpy

我有一个很长的 NumPy 数组 1_000_000_000元素,我想滑动一个 50数组中的元素窗口,并询问窗口内的所有元素是否都是有限的。如果 50 中的所有元素元素窗口都是有限的然后返回True (对于那个窗口),否则,如果 50 中有一个或多个元素元素窗口不是有限的然后返回 False (对于那个窗口)。继续此评估,直到评估所有窗口。一个很好的方法是:

import numpy as np

def rolling_window(a, window):
    a = np.asarray(a)
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)

    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

if __name__ == "__main__":
    a = np.random.rand(100_000_000)  # This is 10x shorter than my real data
    w = 50
    idx = np.random.randint(0, len(a), size=len(a)//10)  # Simulate having np.nan in my array
    a[idx] = np.nan
    print(np.all(rolling_window(np.isfinite(a), w), axis=1))
但是,当我的数组长度为 1_000_000_000 时,这很慢.有没有一种不需要大量内存的更快的方法来实现这一点?

最佳答案

方法#1:滥用 strided windows 直接进入 isfinite-mask对于任务 -

def strided_allfinite(a, w):
    m = np.isfinite(a)
    p = rolling_window(m, w)
    nmW = ~m[:w]
    if nmW.any():
        m[:np.flatnonzero(nmW).max()] = False
    p[~m[w-1:]] = False
    return m[:-w+1]
给定样本数据的时间:
In [323]: N = 100_000_000
     ...: w = 50
     ...: 
     ...: np.random.seed(0)
     ...: a = np.random.rand(N)  # This is 10x shorter than my real data
     ...: idx = np.random.randint(0, len(a), size=len(a)//10)  # Simulate...
     ...: a[idx] = np.nan

# Original soln
In [324]: %timeit np.all(rolling_window(np.isfinite(a), w), axis=1)
1.61 s ± 14.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [325]: %timeit strided_allfinite(a, w)
556 ms ± 87.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
方法#2
我们可以利用 convolution ——
np.convolve(np.isfinite(a), np.ones(w),'valid')==w
方法#3
binary-erosion ——
from scipy.ndimage.morphology import binary_erosion

m = np.isfinite(a)
out = binary_erosion(m, np.ones(w, dtype=bool))[w//2:len(a)-w+1+w//2]

关于python - 检查numpy数组窗口中的元素是否有限的更快方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64789755/

相关文章:

performance - 何时使用各种语言编译指示和优化?

python - 使用 threading.Lock 作为上下文管理器

python - 为什么 pandas 时间序列重新采样会引发不兼容的频率错误?

java - 在 JTextArea 中显示样式内容

java - 组合来自 2 个数组的唯一整数的最快方法

python - 抓取网站后发送带有附件的电子邮件

python - 用于在 jupyter 中进行内联绘图的 matplotlib 后端是什么

arrays - 如何在 Perl 中合并散列?

arrays - swift - corebluetooth 写入 2 个字节