我有以下代码来计算包含 1 和 0 的 2D numpy 数组中 1 的平均位置。问题是它非常慢,我想知道是否可以有更快的方法?
row_sum = 0
col_sum = 0
ones_count = 0
for row_count, row in enumerate(array):
for col_count, col in enumerate(row):
if col == 1:
row_sum += row_count
col_sum += col_count
ones_count += 1
average_position_ones = (row_sum / ones_count, col_sum / ones_count)
最佳答案
这里有 3 种更快计算的方法 row_sum
, col_sum
和ones_count
.
基线
为了测试,我使用这个数组
import numpy as np
import numba as nb
np.random.seed(1)
n = 10**4
array = np.random.randint(0,2,(n,n))
现在您的确切代码为 20.3 s ± 397 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
在我的机器上。
Lazy One Liner Numpy 版本:
%timeit np.stack(np.where(array)).sum(axis=1),array.sum()
需要1.13 s ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
在我的机器上。
这里np.stack(np.where(array)).sum(axis=1)
就是你所说的row_sum
和col_sum
和array.sum()
给你 ones_count
避免循环抛出两次
您可以使用您的确切代码 numba.jit
@nb.njit
def test():
row_sum = 0
col_sum = 0
ones_count = 0
for row_count, row in enumerate(array):
for col_count, col in enumerate(row):
if col == 1:
row_sum += row_count
col_sum += col_count
ones_count += 1
return row_sum,col_sum,ones_count
%timeit test()
这有点快。需要 50 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
在我的机器上。但绝对不值得付出努力。
多核版本
对代码稍作修改即可使用 numba
运行多线程
@nb.njit(parallel=True)
def test2():
row_sum = 0
col_sum = 0
ones_count = 0
for row_count in nb.prange(len(array)):
row = array[row_count]
for col_count, col in enumerate(row):
if col == 1:
row_sum += row_count
col_sum += col_count
ones_count += 1
return row_sum,col_sum,ones_count
%timeit test2()
现在,与懒惰的 numpy
相比,这确实提供了一点速度。版本。需要13.3 ms ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
在我的 10 核机器上。虽然它没有使用全部 10 个核心。
请注意,并行修改内容时必须小心。您可以创建一个竞争条件。而这里的情况并非如此,只是因为 numba
针对本案具体情况采取应对措施。
进一步优化
正如 Jérôme Richard 在评论中指出的那样。最后一个版本可以通过使用 uint8 代替默认的 int64 来进行优化。只需调用.astype(np.uint8)
在阵列上。然后需要 9.38 ms ± 935 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
在我的机器上。
关于python - 计算数组中某个值的平均位置快速方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71882122/