python - 广播 numpy 点积

标签 python arrays python-3.x numpy

假设我有一个数组 a矩阵,例如形状(N, 3, 3) = N 个 3x3 矩阵和数组 b形状(i, 3, k) .

我想要点积 a * b 的这些行为

  1. 如果i = N ,结果应该是 (N, 3, k)数组,其中第一个元素是 a[0].dot(b[0]) ,第二个a[1].dot(b[1]) ,等等。
  2. 如果i = 1 ,然后 a 的每个元素必须点乘 b[0]最终的形状应该再次是 (N, 3, k) .

如果我尝试仅使用 numpy.dot(),结果几乎不错,但生成的形状不是我所期望的。有没有一种方法可以轻松有效地做到这一点,并使其适用于任何维度?

最佳答案

你是对的,在(N, 3, 3) * (N, 3, k)中如果您无法使用 np.dot直接,因为结果将是 (N, 3, N, k) 。您实际上必须提取 N来自轴 0 和 2 的对角线元素,但这是大量不必要的计算。 (N, 3, 3) * (1, 3, k)可以使用np.dot解决案例如果您后申请 squeeze删除不必要的第三轴:result = a.dot(b).squeeze() .

好消息是您不需要 np.dot 得到点积。以下是三种选择:

最简单的是,使用 @ 运算符,相当于 np.matmul ,这需要主维度一起广播:

a @ b
np.matmul(a, b)

如果您的矩阵不在最后两个维度,您可以将它们转置为最后两个维度。这可能效率低下,因为它可能会复制数据。在这些情况下,请继续阅读。

另一个流行的解决方案是使用 np.einsum 显式指定匹配轴和要求和的轴:

np.einsum('ijk,ikl->ijl', a, b)

广播将同时处理这两个问题 i = Ni = 1案例。

当然,您始终可以使用 * / np.multiply 手动获取和积。和 np.sum :

(a[..., None] * b[:, None, ...]).sum(axis=2)

关于python - 广播 numpy 点积,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60874227/

相关文章:

python - Unicode、Python 3 和程序员之间又一场斗争。解码字符串

python - 使用 NLTK 为中文运行 StanfordPOSTagger 时出现意外格式

python - 从特定日期开始以不同颜色绘制折线图

python - requests.get 不返回网页

javascript - lodash 在数组中查找数组返回未定义

python - 在 Python 中使用对应于 X、Y 和 Z 的三个不同数组构建等高线图

javascript - 在 Javascript 中向嵌套数组添加元素

python - 如何消除 PyInstaller 单文件夹构建中的困惑?

python - 在 Pig 中使用 Python UDF 时,如何让 Hadoop 找到导入的 Python 模块?

linux - 为什么 PyQt5 QFileDialog.getExistingDirectory 看不到 ~/.config/subdirectory?