python - 将张量堆叠在张量元组列表中

标签 python pytorch torch

我有一个 PyTorch 张量元组列表。它看起来像这样:

[
    (tensor([1, 2, 3]),  tensor([4, 5, 6, 7]),  tensor([8])),
    (tensor([9, 10,11]), tensor([11,12,13,14]), tensor([15])),
    (tensor([16,17,18]), tensor([19,20,21,22]), tensor([23])),
    ...
]

每列中的张量(即位于各自元组中 k 的张量)共享相同的形状。我想将张量堆叠在每列中,以便最终得到一个元组,每个值都是沿列维度连接的张量。

在这种情况下,输出元组将具有三个值,如下所示:

(
 tensor([[1,2,3], [9,10,11], [16,17,18]]),

 tensor([[4,5,6,7], [11,12,13,14], [19,20,21,22]],

 tensor([[8],[15],[23])
)

这是一个虚构的例子。我想对任意长度的元组和任意大小的张量执行此操作。使用 PyTorch 快速进行此类串联的最佳方法是什么?

最佳答案

如果有人陷入同样复杂的场景,我可以用一句可爱的俏皮话来解决它:

tuple(map(torch.stack, zip(*x)))

在这种情况下,x 是我上面提到的原始列表。这行代码将 x 转换为所需的确切格式。

关于python - 将张量堆叠在张量元组列表中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59149275/

相关文章:

c++ - 如何将Torch::Tensor转换为cv::InputArray?

Python 优先级

pytorch - 使用fastai的learn.lr_find()选择learning_rate

nlp - PyTorch - 稀疏张量没有步幅

Python matplotlib,图像数据的无效形状

python - 为什么我在 cmd 中安装任何 python 模块时收到这些错误 "WARNING: Ignoring invalid distribution -yproj "

python - conv2d 之后的 PyTorch CNN 线性层形状

python - 在MQL5中接受Python生成的套接字的输出

python - 嵌套 if 语句在列表列表中不返回任何内容

oop - 在 Python 模块中强制执行方法顺序