python - pandas apply function rowwise 花费太长时间下面的代码有其他选择吗

标签 python pandas

我有一个数据框和如下所示的大函数,我想将norm_group函数应用于数据框列,但它使用apply命令花费了太多时间。有什么办法可以减少这段代码的时间吗?目前每个循环需要 24.4 秒。

import pandas as pd
import numpy as np

np.random.seed(1234)
n = 1500000

df = pd.DataFrame()
df['group'] = np.random.randint(1700, size=n)
df['ID'] = np.random.randint(5, size=n)
df['s_count'] = np.random.randint(5, size=n)
df['p_count'] = np.random.randint(5, size=n)
df['d_count'] = np.random.randint(5, size=n)
df['Total'] = np.random.randint(400, size=n)
df['Normalized_total'] = df.groupby('group')['Total'].apply(lambda x: (x-x.min())/(x.max()- x.min()))
df['Normalized_total'] = df['Normalized_total'].apply(lambda x:round(x,2))

def norm_group(a,b,c,d,e):
if a >= 0.7 and b >=1000 and c >2:
    return "Both High "
elif a >= 0.7 and b >=1000 and c < 2:
    return "High and C Low"
elif a >= 0.4 and b >=500 and d > 2:
    return "Medium and D High"
elif a >= 0.4 and b >=500 and d < 2:
    return "Medium and D Low"
elif a >= 0.4 and b >=500 and e > 2:
    return "Medium and E High"
elif a >= 0.4 and b >=500 and e < 2:
    return "Medium and E Low"
else:
    return "Low"

%timeit df['Categery'] = df.apply(lambda x:norm_group(a=x['Normalized_total'],b=x['group']), axis=1)

每次循环 24.4 秒 ± 551 毫秒(7 次运行的平均值 ± 标准差,每次 1 次循环)

我的原始数据框中有多个文本列,并且想要应用类似的函数,与此相比,该函数需要更多的时间。

谢谢

最佳答案

您可以使用np.select进行矢量化:

df['Category'] = np.select((df['Normalized_total'].ge(0.7) & df['group'].ge(1000),
                            df['Normalized_total'].ge(0.4) & df['group'].ge(500)),
                           ('High', 'Medium'), default='Low'
                          )

性能:

255 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

关于python - pandas apply function rowwise 花费太长时间下面的代码有其他选择吗,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58824407/

相关文章:

python - 使用 Python 确定 SSL 证书是否是自签名的

python - 将 groupby 数据框 reshape 为固定尺寸

python - Pandas 具有多索引的一阶差分面板数据

pandas - 使用非索引变量排序 Seaborn 热图

python - 如何以表格形式显示值,因为现在只显示表格标题

python - 在独特的函数调用下重新缩放 Matplotlib imshow 中的 Axis

python - 无法从 OpenCV Gstreamer 接收 gstreamer UDP 流

python - pyMySQL 设置连接字符集

python - 构造的相同 MultiIndex DataFrame 不会聚合(意味着)

python - 如何使用波浪号运算符将 R 代码重写为 Python Pandas?