我有一个张量 A = 40x1
.
我需要将这个张量与其他 3 个张量相乘:B = 40x100x384, C = 40x10, D=40x10
.
例如张量 B
,我们有 40 100x384
矩阵,我需要将这些矩阵中的每一个与 A
中的相应元素相乘
在 pytorch 中执行此操作的最佳方法是什么?假设我们可以有更多的矩阵,如 B、C、D,它们将始终采用 40xKxL
样式或40xJ
最佳答案
如果我理解正确的话,您需要将每个第 i 个矩阵 K x L
乘以 A
中相应的第 i 个标量。
一种可能的方法是:
(A * B.view(len(A), -1)).view(B.shape)
或者您可以使用 broadcasting 的力量:
A = A.reshape(len(A), 1, 1)
# now A is (40, 1, 1) and you can do
A*B
A*C
A*D
基本上,A
中每个等于 1 的尾随维度都会被拉伸(stretch)和复制以匹配其他矩阵。
关于python - Pytorch如何将除第一维之外的可变大小的张量相乘,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56133218/