我有大量(准确地说是 26,214,400 个)数据集,我想对其执行线性回归,即 26,214,400 个数据集中的每个数据集都包含 n
x 值和 n
y 值,我想找到 y = m * x + b
。对于任何一组点,我都可以使用 sklearn
或 numpy.linalg.lstsq
,类似于:
A = np.vstack([x, np.ones(len(x))]).T
m, b = np.linalg.lstsq(A, y, rcond=None)[0]
有没有一种方法可以设置矩阵,从而避免 python 循环遍历 26,214,400 个项目?还是我必须使用循环,使用 Numba 之类的东西会更好?
最佳答案
我最终选择了 numba
路线,它在我的笔记本电脑上产生了大约 20 倍的速度,它使用了我所有的内核,所以我认为 CPU 越多越好。答案看起来像
import numpy as np
from numpy.linalg import lstsq
import numba
@numba.jit(nogil=True, parallel=True)
def fit(XX, yy):
""""Fit a large set of points to a regression"""
assert XX.shape == yy.shape, "Inputs mismatched"
n_pnts, n_samples = XX.shape
scale = np.empty(n_pnts)
offset = np.empty(n_pnts)
for i in numba.prange(n_pnts):
X, y = XX[i], yy[i]
A = np.vstack((np.ones_like(X), X)).T
offset[i], scale[i] = lstsq(A, y)[0]
return offset, scale
运行它:
XX, yy = np.random.randn(2, 1000, 10)
offset, scale = fit(XX, yy)
%timeit offset, scale = fit(XX, yy)
1.87 ms ± 37.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
非 jitted 版本有这个时间:
41.7 ms ± 620 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
关于python - 使用 numpy 进行大量回归的有效方法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61965827/