python - 如何在 PyTorch 中指定多个转换层后的展平层输入大小?

标签 python python-3.x pytorch

这是我的问题,我在 CIFAR10 数据集上做了一个小测试,如何在 PyTorch 中指定展平层输入大小?如下所示,输入大小为 16*5*5,但是我不知道如何计算它,我想通过某个函数获取输入大小。有人可以在这个 Net 类中编写一个简单的函数并解决这个问题?

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)  
        self.conv2 = nn.Conv2d(6,16,5)

        # HERE , the input size is 16*5*5, but I don't know how to get it.
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(x.size()[0],-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

最佳答案

Pytorch 默认没有 Flatten Layer。您可以创建一个如下所示的类。干杯

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.flatten   = Flatten()  ## describing the layer
        self.conv1 = nn.Conv2d(3,6,5)  
        self.conv2 = nn.Conv2d(6,16,5)

        # HERE , the input size is 16*5*5, but I don't know how to get it.
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        #x = x.view(x.size()[0],-1)
        x = self.flatten(x)   ### using of flatten layer
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

关于python - 如何在 PyTorch 中指定多个转换层后的展平层输入大小?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52474439/

相关文章:

python - 如果 X 或 Y 或 Z 则使用 *that* 一个?

Python NumPy - FFT 和逆 FFT?

python - 只给我的卷积神经网络提供 channel 数而不提供高度和宽度如何工作?

python - 如何将 L1 正则化添加到 PyTorch NN 模型?

python - 如何在 Python 中读出 *new* OS 环境变量?

c++ - Pytorch C++ 运行时错误 : Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_index_select

python - Google App Engine 中的自引用 ReferenceProperty

python - matplotlib,如何在给定条件下绘制 3d 2 变量函数

python-3.x - 禁用 ModelMultipleChoiceField CheckBoxSelectMultiple Django 中的选择

python - Python 3.2 中的 HEX 解码