python - 高效计算 Khatri-Rao 类总和(成对行总和)

标签 python performance numpy multidimensional-array linear-algebra

我正在尝试计算Khatri-Rao就像 sum (即成对行总和)一样,并且能够提出这个解决方案:

In [15]: arr1
Out[15]: 
array([[1, 2, 3],
       [2, 3, 4],
       [3, 4, 5]])

In [16]: arr2
Out[16]: 
array([[11, 12, 13],
       [12, 13, 14],
       [13, 14, 15]])

# for every row in `arr1`, sum it with all rows in `arr2` (in pairwise manner)
In [17]: np.repeat(arr1, arr2.shape[0], 0) + np.tile(arr2, (arr1.shape[0], 1))
Out[17]: 
array([[12, 14, 16],
       [13, 15, 17],
       [14, 16, 18],
       [13, 15, 17],
       [14, 16, 18],
       [15, 17, 19],
       [14, 16, 18],
       [15, 17, 19],
       [16, 18, 20]])

# thus `axis0` in the result will become `arr1.shape[0] * arr2.shape[0]`
In [18]: (np.repeat(arr1, arr2.shape[0], 0) + np.tile(arr2, (arr1.shape[0], 1))).shape
Out[18]: (9, 3)

它工作得很好。但是,我想知道这是否是执行此计算的优化方法。我还计算了(相当)大数组的计算时间

# inputs
In [69]: arr1 = np.arange(9000).reshape(100, 90)
In [70]: arr2 = np.arange(45000).reshape(500, 90)

In [71]: (np.repeat(arr1, arr2.shape[0], 0) + np.tile(arr2, (arr1.shape[0], 1))).shape
Out[71]: (50000, 90)

In [72]: %timeit np.repeat(arr1, arr2.shape[0], 0) + np.tile(arr2, (arr1.shape[0], 1))
22.5 ms ± 420 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

是否可以进一步优化它,也许使用更复杂的方法?

另外,我不完全确定是否 numpy.einsum()可以在这里利用..因为据我了解,它不能用于增加结果数组的形状,这就是这里发生的情况。我欢迎对我的解决方案进行更正、建议和改进:)

最佳答案

我们可以利用广播 -

(arr1[:,None] + arr2).reshape(-1,arr1.shape[1])

对于大型数组,我们可以使用 numexpr 获得进一步的加速。传输广播部分 -

import numexpr as ne

arr1_3D = arr1[:,None]
out = ne.evaluate('arr1_3D + arr2').reshape(-1,arr1.shape[1])

运行时测试 -

In [545]: arr1 = np.random.rand(500,500)

In [546]: arr2 = np.random.rand(500,500)

In [547]: %timeit (arr1[:,None] + arr2).reshape(-1,arr1.shape[1])
1 loop, best of 3: 215 ms per loop

In [548]: %%timeit
     ...: arr1_3D = arr1[:,None]
     ...: out = ne.evaluate('arr1_3D + arr2').reshape(-1,arr1.shape[1])
10 loops, best of 3: 174 ms per loop

关于python - 高效计算 Khatri-Rao 类总和(成对行总和),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48331736/

相关文章:

java - 如何控制 Java 中的文件压缩参数以更快地解压缩对象?

python - PIL : Image. fromarray(img.astype ('uint8' ), mode ='RGB' ) 返回灰度图像

python - 将 Matlab 代码翻译为 Python

python 子进程

python - 当它到达python中的excel文件末尾时退出while循环

python - 超越循环:高性能,大格式的数据文件解析

Python Numpy : Extracting a row from an array

python - 运行脚本的 bash 函数

python - 显示 pygtk 小部件后 5 秒运行函数

android - 如何在 android 中使用自定义 View 减少复杂布局的布局加载时间?