给定一个大的多列 Pandas 数据框,我想在 N
的窗口上计算滚动“k-mean”元素尽快。
这里“k-mean”定义为N-2k
的均值N
的元素排除 k
最大和k
最小的元素。
例子
给定数据框:
df = pandas.DataFrame(
{'A': [34, 78, -2, -96, 58, -34, 44, -50, 42],
'B': [-82, 28, 96, 46, 36, -34, -20, 10, -40]})
A B
0 34 -82
1 78 28
2 -2 96
3 -96 46
4 58 36
5 -34 -34
6 44 -20
7 -50 10
8 42 -40
与 N=6
和 k=1
预期输出是: A B
0 NaN NaN
1 NaN NaN
2 NaN NaN
3 NaN NaN
4 NaN NaN
5 14.0 19.0
6 16.5 22.5
7 -10.5 18.0
8 0.5 -2.0
试图我的代码似乎符合要求:
def k_mean(s: pandas.Series, trim: int) -> float:
assert trim >= 0, f"Trim must not be negative, {trim} provided."
if trim == 0:
return s.mean()
return s.sort_values()[trim:-trim].mean()
df.rolling(window=6, axis=0).apply(k_mean, kwargs={'trim': 1})
我的问题 :我的代码是否正确,如果正确,是否有更快的方法来实现相同的结果,尤其是考虑到大型多列数据帧?也许有一个巧妙的数学技巧可以提供帮助?
如果它有助于加快性能,我不太关心启动期的处理,或者可以是 NaN 直到
N
或者可以增长到N
一次 2k+1
元素在窗口中。
最佳答案
您可以使用 Numba JIT 显着加快代码速度。
主要思想是将每一列转换为 Numpy 数组 然后使用滑动窗口迭代它们。
import pandas
import numpy
import numba
# Note:
# You can declare the Numba function parameters types to reduce compilation time:
# @numba.njit('float64[::1](int64[::1], int64, int64)')
@numba.njit
def col_k_mean(arr: numpy.array, window: int, trim: int):
out = numpy.full(len(arr), numpy.nan)
if trim == 0:
localSum = arr[0:window].sum()
windowInv = 1.0 / window
for i in range(window-1, len(arr)-1):
out[i] = localSum * windowInv
localSum += arr[i+1] - arr[i-window+1]
if window-1 <= len(arr)-1:
out[len(arr)-1] = localSum * windowInv
else:
for i in range(window-1, len(arr)):
out[i] = numpy.sort(arr[i-window+1:i+1])[trim:-trim].mean()
return out
def apply_k_mean(df: pandas.DataFrame, window: int, trim: int) -> pandas.DataFrame:
assert trim >= 0, f"Trim must not be negative, {trim} provided."
return pandas.DataFrame({col: col_k_mean(df[col].to_numpy(), window, trim) for col in df})
apply_k_mean(df, window=6, trim=1)
请注意,此方法仅在窗口不大时才有效。对于巨大的窗口,最好使用更高级的排序策略,例如基于优先级队列(使用堆)或更一般的增量排序的排序策略。或者,如果 trim
很小而且window
是巨大的,那么可以使用 2 个分区而不是完整的排序。在我的机器上,随机数据帧大小为
(2, 10000)
并与 window=10
以及 trim=2
,上面的代码是快 300 倍 比引用实现(不包括 JIT 编译时间)!与 trim=0
,是快 5800 倍 !使用 在巨大数据帧上的计算速度会更快并行性 (在 Numba 中同时使用
parallel=True
和 prange
支持)。
关于python - 数据帧的排序子数组的滚动平均值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67618680/