python - NumPy:一次对许多小矩阵进行点积

标签 python arrays numpy vectorization dot-product

我有一长串 3×3 矩阵,例如,

import numpy as np

A = np.random.rand(25, 3, 3)

对于每个小矩阵,我想执行一个外积 dot(a, a.T)。列表理解

import numpy as np

B = np.array([
    np.dot(a, a.T) for a in A
    ])

有效,但表现不佳。一个可能的改进可能是只做一个dot产品,但我在为它正确设置A时遇到了麻烦。

有什么提示吗?

最佳答案

您可以获得转置矩阵列表作为A.swapaxes(1, 2),您想要的乘积列表作为A @A.swapaxes(1, 2).

import numpy as np

A = np.random.rand(25, 3, 3)

B = np.array([
    np.dot(a, a.T) for a in A
    ])

C = A @ A.swapaxes(1, 2)

(B==C).all()     # => True

@ operator只是 np.matmul 的语法糖,关于它 documentation说“如果任一参数是 N-D,N > 2,它被视为驻留在最后两个索引中的一堆矩阵并相应地广播。

关于python - NumPy:一次对许多小矩阵进行点积,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38110296/

相关文章:

javascript - MVC .NET 将带有 Jquery 自动完成功能的 ViewBag 字符串 [] 数组传递给服务器 ActionResult

python - Numpy:找到蒙版边缘的索引

python-3.x - 使用 numpy 数组比较两个相似的 PIL 图像不起作用

python - 在 Python numpy 掩码数组中用最近的邻居填充缺失值?

python - 从初始化程序运行实例方法时出现神秘的 "extra argument"错误

python - input() 函数对我不起作用(Python 3.3)

python - 从 INI 文件获取设置和配置以进行 Pyramid 功能测试

python - 为什么这个 join() 不起作用?

java - 在java中使用索引设置数组的值

Javascript 将远程 XML 文件读入数组