python - 在 pytorch v1.0 Sequential 模块中使用 flatten

标签 python pytorch conv-neural-network

由于我的CUDA版本是8,所以我用的是torch 1.0.0

我需要为 Sequential 模型使用 Flatten 层。这是我的代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
print(torch.__version__)
# 1.0.0
from collections import OrderedDict

layers = OrderedDict()
layers['conv1'] = nn.Conv2d(1, 5, 3)
layers['relu1'] = nn.ReLU()
layers['conv2'] = nn.Conv2d(5, 1, 3)
layers['relu2'] = nn.ReLU()
layers['flatten'] = nn.Flatten()
layers['linear1'] = nn.Linear(3600, 1)
model = nn.Sequential(
layers
).cuda()

它给我以下错误:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-38-080f7c5f5037> in <module>
      6 layers['conv2'] = nn.Conv2d(5, 1, 3)
      7 layers['relu2'] = nn.ReLU()
----> 8 layers['flatten'] = nn.Flatten()
      9 layers['linear1'] = nn.Linear(3600, 1)
     10 model = nn.Sequential(

AttributeError: module 'torch.nn' has no attribute 'Flatten'

如何在 pytorch 1.0.0 中展平我的转换层输出?

最佳答案

只需新建一个 Flatten 层即可。

from collections import OrderedDict

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

layers = OrderedDict()
layers['conv1'] = nn.Conv2d(1, 5, 3)
layers['relu1'] = nn.ReLU()
layers['conv2'] = nn.Conv2d(5, 1, 3)
layers['relu2'] = nn.ReLU()
layers['flatten'] = Flatten()
layers['linear1'] = nn.Linear(3600, 1)
model = nn.Sequential(
layers
).cuda()

关于python - 在 pytorch v1.0 Sequential 模块中使用 flatten,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61039700/

相关文章:

Python pandas 数据框,逐行聚合直至停止标准

Python:列表分配超出范围

python - 如何打印张量而不显示梯度

python - 运行时错误 : one of the variables needed for gradient computation has been modified by an inplace operation

neural-network - caffe 丢失错误 : Check failed: The data and label should have the same first dimension

python - 使用 tf.metric 模块中的变量时出现 TensorFlow FailedPreconditionError

python - 谷歌地图 API 的命中率限制,但不知道为什么

python - Pytorch:从矩阵元素的总和反向传播到叶变量

python - 如何对连续音频进行分类

python - float 打印不一致。为什么它有时会起作用?