python - sklearn - 具有多个分数的交叉验证

标签 python numpy scikit-learn

我想计算不同分类器的交叉验证测试的 recallprecisionf-measurescikit-learn 自带 cross_val_score但不幸的是,这种方法不会返回多个值。

我可以通过调用 3 次 cross_val_score 来计算此类度量,但这并不高效。有没有更好的解决方案?

现在我写了这个函数:

from sklearn import metrics

def mean_scores(X, y, clf, skf):

    cm = np.zeros(len(np.unique(y)) ** 2)
    for i, (train, test) in enumerate(skf):
        clf.fit(X[train], y[train])
        y_pred = clf.predict(X[test])
        cm += metrics.confusion_matrix(y[test], y_pred).flatten()

    return compute_measures(*cm / skf.n_folds)

def compute_measures(tp, fp, fn, tn):
     """Computes effectiveness measures given a confusion matrix."""
     specificity = tn / (tn + fp)
     sensitivity = tp / (tp + fn)
     fmeasure = 2 * (specificity * sensitivity) / (specificity + sensitivity)
     return sensitivity, specificity, fmeasure

它基本上总结了混淆矩阵的值,一旦你有 false positivefalse positive 等,你可以轻松计算召回率、精度等......但我仍然不喜欢这个解决方案:)

最佳答案

现在在 scikit-learn 中:cross_validate 是一个新函数,可以根据多个指标评估模型。 GridSearchCVRandomizedSearchCV (doc) 也提供此功能。 一直是merged recently in master并将在 v0.19 中提供。

来自 scikit-learn doc :

The cross_validate function differs from cross_val_score in two ways: 1. It allows specifying multiple metrics for evaluation. 2. It returns a dict containing training scores, fit-times and score-times in addition to the test score.

典型用例如下:

from sklearn.svm import SVC
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_validate
iris = load_iris()
scoring = ['precision', 'recall', 'f1']
clf = SVC(kernel='linear', C=1, random_state=0)
scores = cross_validate(clf, iris.data, iris.target == 1, cv=5,
                        scoring=scoring, return_train_score=False)

另见 this example .

关于python - sklearn - 具有多个分数的交叉验证,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/23339523/

相关文章:

python - 在贝塞尔曲线的开始/结束处强制法向加速度为零

python - 在哪里可以找到 numpy 百分位数的源代码

python - 用户警告 : Label not :NUMBER: is present in all training examples

python - 任何语言中最流畅的 REPL 控制台

python - 使用 numpy.savetxt 同时保持数组的形状

machine-learning - 新闻文章多类分类算法

machine-learning - scikit learn 的plot_learning_curve 得分是多少?

python - 没有沉重的数据库如何进行模糊字符串搜索?

python - 将 csv 文件转换为字典列表

python - 同步但不关闭 dbm