python - GridSearchCV 结果热图

标签 python matplotlib scikit-learn seaborn sklearn-pandas

我正在尝试为 sklearn 的 GridSearchCV 结果生成热图。我喜欢的东西sklearn-evaluation是因为生成热图真的很容易。但是,我遇到了一个问题。当我将参数设为 None 时,例如

max_depth = [3, 4, 5, 6, None]
在生成热图时,它显示错误:
TypeError: '<' not supported between instances of 'NoneType' and 'int'
有什么解决方法吗?
我找到了其他方法来生成热图,比如使用 matplotlib 和 seaborn,但没有什么能像 sklearn-evalutaion 那样给出漂亮的热图。
enter image description here

最佳答案

我摆弄着 grid_search.py文件 /lib/python3.8/site-packages/sklearn_evaluation/plot/grid_search.py .在第 192/193 行更改行

row_names = sorted(set([t[0] for t in matrix_elements.keys()]),
                   key=itemgetter(1))
col_names = sorted(set([t[1] for t in matrix_elements.keys()]),
                   key=itemgetter(1))
到:
row_names = sorted(set([t[0] for t in matrix_elements.keys()]),
                   key=lambda x: (x[1] is None, x[1]))
col_names = sorted(set([t[1] for t in matrix_elements.keys()]),
                   key=lambda x: (x[1] is None, x[1]))
全部搬家 None排序时到列表末尾基于之前的 answer
来自安德鲁·克拉克。
使用此调整,我的演示脚本如下所示:
import numpy as np
import sklearn.datasets as datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn_evaluation import plot

data = datasets.make_classification(n_samples=200, n_features=10, n_informative=4, class_sep=0.5)


X = data[0]
y = data[1]

hyperparameters = {
    "max_depth": [1, 2, 3, None],
    "criterion": ["gini", "entropy"],
    "max_features": ["sqrt", "log2"],
}

est = RandomForestClassifier(n_estimators=5)
clf = GridSearchCV(est, hyperparameters, cv=3)
clf.fit(X, y)
plot.grid_search(clf.cv_results_, change=("max_depth", "criterion"), subset={"max_features": "sqrt"})


import matplotlib.pyplot as plt

plt.show()
输出如下所示:
enter image description here

关于python - GridSearchCV 结果热图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68138679/

相关文章:

python - GitPython 通过管道输出到 stdout

python - python2中的string.decode()函数

Python:在 locals() 中查找键的类型

python - 如何以粗线绘制指定数据

python - 如何更改 matplotlib colorbar 标签的字体属性?

python - sklearn SelectKBest()中如何给score函数设置参数

python - SciKit-learn--高斯朴素贝叶斯实现

python - 在Python中将字典切片到某个键值以下

python - 在 ipython 笔记本中使用 matplotlib 内联时如何禁用 bbox_inches ='tight'

python-2.7 - 您可以从 scikit-learn 中的 DecisionTreeRegressor 获取选定的叶子吗