python - 如何有效地乘以具有重复行的 torch 张量,而不将所有行存储在内存中或迭代?

标签 python pytorch matrix-multiplication tensor torch

给定 torch 张量:

# example tensor size 2 x 4
a = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])

另一个每 n 行重复:

# example tensor size 4 x 3 where every 2 rows repeated
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])

如何执行矩阵乘法:

>>> torch.mm(a, b)
tensor([[ 28.,  38.,  48.],
        [ 68.,  94., 120.]])

不将整个重复行张量复制到内存中或迭代?

即仅存储前 2 行:

# example tensor size 2 x 3 where only the first two rows from b are actually stored in memory
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])

因为这些行将会重复。

有一个函数

torch.expand()

但是当重复不止一行时,这确实有效,而且,正如这个问题:

Repeating a pytorch tensor without copying memory

表明并且我自己的测试确认,在调用时通常最终会将整个张量复制到内存中

.to(device)

也可以迭代地执行此操作,但这相对较慢。

是否有某种方法可以有效地执行此操作,而不将整个重复行张量存储在内存中?

编辑说明:

抱歉,最初没有澄清:一个被用作第一个张量的第一个维度,以保持示例简单,但我实际上正在寻找任意两个张量 a 和 b 的一般情况的解决方案,使得它们的维度与矩阵乘法兼容,并且 b 的行每 n 行重复一次。我已更新示例以反射(reflect)这一点。

最佳答案

假设第一个维度为a如您的示例中所示,为 1,您可以执行以下操作:

a = torch.Tensor([[1, 2, 3, 4]])
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])
torch.mm(a.reshape(-1, 2), b_abbreviated).sum(axis=0, keepdim=True)

在这里,您不是重复行,而是乘以 a分成 block ,然后按列将它们相加以获得相同的结果。


如果第一个维度为a不一定是1,您可以尝试以下操作:

torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1).sum(
dim=0, keepdim=True).reshape(a.shape[0], -1)

在这里,您执行以下操作:

  • torch.mm(a.reshape(-1,2),b_abbreviated ,您再次拆分 a 的每一行分成大小为 2 的 block ,并将它们一层一层地堆叠起来,然后将每一行堆叠在另一行之上。
  • torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]) ,然后将这些堆栈按行分离,以便拆分的每个结果组件对应于单行的 block 。
  • torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1)然后这些堆栈按列连接。
  • .sum(dim=0, keepdim=True) ,结果对应于a中各个行的单独 block 。已相加。
  • .reshape(a.shape[0], -1) ,行a按列连接的数据再次按行堆叠。

与直接矩阵乘法相比,它似乎相当慢,这并不奇怪,但我还没有与显式迭代进行比较。可能有更好的方法可以做到这一点,如果我想到的话将进行编辑。

关于python - 如何有效地乘以具有重复行的 torch 张量,而不将所有行存储在内存中或迭代?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67496315/

相关文章:

python - Pandas Dataframe,当某些行可能有超过 1 ","时,如何将一列分成两列 ","

python - 如何将 Web API 与 Django 结合使用

python - RNN 中的隐藏大小与输入大小

nlp - Gensim 的 word2vec 从 epoch 1 开始损失为 0?

c++ - BLAS 产品 dgemm 使用 CblasTrans 选项时行为异常

python - PyTorch:如何通过广播两个不同形状的张量相乘

python - 将数组数据与子数组数据匹配

python - 使用多处理时克服内存限制

python - 尽管形状相同 : if not (target. size() == input.size()) 但出现类型错误: 'int' 对象不可调用

matrix - 使用cuda乘以数百个矩阵