python - 数据帧的排序子数组的滚动平均值

标签 python pandas performance

给定一个大的多列 Pandas 数据框,我想在 N 的窗口上计算滚动“k-mean”元素尽快。
这里“k-mean”定义为N-2k的均值N 的元素排除 k最大和k最小的元素。
例子
给定数据框:

df = pandas.DataFrame(
{'A': [34, 78, -2, -96, 58, -34, 44, -50, 42],
 'B': [-82, 28, 96, 46, 36, -34, -20, 10, -40]})

    A   B
0  34 -82
1  78  28
2  -2  96
3 -96  46
4  58  36
5 -34 -34
6  44 -20
7 -50  10
8  42 -40
N=6k=1预期输出是:
      A     B
0   NaN   NaN
1   NaN   NaN
2   NaN   NaN
3   NaN   NaN
4   NaN   NaN
5  14.0  19.0
6  16.5  22.5
7 -10.5  18.0
8   0.5  -2.0
试图
我的代码似乎符合要求:
def k_mean(s: pandas.Series, trim: int) -> float:
    assert trim >= 0, f"Trim must not be negative, {trim} provided."
    if trim == 0:
        return s.mean()
    return s.sort_values()[trim:-trim].mean()

df.rolling(window=6, axis=0).apply(k_mean, kwargs={'trim': 1})
我的问题 :我的代码是否正确,如果正确,是否有更快的方法来实现相同的结果,尤其是考虑到大型多列数据帧?
也许有一个巧妙的数学技巧可以提供帮助?
如果它有助于加快性能,我不太关心启动期的处理,或者可以是 NaN 直到 N或者可以增长到N一次 2k+1元素在窗口中。

最佳答案

您可以使用 Numba JIT 显着加快代码速度。
主要思想是将每一列转换为 Numpy 数组 然后使用滑动窗口迭代它们。

import pandas
import numpy
import numba

# Note:
# You can declare the Numba function parameters types to reduce compilation time:
# @numba.njit('float64[::1](int64[::1], int64, int64)')

@numba.njit
def col_k_mean(arr: numpy.array, window: int, trim: int):
    out = numpy.full(len(arr), numpy.nan)
    if trim == 0:
        localSum = arr[0:window].sum()
        windowInv = 1.0 / window
        for i in range(window-1, len(arr)-1):
            out[i] = localSum * windowInv
            localSum += arr[i+1] - arr[i-window+1]
        if window-1 <= len(arr)-1:
            out[len(arr)-1] = localSum * windowInv
    else:
        for i in range(window-1, len(arr)):
            out[i] = numpy.sort(arr[i-window+1:i+1])[trim:-trim].mean()
    return out

def apply_k_mean(df: pandas.DataFrame, window: int, trim: int) -> pandas.DataFrame:
    assert trim >= 0, f"Trim must not be negative, {trim} provided."
    return pandas.DataFrame({col: col_k_mean(df[col].to_numpy(), window, trim) for col in df})

apply_k_mean(df, window=6, trim=1)
请注意,此方法仅在窗口不大时才有效。对于巨大的窗口,最好使用更高级的排序策略,例如基于优先级队列(使用堆)或更一般的增量排序的排序策略。或者,如果 trim很小而且window是巨大的,那么可以使用 2 个分区而不是完整的排序。
在我的机器上,随机数据帧大小为 (2, 10000)并与 window=10以及 trim=2 ,上面的代码是快 300 倍 比引用实现(不包括 JIT 编译时间)!与 trim=0 ,是快 5800 倍 !
使用 在巨大数据帧上的计算速度会更快并行性 (在 Numba 中同时使用 parallel=Trueprange 支持)。

关于python - 数据帧的排序子数组的滚动平均值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67618680/

相关文章:

c++ - 为什么这些矩阵转置时间如此违反直觉?

mysql - 多列索引结合唯一索引效率

python - 打开 CSV IOError : [Errno 13] Permission Denied

python - 这个mergesort实现的时间复杂度是O(nlogn)吗?

python - 如何将多列重新排列成具有相同索引的一列

python - Pandas 字符串标记化太慢

python - 无循环的3x3矩阵转换(RGB颜色转换)

python - 按键分组并创建相应值的列表

python - 在字典中绘制嵌套列表 - python

performance - WebSphere SIB与MQ?从WebSphere上运行的J2EE应用程序中进行异步消息传递的最佳选择是什么?