python - 使用 joblib 在 sklearn 中重用 cross_val_score 拟合的模型

标签 python scikit-learn joblib

<分区>

我在 python 中创建了以下函数:

def cross_validate(algorithms, data, labels, cv=4, n_jobs=-1):
    print "Cross validation using: "
    for alg, predictors in algorithms:
        print alg
        print
        # Compute the accuracy score for all the cross validation folds. 
        scores = cross_val_score(alg, data, labels, cv=cv, n_jobs=n_jobs)
        # Take the mean of the scores (because we have one for each fold)
        print scores
        print("Cross validation mean score = " + str(scores.mean()))

        name = re.split('\(', str(alg))
        filename = str('%0.5f' %scores.mean()) + "_" + name[0] + ".pkl"
        # We might use this another time 
        joblib.dump(alg, filename, compress=1, cache_size=1e9)  
        filenameL.append(filename)
        try:
            move(filename, "pkl")
        except:
            os.remove(filename) 

        print 
    return

我认为为了进行交叉验证,sklearn 必须适合您的功能。

但是,当我稍后尝试使用它时(f 是我在上面保存的 pkl 文件 joblib.dump(alg, filename, compress=1, cache_size=1e9)):

alg = joblib.load(f)  
predictions = alg.predict_proba(train_data[predictors]).astype(float)

我在第一行中没有收到任何错误(所以看起来负载正在工作),但随后它告诉我NotFittedError: Estimator not fitted, callfitbefore exploiting the model。 在下一行。

我做错了什么?我不能重用适合计算交叉验证的模型吗?我看了Keep the fitted parameters when using a cross_val_score in scikits learn但要么我不明白答案,要么这不是我要找的。我想要的是用 joblib 保存整个模型,这样我以后就可以使用它而无需重新拟合。

最佳答案

交叉验证必须适合您的模型并不完全正确;而是 k 折交叉验证在部分数据集上适合您的模型 k 次。如果你想要模型本身,你实际上需要在整个数据集上再次拟合模型;这实际上不是交叉验证过程的一部分。所以调用

实际上并不是多余的
alg.fit(data, labels)

在交叉验证后适合您的模型。

另一种方法是不使用专用函数 cross_val_score,您可以将其视为交叉验证网格搜索的特例(参数空间中有一个点)。在这种情况下,GridSearchCV 将默认在整个数据集上重新拟合模型(它有一个参数 refit=True),并且还有 predictpredict_proba API 中的方法。

关于python - 使用 joblib 在 sklearn 中重用 cross_val_score 拟合的模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36215700/

相关文章:

python - Pygame 缩放健康栏

python - GridSearchCV 结果热图

machine-learning - Sklearn 批量梯度下降的实现

Python,与 joblib : Delayed with multiple arguments 并行化

python - 用平均值填充 nan 值的更快方法

python - Numpy:将数组元素设置为另一个数组

python - 使用python替换特定行中的字符串

python - 随机森林修剪

python - 使用 joblib 加载腌制的 scikit-learn 模型时出现 KeyError

Python - Urllib.Request - 更改下载文件的位置