我正在尝试加载除最后一层之外的两个单独训练的模型,并希望结合这两个模型单独训练最后一层。我定义了一个新的 nn.Module 类,并将这些预训练模型加载到该类中,并在前向路径中尝试返回最后一层之前的值。
class New_net(nn.Module):
def __init__(self):
super(New_net, self).__init__()
self.net1 = net1()
self.net2 = net2()
self.fc= nn.Linear(512, 2)
self._initialize_weights()
def _initialize_weights(self):
checkpoint = torch.load('save_model/checkpoint_net1.t7')
self.net1.load_state_dict(checkpoint['state_dict'])
checkpoint = torch.load('save_model/checkpoint_net2.t7')
self.net2.load_state_dict(checkpoint['state_dict'])
def forward(self, x):
x1 = self.net1(x)
x2 = self.net2(x)
x=torch.cat((x1,x2),dim=1)
x=self.fc(x)
return x
但似乎没有准确加载模型。正确的做法是什么
最佳答案
我想到了。我没有进行权重初始化,而是执行了以下操作
#load net1 model partially
checkpoint = torch.load('save_model/checkpoint_net1.t7')
pretrained_dict=checkpoint['state_dict']
net1_dict=net.net1.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net1_dict}
net1_dict.update(pretrained_dict)
net.net1.load_state_dict(net1_dict)
#load net2 model partially
checkpoint = torch.load('save_model/checkpoint_net2.t7')
pretrained_dict=checkpoint['state_dict']
net2_dict=net.net2.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in net2_dict}
net2_dict.update(pretrained_dict)
net.net2.load_state_dict(net2_dict)
关于machine-learning - 在 Pytorch 中从 state-dict 部分加载两个预训练模型的最佳方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62649109/