给定 torch 张量:
# example tensor size 2 x 4
a = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
另一个每 n 行重复:
# example tensor size 4 x 3 where every 2 rows repeated
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])
如何执行矩阵乘法:
>>> torch.mm(a, b)
tensor([[ 28., 38., 48.],
[ 68., 94., 120.]])
不将整个重复行张量复制到内存中或迭代?
即仅存储前 2 行:
# example tensor size 2 x 3 where only the first two rows from b are actually stored in memory
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])
因为这些行将会重复。
有一个函数
torch.expand()
但是当重复不止一行时,这确实有效,而且,正如这个问题:
Repeating a pytorch tensor without copying memory
表明并且我自己的测试确认,在调用时通常最终会将整个张量复制到内存中
.to(device)
也可以迭代地执行此操作,但这相对较慢。
是否有某种方法可以有效地执行此操作,而不将整个重复行张量存储在内存中?
编辑说明:
抱歉,最初没有澄清:一个被用作第一个张量的第一个维度,以保持示例简单,但我实际上正在寻找任意两个张量 a 和 b 的一般情况的解决方案,使得它们的维度与矩阵乘法兼容,并且 b 的行每 n 行重复一次。我已更新示例以反射(reflect)这一点。
最佳答案
假设第一个维度为a
如您的示例中所示,为 1,您可以执行以下操作:
a = torch.Tensor([[1, 2, 3, 4]])
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])
torch.mm(a.reshape(-1, 2), b_abbreviated).sum(axis=0, keepdim=True)
在这里,您不是重复行,而是乘以 a
分成 block ,然后按列将它们相加以获得相同的结果。
如果第一个维度为a
不一定是1,您可以尝试以下操作:
torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1).sum(
dim=0, keepdim=True).reshape(a.shape[0], -1)
在这里,您执行以下操作:
- 与
torch.mm(a.reshape(-1,2),b_abbreviated
,您再次拆分a
的每一行分成大小为 2 的 block ,并将它们一层一层地堆叠起来,然后将每一行堆叠在另一行之上。 - 与
torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0])
,然后将这些堆栈按行分离,以便拆分的每个结果组件对应于单行的 block 。 - 与
torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1)
然后这些堆栈按列连接。 - 与
.sum(dim=0, keepdim=True)
,结果对应于a
中各个行的单独 block 。已相加。 - 与
.reshape(a.shape[0], -1)
,行a
按列连接的数据再次按行堆叠。
与直接矩阵乘法相比,它似乎相当慢,这并不奇怪,但我还没有与显式迭代进行比较。可能有更好的方法可以做到这一点,如果我想到的话将进行编辑。
关于python - 如何有效地乘以具有重复行的 torch 张量,而不将所有行存储在内存中或迭代?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67496315/