python - 有机会让它更快吗? (numpy.einsum)

标签 python numpy

我正在尝试将三个数组 (A x B x A) 与维度 (19000, 3) x (19000, 3, 3) x (19000, 3) 相乘,这样最后我得到一个大小为 (19000) 的一维数组,所以我只想沿着最后一个/两个维度相乘。

我已经让它与 np.einsum() 一起工作,但我想知道是否有任何方法可以让它更快,因为这是我整个代码的瓶颈。

np.einsum('...i,...ij,...j', A, B, A)

我已经用两个单独的 np.einsum() 调用尝试过,但这给了我相同的性能:

np.einsum('...i, ...i', np.einsum('...i,...ij', A, B), A)

此外,我已经尝试过 @ 运算符并添加了一些额外的轴,但这并没有使速度更快:

(A[:, None]@B@A[...,None]).squeeze()

我试图让它与 np.inner()、np.dot()、np.tensordot() 和 np.vdot() 一起工作,但是这些从来没有给我相同的结果,所以我不能比较一下。

还有其他想法吗?有什么方法可以获得更好的性能?

我已经快速浏览了 Numba,但由于 Numba 不支持 np.einsum() 和许多其他 NumPy 函数,我将不得不重写大量代码。

最佳答案

你可以使用 Numba

一开始总是一个好主意,看看 np.einsum 做了什么。使用 optimize==optimal,找到一种收缩方式通常非常好,它具有更少的 FLOP。在这种情况下,实际上只有一个小的优化可能,中间数组相对较大(我会坚持使用朴素的版本)。还应该提到的是,尺寸非常小(固定?)的收缩是一种非常特殊的情况。这也是为什么在这里很容易胜过 np.einsum 的原因(展开等...,如果编译器知道循环仅包含 3 个元素,编译器就会这样做)

import numpy as np

A=np.random.rand(19000, 3)
B=np.random.rand(19000, 3, 3)

print(np.einsum_path('...i,...ij,...j', A, B, A,optimize="optimal")[1])

"""
  Complete contraction:  si,sij,sj->s
         Naive scaling:  3
     Optimized scaling:  3
      Naive FLOP count:  5.130e+05
  Optimized FLOP count:  4.560e+05
   Theoretical speedup:  1.125
  Largest intermediate:  5.700e+04 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   3                  sij,si->js                                 sj,js->s
   2                    js,sj->s                                     s->s

"""

Numba 实现

import numba as nb

#si,sij,sj->s
@nb.njit(fastmath=True,parallel=True,cache=True)
def nb_einsum(A,B):
    #check the input's at the beginning
    #I assume that the asserted shapes are always constant
    #This makes it easier for the compiler to optimize 
    assert A.shape[1]==3
    assert B.shape[1]==3
    assert B.shape[2]==3

    #allocate output
    res=np.empty(A.shape[0],dtype=A.dtype)

    for s in nb.prange(A.shape[0]):
        #Using a syntax like that is also important for performance
        acc=0
        for i in range(3):
            for j in range(3):
                acc+=A[s,i]*B[s,i,j]*A[s,j]
        res[s]=acc
    return res

时间

#warmup the first call is always slower 
#(due to compilation or loading the cached function)
res=nb_einsum(A,B)
%timeit nb_einsum(A,B)
#43.2 µs ± 1.22 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit np.einsum('...i,...ij,...j', A, B, A,optimize=True)
#450 µs ± 8.28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.einsum('...i,...ij,...j', A, B, A)
#977 µs ± 4.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
np.allclose(np.einsum('...i,...ij,...j', A, B, A,optimize=True),nb_einsum(A,B))
#True

关于python - 有机会让它更快吗? (numpy.einsum),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63284736/

相关文章:

python - 如何生成彼此不相交的正方形(位置随机,大小相等,随机旋转)?

arrays - 可见的弃用警告...?

python - 将对称矩阵(二维数组)的上/下三角部分转换为一维数组并将其返回为二维格式

python - 如何在通过将 txt 文件加载到 scipy 程序形成的图上的特定间隔内插入点?

python - 如何从图像像素阵列创建平均 rgb 向量?

python - 在 VS Code 中,可以在集成的 Python 终端(如 Spyder 中)中运行 Python 代码吗?

c++ - 正则表达式:在开头以外的任何地方接受空格

python - 如何为每个ID分配一个组号(n=1,2,3.....)?

python - 如何查找组中的缺失值

memory - 转置数组并实际重新排序内存