python - 计算全连接层的尺寸?

标签 python pytorch conv-neural-network

我正在努力弄清楚如何计算全连接层的尺寸。我正在使用批量大小 (16) 输入 (448x448) 的图像。以下是我的卷积层的代码:

class ConvolutionalNet(nn.Module):
  def __init__(self, num_classes=182):
    super().__init__()

    self.layer1 = nn.Sequential(
        nn.Conv2d(3, 16, kernal_size=5, stride=1, padding=2),
        nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.MaxPool2d(kernal_size=2, stride=2)
    )

    self.layer2 = nn.Sequential(
        nn.Conv2d(16, 32, kernal_size=5, stride=1, padding=2),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(kernal_size=2, stride=2)
    )

    self.layer3 = nn.Sequential(
        nn.Conv2d(32, 32, kernal_size=5, stride=1, padding=2),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(kernal_size=2, stride=2)
    )

    self.layer4 = nn.Sequential(
        nn.Conv2d(32, 64, kernal_size=5, stride=1, padding=2),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(kernal_size=2, stride=2)
    )

    self.layer5 = nn.Sequential(
        nn.Conv2d(64, 64, kernal_size=5, stride=1, padding=2),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(kernal_size=2, stride=2)
    )

我想添加一个全连接层:

self.fc = nn.Linear(?, num_classes)

有人能解释一下计算这个的最佳方法吗?另外,如果我有多个完全连接的层,例如(self.fc2, self.fc3),第二个参数是否始终等于类的数量。我是编码新手,发现很难理解这一点。

最佳答案

由于您将 padding 设置为等于 (kernel_size - 1)/2,因此转换层不会更改特征的宽度/高度。 kernel_size = stride = 2 的最大池化会将宽度/高度减少 2 倍(如果输入形状不均匀则向下舍入)。

使用448作为输入宽度/高度,输出宽度/高度将为448//2//2//2//2//2 = 448/32 = 14 (其中 // 是下限运算符)。

channel 数完全由最后一个卷积层决定,该层输出 64 个 channel 。

因此,您将拥有一个 [B,64,14,14] 形状的张量,因此 Linear 层应具有 in_features = 64*14*14 = 12544

请注意,您需要事先展平输入,例如。

self.layer6 = nn.Sequential(
    nn.Flatten(),
    nn.Linear(12544, num_classes)
)

关于python - 计算全连接层的尺寸?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71385657/

相关文章:

python - 如何将完整字符串输出转换为 torch 张量

python - keras 中的图像到图像映射

python - 如何知道列表是否在 Python 3 中填充了相同的数字?

python - 如何根据pytorch中的另一个张量选择索引

python - 从下拉列表中的数据库 mysql 中选择条目,python Flask

pytorch - Pytorch 中的 .ckpt 和 .pth 文件有什么区别?

tensorflow - 卷积神经网络 (CNN) 输入形状

python - 是否可以创建同一 CNN 的多个实例,这些实例接收多个图像并连接成一个密集层? (喀拉斯)

Python (Mutagen) - 无法从 MP4/MP3 文件中获取艺术家

python - 在列表中执行二分搜索 - Python