python - 使带有矩阵切片的循环更加高效

标签 python performance

我想让下面的代码更高效,但我不知道如何实现。我只想使用 numpy 和 native python 库。

iterations = 100
aggregation = 0
for i in range(iterations):
    aggregation += np.sum(np.linalg.norm(dat[dat_filter==i] - dat_points[i], axis=1))

dat 是一个 nxD 矩阵 dat_filter 是长度为 n 的向量,包含从 0 到 num_iterations 的标识符 dat_points 是 num_iterators x D 矩阵。

基本上,我正在计算其点属于某个类的矩阵 Dat 与该类的点之间的距离

最佳答案

向量化问题并不容易,因为数据部分的平方根不一定具有相同的长度。您可以对其部分进行矢量化以小幅加速:

import numpy as np

# Make some data
n = 200000
d = 100
iterations = 2000

np.random.seed(42)
dat = np.random.random((n, d))
dat_filter = np.random.randint(0, n_it, size=n)
dat_points = np.random.random((n_it, d))


def slow(dat, dat_filter, dat_points, iterations):
    aggregation = 0
    for i in range(iterations):
        # Wrote linalg.norm as standard numpy operations,
        # such that numba can be used on the code as well
        aggregation += np.sum(np.sqrt(np.sum((dat[dat_filter==i] - dat_points[i])**2, axis=1)))
    return aggregation

def fast(dat, dat_filter, dat_points, iterations):
    # Rearrange the arrays such that the correct operations are done
    sort_idx = np.argsort(dat_filter)
    filtered_dat_squared_sum = np.sum((dat - dat_points[dat_filter])**2, axis=1)[sort_idx]
    # Count the number of different 'iterations'
    counts = np.unique(dat_filter, return_counts=True)[1]
    aggregation = 0 
    idx = 0 
    for c in counts:
        aggregation += np.sum(np.sqrt(filtered_dat_squared_sum[idx:idx+c]))
        idx += c
    return aggregation

时间:

In [1]: %timeit slow(dat, dat_filter, dat_points, n_it)       
3.47 s ± 314 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [2]: %timeit fast(dat, dat_filter, dat_points, n_it)     
846 ms ± 81.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

将 numba 与 slow 函数一起使用会稍微加快速度,但仍然不如 fast 方法快。具有 fast 函数的 Numba 会使我测试的矩阵大小上的调用速度变慢。

关于python - 使带有矩阵切片的循环更加高效,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59601987/

相关文章:

python - 我迷失了尝试为我的 Flask 应用程序编写测试

c++ - 为什么 Core i5-6600 在非方矩阵乘法上比 Core i9-9960X 更快?

c++ - 在基于范围的 for 循环中使用转发引用有什么好处?

c - 我是否使用循环来迭代小尺寸数组(例如 2 或 3)?

python - 将范围指定为仅选择填充单元格/以空单元格结尾 Python

python - WSGI在应用程序的同一目录中找不到文件

python - 重组和重命名 Pandas 数据框中的几列

python - 在 pandas 数据框中找到特定 alpha 的临界值?

c# - 是否有更快的方法来突出显示正则表达式匹配项 (RichTextBox)

performance - 什么是 A* (AStar) 的良好基准?