python - PyTorch:如何通过广播两个不同形状的张量相乘

标签 python pytorch shapes matrix-multiplication array-broadcasting

我有以下两个 PyTorch 张量 A 和 B。

A = torch.tensor(np.array([40, 42, 38]), dtype = torch.float64)

tensor([40., 42., 38.], dtype=torch.float64)
B = torch.tensor(np.array([[[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5]], [[4,5,6,7,8],[4,5,6,7,8],[4,5,6,7,8],[4,5,6,7,8],[4,5,6,7,8]], [[7,8,9,10,11],[7,8,9,10,11],[7,8,9,10,11],[7,8,9,10,11],[7,8,9,10,11]]]), dtype = torch.float64)

tensor([[[ 1.,  2.,  3.,  4.,  5.],
         [ 1.,  2.,  3.,  4.,  5.],
         [ 1.,  2.,  3.,  4.,  5.],
         [ 1.,  2.,  3.,  4.,  5.],
         [ 1.,  2.,  3.,  4.,  5.]],

        [[ 4.,  5.,  6.,  7.,  8.],
         [ 4.,  5.,  6.,  7.,  8.],
         [ 4.,  5.,  6.,  7.,  8.],
         [ 4.,  5.,  6.,  7.,  8.],
         [ 4.,  5.,  6.,  7.,  8.]],

        [[ 7.,  8.,  9., 10., 11.],
         [ 7.,  8.,  9., 10., 11.],
         [ 7.,  8.,  9., 10., 11.],
         [ 7.,  8.,  9., 10., 11.],
         [ 7.,  8.,  9., 10., 11.]]], dtype=torch.float64)

张量 A 的形状为:

torch.Size([3])

张量 B 的形状为:

torch.Size([3, 5, 5])

如何以这种方式将张量 A 与张量 B(使用广播)相乘,例如。张量 A 中的第一个值(即 40.)乘以张量 B 中第一个“嵌套”张量中的所有值,即。

tensor([[[ 1.,  2.,  3.,  4.,  5.],
         [ 1.,  2.,  3.,  4.,  5.],
         [ 1.,  2.,  3.,  4.,  5.],
         [ 1.,  2.,  3.,  4.,  5.],
         [ 1.,  2.,  3.,  4.,  5.]],

分别对张量 A 中的其他 2 个值和张量 B 中的其他两个嵌套张量依此类推。

如果 A 和 B 都是形状为 (3,) 的数组,我可以用 numpy 数组做这个乘法(通过广播)——即。 A*B - 但我似乎无法用 PyTorch 张量找出与此对应的东西。非常感谢任何帮助。

最佳答案

在 pytorch(以及 numpy)中应用广播时,您需要从最后维度开始(查看https://pytorch.org/docs/stable/notes/broadcasting.html)。如果它们不匹配,则需要 reshape 张量。在您的情况下,它们不能直接广播:

      [3]  # the two values in the last dimensions are not one and do not match
[3, 5, 5]

相反,您可以在乘积之前重新定义 A = A[:, None, None] 以获得形状

[3, 1, 1]
[3, 5, 5]

满足广播条件

关于python - PyTorch:如何通过广播两个不同形状的张量相乘,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65121614/

相关文章:

python - .backward() 之后 pytorch grad 为 None

python - Pandas :在群体内规范化

python - 可以扩展 collections.deque 以构建 "file buffer"吗?

python - 如何在 Python 中测试一个数字是否为平方数?

python - 由于内存问题,如何保存仅与预训练bert模型的分类器层相关的参数?

python - PyTorch 中的嵌入会创建范数大于 max_norm 的嵌入

python - 我需要 Python 的帮助来检查列表的边界

java - 在 Java 中解析十六进制值

java - 确定线是否完全在path2d形状内(在java中)

python - 修改 Tkinter 标签