python - Numpy 切片慢?

标签 python arrays numpy matrix numba

您好,我正在使用 numpy + numba 运行科学计算。 我已经意识到 numpy 数组就地添加非常慢......与 matlab 相比

这是matlab代码:

tic;
% A,B are 2-d matrices, ind may not be distinct
for ii=1:N 
    A(ind(ii),:) =  A(ind(ii),:) +  B(ii,:);
end
toc;

这是 numpy 代码:

s = time.time()
# A,B are numpy.ndarray, ind may not be distinct
for k in xrange(N):
     A[ind[k],:] += B[k,:];
print time.time() - s

结果显示 numpy 代码比 matlab 慢 10 倍...这让我很困惑。

此外,当我将加法从 for 循环中拉出,并仅将单个矩阵加法与 numpy.add 进行比较时,numpy 和 matlab 在速度上似乎具有可比性。

我知道的一个因素是 matlab 使用 JIT for version>=2012a 来加速 for 循环,但我在 python 代码上尝试了 numba,它仍然没有加速。我认为这与 numba 根本没有触及 numpy.add 函数有关,因此性能根本没有改变。

我猜测 matlab 为这种情况做了一些糟糕的缓存,因此它大大击败了 numpy。

关于如何加速 numpy 有什么建议吗?

最佳答案

尝试

A[ind] += B[:N]

即没有任何循环。

如果 ind 可能有重复的元素,您可以使用 np.add.at :

np.add.at(A, ind, B[:N])

关于python - Numpy 切片慢?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/19402069/

相关文章:

python - 检查 float 是否接近存储在数组中的任何 float

python - 我如何在 Numpy 中向量化这个双 for 循环?

python - 如何将高斯法线与直方图相匹配?

python原始字符串赋值

python - 使用 Jinja 按嵌套字典值过滤

javascript - 如何在我的网站上实现实时更新(使用 Flask)?

javascript - 推送复合数组元素

python - 如何管理 python 线程结果?

ios - 从文本字段在 tableView 中添加新单元格

python - numpy randint 和 floor of rand 之间的区别