我面临性能问题,因为一段 Python 代码应该执行以下操作:
我有 2 个带有未排序值的数组 A 和 B,我想构造一个新数组 C,其中每个索引包含以下内容:
C[i]= sum(flag*B[k] for k so that flag = A[k]<=A[i])
我通过两种方式做到了:
1)非常直接的方式:
M = len(A)
C = np.zeros(M)
for i in xrange(M):
value = A[i]
flag = A <= value
C[i] = np.sum(flag * B)
2)尝试使用 numpy 排序函数:
indices_sorted = np.argsort(A)
C_sort = np.zeros(M)
for i in xrange(M):
index = np.where(indices_sorted==i)
for k in xrange(index[0][0]+1):
C_sort[i] += B[indices_sorted[k]]
结果是,对于 5000 个元素的数组,第一个要快得多(因子 40-50)。
我没想到第二次会那么糟糕,第一次尝试也不够快......
你们能给我一个更好的方法吗?
提前致谢。
最佳答案
假设A
和B
要成为相同形状的一维数组,您可以使用 broadcasting
通过扩展A
到二维数组,然后进行比较,从而基本上以矢量化方式将每个元素与其他每个元素进行比较。然后,执行元素乘法 B
,又在哪里broadcasting
发挥作用。最后沿着第二个轴求和得到最终输出。实现看起来像这样 -
C = ((A <= A[:,None])*B).sum(1)
您可以模拟 elementwise multiplication and summing
的相同行为与 matrix-multiplication
使用 np.dot
一个更有效的解决方案,就像这样 -
C = (A <= A[:,None]).dot(B)
这是另一种基于 np.take
索引的方法并用 np.bincount
进行计数-
row,col = np.nonzero(A <= A[:,None])
C = np.bincount(row,np.take(B,col))
对于巨大的数据量,创建2D
的内存开销面具(A <= A[:,None]
可能会抵消性能。因此,作为对现有循环代码的优化,您可以引入 matrix-multiplication
代替元素乘法和求和。因此,np.sum(flag * B)
可以替换为 flag.dot(B)
。引入一些其他优化技巧,您将得到像这样的修改版本 -
M = len(A)
C = np.empty(M)
for i in xrange(M):
C[i] = (A <= A[i]).dot(B)
最后!这是获胜者 np.cumsum
-
idx = A.argsort()
C = B[idx].cumsum()[idx.argsort()]
以下是其工作方式和原因的快速说明:
您正在执行逐元素比较,然后根据比较结果对 B 中的元素求和。现在,如果 A
是一个排序数组,然后输出 C
本质上是 cumsum
B
的版本。因此,对于一般的未排序情况,您需要排序 B
通过 A
的 argsort ,执行cumsum
其上,最后根据原始未排序的顺序重新排列元素。
运行时测试
定义方法 -
def org_app(A,B):
M = len(A)
C = np.zeros(M)
for i in range(M):
value = A[i]
flag = A <= value
C[i] = np.sum(flag * B)
return C
def sum_based(A,B):
return ((A <= A[:,None])*B).sum(1)
def dot_based(A,B):
return (A <= A[:,None]).dot(B)
def bincount_based(A,B):
row,col = np.nonzero(A <= A[:,None])
return np.bincount(row,np.take(B,col))
def org_app_modified(A,B):
M = len(A)
C = np.empty(M)
for i in xrange(M):
C[i] = (A <= A[i]).dot(B)
return C
def cumsum_trick(A,B):
idx = A.argsort()
return B[idx].cumsum()[idx.argsort()]
设置输入和计时 -
In [212]: # Inputs
...: N = 5000
...: A = np.random.rand(N)
...: B = np.random.rand(N)
...:
In [213]: %timeit org_app(A,B)
...: %timeit sum_based(A,B)
...: %timeit dot_based(A,B)
...: %timeit bincount_based(A,B)
...: %timeit org_app_modified(A,B)
...: %timeit cumsum_trick(A,B)
...:
1 loops, best of 3: 266 ms per loop
1 loops, best of 3: 411 ms per loop
1 loops, best of 3: 322 ms per loop
1 loops, best of 3: 1.01 s per loop
10 loops, best of 3: 196 ms per loop
1000 loops, best of 3: 835 µs per loop
关于python - 高效的一维数组比较、缩放和求和,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33923805/