我有一个 PyTorch 张量元组列表。它看起来像这样:
[
(tensor([1, 2, 3]), tensor([4, 5, 6, 7]), tensor([8])),
(tensor([9, 10,11]), tensor([11,12,13,14]), tensor([15])),
(tensor([16,17,18]), tensor([19,20,21,22]), tensor([23])),
...
]
每列中的张量(即位于各自元组中 k 的张量)共享相同的形状。我想将张量堆叠在每列中,以便最终得到一个元组,每个值都是沿列维度连接的张量。
在这种情况下,输出元组将具有三个值,如下所示:
(
tensor([[1,2,3], [9,10,11], [16,17,18]]),
tensor([[4,5,6,7], [11,12,13,14], [19,20,21,22]],
tensor([[8],[15],[23])
)
这是一个虚构的例子。我想对任意长度的元组和任意大小的张量执行此操作。使用 PyTorch 快速进行此类串联的最佳方法是什么?
最佳答案
如果有人陷入同样复杂的场景,我可以用一句可爱的俏皮话来解决它:
tuple(map(torch.stack, zip(*x)))
在这种情况下,x
是我上面提到的原始列表。这行代码将 x
转换为所需的确切格式。
关于python - 将张量堆叠在张量元组列表中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59149275/