python - 没有循环的多个 numpy 点产品

标签 python loops numpy matrix

是否可以在没有循环的情况下计算多个点积? 假设你有以下内容:

a = randn(100, 3, 3)
b = randn(100, 3, 3)

我想得到一个形状为 (100, 3, 3) 的数组 z,这样对于所有 i

z[i, ...] == dot(a[i, ...], b[i, ...])

换句话说,它验证:

for va, vb, vz in izip(a, b, z):
    assert (vq == dot(va, vb)).all()

直接的解决方案是:

z = array([dot(va, vb) for va, vb in zip(a, b)])

使用隐式循环(列表理解 + 数组)。

有没有更有效的方法来计算 z?

最佳答案

np.einsum 在这里很有用。尝试运行此复制+粘贴代码:

import numpy as np

a = np.random.randn(100, 3, 3)
b = np.random.randn(100, 3, 3)

z = np.einsum("ijk, ikl -> ijl", a, b)

z2 = np.array([ai.dot(bi) for ai, bi in zip(a, b)])

assert (z == z2).all()

einsum 是编译后的代码,运行速度非常快,甚至与 np.tensordot 相比(这里并不完全适用,但通常适用)。以下是一些统计数据:

In [8]: %timeit z = np.einsum("ijk, ikl -> ijl", a, b)
10000 loops, best of 3: 105 us per loop


In [9]: %timeit z2 = np.array([ai.dot(bi) for ai, bi in zip(a, b)])
1000 loops, best of 3: 1.06 ms per loop

关于python - 没有循环的多个 numpy 点产品,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/24090889/

相关文章:

python - pandas/numpy int64 中意外的 32 位整数溢出(python 3.6)

python tkinter 禁用文本小部件中的换行符

python - 在 QTableView 中使用 QSortFilterProxyModel 对两行进行排序

python - Python 脚本的 Node.js child_process 执行导致错误导入模块

python与SQL服务器的连接

python 列表突变(for in loop vs range(len))

具有文件名数组的 Java InputData 读/写循环

python - numpy View 如何知道它引用的值在原始 numpy 数组中的位置?

php - 为什么不能在表达式中组合 'continue' 关键字?

python - 使用 scipy.signal.spectral.lombscargle 进行周期发现