python - 使用fast-ai时如何修复 'Error(s) in loading state_dict for AWD_LSTM'

标签 python machine-learning torch fast-ai modelstatedictionary

我正在使用 fast-ai 库来训练 IMDB 评论数据集的样本。我的目标是实现情感分析,我只想从一个小数据集开始(这个数据集包含 1000 条 IMDB 评论)。我使用 this tutorial 在虚拟机中训练了模型.

我保存了 data_lmdata_clas 模型,然后是编码器 ft_enc,之后保存了分类器学习器 sentiment_model.然后,我从虚拟机中获取了这 4 个文件,并将它们放入我的机器中,并希望使用这些预训练模型来对情绪进行分类。

这就是我所做的:

# Use the IMDB_SAMPLE file
path = untar_data(URLs.IMDB_SAMPLE)

# Language model data
data_lm = TextLMDataBunch.from_csv(path, 'texts.csv')

# Sentiment classifier model data
data_clas = TextClasDataBunch.from_csv(path, 'texts.csv', 
                                       vocab=data_lm.train_ds.vocab, bs=32)

# Build a classifier using the tuned encoder (tuned in the VM)
learn = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5)
learn.load_encoder('ft_enc')

# Load the trained model
learn.load('sentiment_model')

之后,我想使用该模型来预测句子的情绪。执行此代码时,我遇到以下错误:

RuntimeError: Error(s) in loading state_dict for AWD_LSTM:
   size mismatch for encoder.weight: copying a param with shape torch.Size([8731, 400]) from checkpoint, the shape in current model is torch.Size([8888, 400]).
   size mismatch for encoder_dp.emb.weight: copying a param with shape torch.Size([8731, 400]) from checkpoint, the shape in current model is torch.Size([8888, 400]). 

回溯是:

Traceback (most recent call last):
  File "C:/Users/user/PycharmProjects/SentAn/mainApp.py", line 51, in <module>
    learn = load_models()
  File "C:/Users/user/PycharmProjects/SentAn/mainApp.py", line 32, in load_models
    learn.load_encoder('ft_enc')
  File "C:\Users\user\Desktop\py_code\env\lib\site-packages\fastai\text\learner.py", line 68, in load_encoder
    encoder.load_state_dict(torch.load(self.path/self.model_dir/f'{name}.pth'))
  File "C:\Users\user\Desktop\py_code\env\lib\site-packages\torch\nn\modules\module.py", line 769, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))

所以加载编码器时出现错误。但是,我也尝试删除 load_encoder 行,但下一行 learn.load('sentiment_model') 发生了同样的错误。

我在fast-ai论坛上搜索了一下,发现其他人也有这个问题,但没有找到解决方案。在 this post用户说这可能与不同的预处理有关,尽管我无法理解为什么会发生这种情况。

有人知道我做错了什么吗?

最佳答案

看来data_clas和data_lm的词汇量大小不同。我猜这个问题是由 data_clas 和 data_lm 中使用的不同预处理引起的。为了检查我的猜测,我简单地使用了

data_clas.vocab.itos = data_lm.vocab.itos

下一行之前

learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.3)

这已修复了错误。

关于python - 使用fast-ai时如何修复 'Error(s) in loading state_dict for AWD_LSTM',我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55847371/

相关文章:

image-processing - 人脸识别的误报

python-3.x - PyTorch 中的左移张量

python - codeigniter 作为 web 和 python 作为 web 服务

python - 如何跳过要求输入密码的 Fabric 连接?

Python:两个列表列表的交集

algorithm - 哪个是正确的 tripletLoss 反向传播公式?

lua - torch/nn - 按元素连接张量数组

python tsne.transform 不存在?

machine-learning - liblbfgs 在 C++ 中编译

machine-learning - GridSearchCV + StratifiedKfold(如果是 TFIDF)