python - 如何挤压除一支 torch 之外的所有 torch ?

标签 python pytorch

torch.squeeze可以将张量的形状转换为尺寸不为 1 的张量。

我想在除一维之外的所有维度上挤压张量(在本例中,不是挤压 dim=0)。

我在文档中只能看到

dim (int, optional) – if given, the input will be squeezed only in this dimension

我想要相反的:

t = torch.zeros(5, 1, 6, 1, 7, 1)

squeezed = torch.magic_squeeze(keep_dim=3)

assert squeezed == (5, 6, 1, 7)

这可以做到吗?

最佳答案

Reshape 会让您完成您想做的事情:

import torch

t = torch.zeros(5, 1, 6, 1, 7, 1)
t = t.reshape((5, 6, 1, 7))
>>> torch.Size([5, 6, 1, 7])

关于python - 如何挤压除一支 torch 之外的所有 torch ?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66226505/

相关文章:

json - ML 模型输入和输出类比

pytorch - 为什么多层感知器在 CartPole 中优于 RNN?

python - 带有 Multipart/form-data 的 POST 请求。内容类型不正确

Python-FTP下载目录中的所有文件

python - 为什么在 pyTorch 强化学习示例的 nn.Module 中返回 self.head(x.view(x.size(0), -1))

python - Pytorch:除了一个求和之外的所有求和?

tensorflow - Pytorch Autograd : what does runtime error "grad can be implicitly created only for scalar outputs" mean

python - 如何使用原始输入来定义变量

python - 有用的发布-订阅语义

Python:如何正确处理 pandas DataFrame 中的 NaN,以便在 scikit-learn 中进行特征选择