python - 可广播的 Numpy 点

标签 python algorithm numpy matrix-multiplication blas

我有一个维度为 (n0, n2) 的数组 H 和一个维度为 (n0, n1, n2, n3) 的数组 W,我想执行以下操作:

(H[:, None, :, None] * W).sum(axis=(0, 2))

据我所知,上面一行没有使用 BLAS 库。有没有一种方法可以使用 numpy.dot 或使用 BLAS 进行相同计算的类似函数(并且仍然无需在内存中多次复制数组 H)?

最佳答案

您已经确定了一种方法;我知道另外两个。

举个小例子

In [365]: n0,n1,n2,n3=2,3,4,5
In [366]: H=np.ones((n0,n2));W=np.ones((n0,n1,n2,n3))

比较时间是:

In [362]: timeit np.tensordot(H,W,[(0,1),(0,2)])
10000 loops, best of 3: 32.8 µs per loop

In [363]: timeit np.einsum('ik,ijkl',H,W)
100000 loops, best of 3: 10.7 µs per loop

In [364]: timeit (H[:,None,:,None]*W).sum(axis=(0,2))
10000 loops, best of 3: 29.5 µs per loop

tensordot reshape 和转置输入,以便它可以调用 np.doteinsum 解码字符串,并在 C 中执行它自己的 nditer

https://stackoverflow.com/a/31129207/901925具有另一个多维的时间,涉及(100,)*(10,100,100)*(100,)数组。

关于python - 可广播的 Numpy 点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31151202/

相关文章:

algorithm - Order Of Growth 复杂的循环

c++ - 数论算法

python - numpy.ndindex 和数组切片

python - 以编程方式创建图像的最佳方式

python - 统计groupby中每年出现的次数

python - 将字符串转换为 HH :MM time in Python

python - 类型错误 : 'module' object is not callable . MFCC

python - Outlook 中对话历史记录文件夹的 API 是什么?

python - twoSum 找到所有可能的唯一对

python - 增加窗口的滚动均值