python - 如何使用 scikit-learn 在分类问题中为 F1 分数做 GridSearchCV?

标签 python machine-learning neural-network multilabel-classification hyperparameters

我正在使用 scikit-learn 中的神经网络解决多分类问题,我正在尝试弄清楚如何优化我的超参数(层数、感知器和其他最终的东西)。

我发现 GridSearchCV 是实现此目的的方法,但我使用的代码返回平均准确度,而我实际上想测试 F1 分数。有谁知道如何编辑此代码以使其适用于 F1 分数?

在开始时,当我不得不评估精度/准确度时,我认为仅采用混淆矩阵并从中得出结论就“足够”了,同时通过反复试验改变层数和感知器的数量在我的神经网络中一次又一次。

今天我发现不止于此:GridSearchCV。我只需要弄清楚如何评估 F1 分数,因为我需要研究确定神经网络在层、节点和最终其他替代方案方面的准确性...

mlp = MLPClassifier(max_iter=600)
clf = GridSearchCV(mlp, parameter_space, n_jobs= -1, cv = 3)
clf.fit(X_train, y_train.values.ravel())

parameter_space = {
    'hidden_layer_sizes': [(1), (2), (3)],
}

print('Best parameters found:\n', clf.best_params_)

means = clf.cv_results_['mean_test_score']
stds = clf.cv_results_['std_test_score']
for mean, std, params in zip(means, stds, clf.cv_results_['params']):
    print("%0.3f (+/-%0.03f) for %r" % (mean, std * 2, params))

输出:

Best parameters found:
 {'hidden_layer_sizes': 3}
0.842 (+/-0.089) for {'hidden_layer_sizes': 1}
0.882 (+/-0.031) for {'hidden_layer_sizes': 2}
0.922 (+/-0.059) for {'hidden_layer_sizes': 3}

所以这里我的输出给出了平均准确度(我发现这是 GridSearchCV 的默认值)。我该如何更改它以返回平均 F1 分数而不是准确度?

最佳答案

您可以使用 make_scorer 创建自己的度量函数。在这种情况下,您可以使用 sklearn 的 f1_score,但如果您愿意,也可以使用自己的:

from sklearn.metrics import f1_score, make_scorer

f1 = make_scorer(f1_score , average='macro')


一旦你制作了你的记分器,你可以将它直接插入到网格创建中作为 scoring 参数:

clf = GridSearchCV(mlp, parameter_space, n_jobs= -1, cv = 3, scoring=f1)


另一方面,我使用 average='macro' 作为 f1 多类参数。这会计算每个标签的指标,然后找到它们的未加权平均值。但是还有其他选项可以计算具有多个标签的 f1。你可以找到他们here


注意:答案经过完全编辑以便更好地理解

关于python - 如何使用 scikit-learn 在分类问题中为 F1 分数做 GridSearchCV?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56084591/

相关文章:

python - Python 中的多输出多类机器学习

python - TensorFlow 和 Keras 的相同实现之间的不同行为

python - 运行 python unit-test 在要导入的正确目录中找不到我的文件

machine-learning - 新闻文章在线聚类

machine-learning - 波形比较

PHP脚本无法正确运行python脚本,但它在终端上正确运行

machine-learning - TensorFlow 学习率衰减 - 如何正确提供衰减的步数?

python - 抓取元素上缺少类/id 的数据

python - NFQUEUE/IPtables - Suricata 内嵌 Python 拦截 DNS

c++ - 对继承类的虚表的 undefined reference