python - 如何有效地迭代 pandas DataFrame 并在这些值上递增 NumPy 数组?

标签 python python-3.x pandas numpy

我的pandas/numpy生疏了,感觉自己写的代码效率低下。

我正在 Python3.x 中初始化一个 numpy 零数组,长度为 1000。为了我的目的,这些只是整数:

import numpy as np
array_of_zeros =  np.zeros((1000, ), )

我还有下面的DataFrame(比我的实际数据小很多)

import pandas as pd
dict1 = {'start' : [100, 200, 300], 'end':[400, 500, 600]}
df = pd.DataFrame(dict1)
print(df)
##
##    start     end
## 0    100     400
## 1    200     500
## 2    300     600

DataFrame 有两列,startend。这些值表示一个值的范围,即 start 将始终是一个小于 end 的整数。在上面,我们看到第一行的范围是 100-400,接下来是 200-500,然后是 300-600

我的目标是逐行遍历 pandas DataFrame,并根据这些索引位置递增 numpy 数组 array_of_zeros。因此,如果 1020 的数据帧中有一行,我想将索引 10-20 的零增加 +1。

这是我想要的代码:

import numpy as np
array_of_zeros =  np.zeros((1000, ), )

import pandas as pd
dict1 = {'start' : [100, 200, 300], 'end':[400, 500, 600]}
df = pd.DataFrame(dict1)
print(df)

for idx, row in df.iterrows():
    for i in range(int(row.start), int(row.end)+1):
        array_of_zeros[i]+=1

而且有效!

print(array_of_zeros[15])
## output: 0.0
print(array_of_zeros[600])
## output: 1.0
print(array_of_zeros[400])
## output: 3.0
print(array_of_zeros[100])
## output: 1.0
print(array_of_zeros[200])
## output: 2.0

我的问题:这是非常笨拙的代码!我不应该对 numpy 数组使用这么多 for 循环!如果输入数据帧非常大,此解决方案将非常低效

是否有更有效(即更基于 numpy)的方法来避免这种 for 循环?

for i in range(int(row.start), int(row.end)+1):
    array_of_zeros[i]+=1

也许有一个面向 pandas 的解决方案?

最佳答案

您可以使用 NumPy 数组索引来避免内部循环,即 res[np.arange(A[i][0], A[i][1]+1)] += 1,但这效率不高,因为它涉及创建新数组和使用高级索引。

相反,您可以使用 numba1 来优化您的算法,完全按照原样进行。下面的示例展示了通过将性能关键型逻辑移至 JIT 编译代码后性能的巨大提升。

from numba import jit

@jit(nopython=True)
def jpp(A):
    res = np.zeros(1000)
    for i in range(A.shape[0]):
        for j in range(A[i][0], A[i][1]+1):
            res[j] += 1
    return res

一些基准测试结果:

# Python 3.6.0, NumPy 1.11.3

# check result the same
assert (jpp(df[['start', 'end']].values) == original(df)).all()
assert (pir(df) == original(df)).all()
assert (pir2(df) == original(df)).all()

# time results
df = pd.concat([df]*10000)

%timeit jpp(df[['start', 'end']].values)  # 64.6 µs per loop
%timeit original(df)                      # 8.25 s per loop
%timeit pir(df)                           # 208 ms per loop
%timeit pir2(df)                          # 1.43 s per loop

用于基准测试的代码:

def original(df):
    array_of_zeros = np.zeros(1000)
    for idx, row in df.iterrows():
        for i in range(int(row.start), int(row.end)+1):
            array_of_zeros[i]+=1   
    return array_of_zeros

def pir(df):
    return np.bincount(np.concatenate([np.arange(a, b + 1) for a, b in \
                       zip(df.start, df.end)]), minlength=1000)

def pir2(df):
    a = np.zeros((1000,), np.int64)
    for b, c in zip(df.start, df.end):
        np.add.at(a, np.arange(b, c + 1), 1)
    return a

1 为了后代,我加入了@piRSquared 关于为什么 numba 在这里有帮助的精彩评论:

numba's advantage is looping very efficiently. Though it can understand much of NumPy's API, it is often better to avoid creating NumPy objects within a loop. My code is creating a NumPy array for every row in the dataframe. Then concatenating them prior to using bincount. @jpp's numba code creates very little extra objects and utilizes much of what is already there. The difference between my NumPy solution and @jpp's numba solution is about 4-5 times. Both are linear and should be pretty quick.

关于python - 如何有效地迭代 pandas DataFrame 并在这些值上递增 NumPy 数组?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52101378/

相关文章:

PYTHON - 不可哈希列表

python - 为多个 Python 文件编写单元测试

python - sum(1 for c in sentence if c.isupper())) 在非编程术语中是什么意思

python - 为什么转置数据以获得多索引数据帧?

python - 在 matplotlib 3d 散点图中更改数据点的颜色并通过按键将其删除

Python、 `let`、 `with`、局部作用域、调试打印和临时变量

python - Postgresql插入多列格式错误的数组文字

python - 仍然找到安装了 pyenv 的系统 python - 安装模块

python - 重命名非常大的 CSV 数据文件的列

python - 使用 pandas 获取每个唯一行项目的时间差