pytorch - 如何在 pytorch 中展平张量?

标签 pytorch

给定一个多维的张量,我如何将它展平以使其具有单个维度?

例如:

>>> t = torch.rand([2, 3, 5])
>>> t.shape
torch.Size([2, 3, 5])

我如何将它压平以具有形状:
torch.Size([30])

最佳答案

TL;博士:torch.flatten()
使用 torch.flatten() 这是在 v0.4.1 中引入的并记录在 v1.0rc1 :

>>> t = torch.tensor([[[1, 2],
                       [3, 4]],
                      [[5, 6],
                       [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])


对于 v0.4.1 及更早版本,使用 t.reshape(-1) .

t.reshape(-1) :

如果请求的 View 在内存中是连续的
这将等同于 t.view(-1) 并且内存不会被复制。

否则等于t. contiguous() .view(-1) .

其他非选项:
  • t.view(-1) won't copy memory, but may not work depending on original size and stride
  • t.resize(-1)RuntimeError (见下文)
  • t.resize(t.numel()) warning about being a low-level method
    (见下面的讨论)

  • (注意: pytorchreshape() 可能会改变数据,但 numpy 's reshape() won't 。)
    t.resize(t.numel())需要一些讨论。 torch.Tensor.resize_ documentation说:

    The storage is reinterpreted as C-contiguous, ignoring the current strides (unless the target size equals the current size, in which case the tensor is left unchanged)



    鉴于当前的步幅将被新的 (1, numel()) 忽略。大小,元素的顺序可能会以与 reshape(-1) 不同的顺序出现.但是,“大小”可能意味着内存大小,而不是张量的大小。

    如果 t.resize(-1) 就好了既方便又高效,但使用 torch 1.0.1.post2 , t = torch.rand([2, 3, 5]); t.resize(-1)给出:
    RuntimeError: requested resize to -1 (-1 elements in total), but the given 
    tensor has a size of 2x2 (4 elements). autograd's resize can only change the 
    shape of a given tensor, while preserving the number of elements.
    

    我为此提出了一个功能请求 here ,但共识是 resize()是一种低级方法,并且 reshape()应该优先使用。

    关于pytorch - 如何在 pytorch 中展平张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55546873/

    相关文章:

    tensorflow - 在训练用于语义分割的深度学习模型时,处理背景像素类 (ignore_label) 的最佳方法是什么?

    c++ - Libtorch C++ 和 pytorch 的不同输出

    multithreading - 如何在 PyTorch 中禁用多线程?

    python - 如何使用 Pytorch 中的预训练权重以 4 个 channel 作为输入修改 resnet 50?

    python - Pytorch BiLSTM POS 标记问题 : RuntimeError: input. size(-1) 必须等于 input_size。预计6个,实测12个

    python - Pytorch教程代码错误: "NameError: name ' net' is not defined"

    deep-learning - 如何将 Torch 图像切片为 numpy 图像

    python - 比较 scipy、torch 和 Fourier 周期卷积时的不一致

    python - 无法让 pytorch 与张量板一起工作

    python-3.x - ModuleNotFoundError : No module named 'past' when installing tensorboard with pytorch 1. 2