pytorch - 从检查点加载模型不起作用

标签 pytorch pytorch-lightning

我训练了一个我从 this repository 修改的 Vanilla vae .当我尝试使用经过训练的模型时,我无法使用 load_from_checkpoint 加载权重.我的检查点对象和我的 lightningModule 之间似乎不匹配目的。
我已经使用 VAEXperiment 设置了一个实验 (pytorch-lightning LightningModule) .我尝试将权重加载到网络中:

#building a new model
model = VanillaVAE(**config['model_params'])
model.build_layers()

#loading the weights
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_from_checkpoint(path_to_checkpoint, config['exp_params'])
我也试过:
checkpoint = torch.load(path_to_checkpoint, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])
但我得到一个错误Unexpected key(s) in state_dict: "model.encoder.0.0.weight", "model.encoder.0.0.bias" ...
我也关注了这个问题
https://github.com/PyTorchLightning/pytorch-lightning/issues/924
https://github.com/PyTorchLightning/pytorch-lightning/issues/2798
为什么我会收到此错误?是因为我的模型中的编码器和解码器模块吗?根据 git 上的问题日志,似乎错误已解决。我究竟做错了什么?

最佳答案

从评论中发布答案:

experiment.load_state_dict(checkpoint['state_dict'])

关于pytorch - 从检查点加载模型不起作用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63243359/

相关文章:

python - 使用 PyTorchVideo 加载用于训练视频分类模型的动力学数据集时出错

machine-learning - Pytorch N - Beats 模型抛出错误 : 'str' object has no attribute '__name__'

python - 为什么PIL与Pytorch经常使用?

pytorch - 使用 pytorch 验证卷积定理

python - PyTorch - 自定义 ReLU 平方实现

python - 如何在同一张图中绘制 Tensorboard 中的多个标量而不向实验列表发送垃圾邮件?

python-3.x - “torchmetrics”不适用于 PyTorchLightning

python-3.x - 使用 Pytorch Lightning DDP 时记录事情的正确方法

pytorch - 并行模拟 torch.nn.Sequential 容器

python - PyTorch 中相同形状的掩蔽张量