machine-learning - 在 Pytorch 中从 state-dict 部分加载两个预训练模型的最佳方法是什么?

标签 machine-learning deep-learning pytorch pre-trained-model

我正在尝试加载除最后一层之外的两个单独训练的模型,并希望结合这两个模型单独训练最后一层。我定义了一个新的 nn.Module 类,并将这些预训练模型加载到该类中,并在前向路径中尝试返回最后一层之前的值。

class New_net(nn.Module):
    def __init__(self):
        super(New_net, self).__init__()
        self.net1 = net1()
        self.net2 = net2()
        self.fc= nn.Linear(512, 2)
        self._initialize_weights()

    def _initialize_weights(self):
        checkpoint = torch.load('save_model/checkpoint_net1.t7')
        self.net1.load_state_dict(checkpoint['state_dict'])

        checkpoint = torch.load('save_model/checkpoint_net2.t7')
        self.net2.load_state_dict(checkpoint['state_dict'])       

    def forward(self, x):
        x1 = self.net1(x)
        x2 = self.net2(x)
        x=torch.cat((x1,x2),dim=1)
        x=self.fc(x)
        return x

但似乎没有准确加载模型。正确的做法是什么

最佳答案

我想到了。我没有进行权重初始化,而是执行了以下操作

#load net1 model partially
checkpoint = torch.load('save_model/checkpoint_net1.t7')
pretrained_dict=checkpoint['state_dict']

net1_dict=net.net1.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net1_dict}
net1_dict.update(pretrained_dict)
net.net1.load_state_dict(net1_dict)

#load net2 model partially
checkpoint = torch.load('save_model/checkpoint_net2.t7')
pretrained_dict=checkpoint['state_dict']
net2_dict=net.net2.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net2_dict}
net2_dict.update(pretrained_dict)
net.net2.load_state_dict(net2_dict)

关于machine-learning - 在 Pytorch 中从 state-dict 部分加载两个预训练模型的最佳方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62649109/

相关文章:

tensorflow - Channels first vs Channels last——这是什么意思?

tensorflow - 使用机器学习去除手写签名图像中的背景

machine-learning - 卷积神经网络中不同内核大小(1x1、3x3、5x5)之间有什么区别?

machine-learning - 在 cifar-10 上的 Keras 中实现 AlexNet 的准确性较差

python - PyTorch 数据加载器中的 "number of workers"参数实际上是如何工作的?

python - 一次迭代两个 Pytorch 张量?

python - pycharm中找不到Pytorch,无法安装

machine-learning - 有人会如何创建一种机器学习算法来从书籍/小说中提取说话者?

python - Tensorflow 抛出分布式函数错误

neural-network - Style Transfer 和 GAN 之间的关系是什么?