neural-network - PyTorch:如何将 CNN 中的预训练 FC 层转换为 Conv 层

标签 neural-network pytorch conv-neural-network

我想在 Pytorch 中将预训练的 CNN(如 VGG-16)转换为完全卷积网络。我怎么能这样做?

最佳答案

您可以按如下方式执行此操作(有关说明,请参阅注释):

import torch
import torch.nn as nn
from torchvision import models

# 1. LOAD PRE-TRAINED VGG16
model = models.vgg16(pretrained=True)

# 2. GET CONV LAYERS
features = model.features

# 3. GET FULLY CONNECTED LAYERS
fcLayers = nn.Sequential(
    # stop at last layer
    *list(model.classifier.children())[:-1]
)

# 4. CONVERT FULLY CONNECTED LAYERS TO CONVOLUTIONAL LAYERS

### convert first fc layer to conv layer with 512x7x7 kernel
fc = fcLayers[0].state_dict()
in_ch = 512
out_ch = fc["weight"].size(0)

firstConv = nn.Conv2d(in_ch, out_ch, 7, 7)

### get the weights from the fc layer
firstConv.load_state_dict({"weight":fc["weight"].view(out_ch, in_ch, 7, 7),
                           "bias":fc["bias"]})

# CREATE A LIST OF CONVS
convList = [firstConv]

# Similarly convert the remaining linear layers to conv layers 
for layer in enumerate(fcLayers[1:]):
    if isinstance(module, nn.Linear):
        # Convert the nn.Linear to nn.Conv
        fc = module.state_dict()
        in_ch = fc["weight"].size(1)
        out_ch = fc["weight"].size(0)
        conv = nn.Conv2d(in_ch, out_ch, 1, 1)

        conv.load_state_dict({"weight":fc["weight"].view(out_ch, in_ch, 1, 1),
            "bias":fc["bias"]})

        convList += [conv]
    else:
        # Append other layers such as ReLU and Dropout
        convList += [layer]

# Set the conv layers as a nn.Sequential module
convLayers = nn.Sequential(*convList)  

关于neural-network - PyTorch:如何将 CNN 中的预训练 FC 层转换为 Conv 层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44146655/

相关文章:

conv-neural-network - 我应该对训练集中存在的所有对象进行标记和训练吗(yolo darknet)

python - 只能使用 TensorFlow 中处理梯度的代码示例来实现类似优化器的梯度下降吗?

machine-learning - Caffe框架: A real example of batch size, max_iter、test_iter、epoch

python - 我如何知道 "bias"是否存在于图层中?

algorithm - 如何转换我的 float 以喂养我的神经网络?

python - 如何使用 pybind11 .so 链接所有 PyTorch?

machine-learning - 感受野大小与物体大小

python - 神经网络的反向传播(形状误差)

deep-learning - 使 CUDA 内存不足

python - 密集合成器的实现