我正在尝试获取 nx2x3 数组和 nx3 数组中每个元素的点积(n 的值始终在两者之间共享)。
例如:
import numpy as np
a = np.arange(12).reshape(4,3)
b = np.arange(24).reshape(4,2,3)
我试图获取的数组将包含这些:
print(np.dot(b[0],a[0]))
print(np.dot(b[1],a[1]))
print(np.dot(b[2],a[2]))
print(np.dot(b[3],a[3]))
我确信有一种方法可以使用 einsum
或 tensordot
来实现此目的,但我无法让它工作。
最佳答案
您可以这样使用einsum
:
>>> np.einsum('ij,ikj->ik', a, b)
array([[ 5, 14],
[ 86, 122],
[275, 338],
[572, 662]])
这里发生的只是 a
的轴 0 与 b
的轴 0 相乘,a
的轴 1 与 axis 相乘2 个b
。沿后一个轴的值被求和并返回一个二维数组。
(tensordot
并不能完美地应用于这个特定问题,因为我们需要沿两个轴进行乘法,并沿一个轴进行求和。这些操作仅与 tensordot
成对出现。)
关于python - 混合形状阵列列的点积,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31436553/