python - 计算数组中某个值的平均位置快速方法

标签 python arrays numpy

我有以下代码来计算包含 1 和 0 的 2D numpy 数组中 1 的平均位置。问题是它非常慢,我想知道是否可以有更快的方法?

row_sum = 0
col_sum = 0
ones_count = 0

for row_count, row in enumerate(array):
    for col_count, col in enumerate(row):
        if col == 1:
            row_sum += row_count
            col_sum += col_count
            ones_count += 1

average_position_ones = (row_sum / ones_count, col_sum / ones_count)

最佳答案

这里有 3 种更快计算的方法 row_sum , col_sumones_count .

基线

为了测试,我使用这个数组

import numpy as np
import numba as nb

np.random.seed(1)

n = 10**4
array = np.random.randint(0,2,(n,n))

现在您的确切代码为 20.3 s ± 397 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)在我的机器上。

Lazy One Liner Numpy 版本:

%timeit np.stack(np.where(array)).sum(axis=1),array.sum()需要1.13 s ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)在我的机器上。

这里np.stack(np.where(array)).sum(axis=1)就是你所说的row_sumcol_sumarray.sum()给你 ones_count

避免循环抛出两次

您可以使用您的确切代码 numba.jit

@nb.njit
def test():
    row_sum = 0
    col_sum = 0
    ones_count = 0

    for row_count, row in enumerate(array):
        for col_count, col in enumerate(row):
            if col == 1:
                row_sum += row_count
                col_sum += col_count
                ones_count += 1

    return row_sum,col_sum,ones_count

%timeit test()

这有点快。需要 50 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)在我的机器上。但绝对不值得付出努力。

多核版本

对代码稍作修改即可使用 numba 运行多线程

@nb.njit(parallel=True)
def test2():
    row_sum = 0
    col_sum = 0
    ones_count = 0
    
    for row_count in nb.prange(len(array)):
        row = array[row_count]
        for col_count, col in enumerate(row):
            if col == 1:
                row_sum += row_count
                col_sum += col_count
                ones_count += 1

    return row_sum,col_sum,ones_count

%timeit test2()

现在,与懒惰的 numpy 相比,这确实提供了一点速度。版本。需要13.3 ms ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)在我的 10 核机器上。虽然它没有使用全部 10 个核心。

请注意,并行修改内容时必须小心。您可以创建一个竞争条件。而这里的情况并非如此,只是因为 numba针对本案具体情况采取应对措施。

进一步优化

正如 Jérôme Richard 在评论中指出的那样。最后一个版本可以通过使用 uint8 代替默认的 int64 来进行优化。只需调用.astype(np.uint8)在阵列上。然后需要 9.38 ms ± 935 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)在我的机器上。

关于python - 计算数组中某个值的平均位置快速方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71882122/

相关文章:

python - 对于 AI 和数据挖掘的东西,我应该学习哪一组 python 库

Android Arrays 内容无故更改

C - 在输入流中分隔字符串

python - 如何异步运行 sublime text3 插件命令

python - 如何增加 QListWidget 中缩略图的大小

python - 使用 Python/NumPy 对图像阈值进行矢量化

python - numpy数组中的np.nan是否占用内存?

python - 取平均值时如何避免在 2d numpy 数组中被零除?

python - matplotlib - 基于附加数据的绘图的多个标记

javascript - 关联数组 : define the undefined?