python - 如何使用 scikit-learn 加载之前保存的模型并使用新的训练数据扩展模型

标签 python machine-learning scikit-learn

我正在使用 scikit-learn,其中我保存了一个逻辑回归模型,其中包含一元组作为训练集 1 中的特征。是否可以加载此模型,然后使用第二个训练集(训练设置 2)?如果是,该怎么办?这样做的原因是因为我对每个训练集使用两种不同的方法(第一种方法涉及特征损坏/正则化,第二种方法涉及 self 训练)。

为了清楚起见,我添加了一些简单的示例代码:

from sklearn.linear_model import LogisticRegression as log
from sklearn.feature_extraction.text import CountVectorizer as cv
import pickle

trainText1 # Training set 1 text instances    
trainLabel1 # Training set 1 labels 
trainText2 # Training set 2 text instances    
trainLabel2 # Training set 2 labels 

clf = log()
# Count vectorizer used by the logistic regression classifier 
vec = cv() 

# Fit count vectorizer with training text data from training set 1
vec.fit(trainText1) 

# Transforms text into vectors for training set1
train1Text1 = vec.transform(trainText1) 

# Fitting training set1 to the linear logistic regression classifier 
clf.fit(trainText1,trainLabel1)

# Saving logistic regression model from training set 1
modelFileSave = open('modelFromTrainingSet1', 'wb')
pickle.dump(clf, modelFileSave)
modelFileSave.close()  

# Loading logistic regression model from training set 1    
modelFileLoad = open('modelFromTrainingSet1', 'rb')
clf = pickle.load(modelFileLoad)

# I'm unsure how to continue from here....

最佳答案

LogisticRegression 在内部使用不支持增量拟合的 liblinear 求解器。相反,您可以使用 SGDClassifier(loss='log') 作为 partial_fit 方法,尽管在实践中可以用于此目的。其他超参数有所不同。小心仔细地网格搜索它们的最佳值。请阅读 SGDClassifier 文档了解这些超参数的含义。

CountVectorizer 不支持增量拟合。您必须重复使用训练集 #1 上安装的矢量化器来转换 #2。这意味着集合 #2 中尚未在 #1 中看到的任何标记都将被完全忽略。这可能不是您所期望的。

为了缓解这种情况,您可以使用无状态的 HashingVectorizer,但代价是不知道这些功能的含义。阅读 the documentation了解更多详情。

关于python - 如何使用 scikit-learn 加载之前保存的模型并使用新的训练数据扩展模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/26714214/

相关文章:

python - 文本分类与推荐

python-3.x - 属性错误: 'list' object has no attribute 'create_png'

python - 将 dtype= 对象转换为 dtype ='|S5'

python - scikit-learn OpenMP libsvm

python - genericsetup 导入步骤的名称是否有一个很好的引用列表

python - Matplotlib - 基于光谱颜色的曲线下颜色

python - 当值与另一列匹配时回填 Pandas 系列中的值

python - 将 CudaNdarraySharedVariable 转换为 TensorVariable

python - sklearn 中的 log_loss : Multioutput target data is not supported with label binarization

python - 在 Python 中访问静态属性