python - 使用来自 gridsearchcv 的最佳参数

标签 python scikit-learn grid-search

我不知道在这里问这个问题是否正确,但无论如何我都会问。如果不允许,请告诉我。

我已使用 GridSearchCV 调整参数以找到最佳精度。这就是我所做的:

from sklearn.grid_search import GridSearchCV
parameters = {'min_samples_split':np.arange(2, 80), 'max_depth': np.arange(2,10), 'criterion':['gini', 'entropy']}
clfr = DecisionTreeClassifier()
grid = GridSearchCV(clfr, parameters,scoring='accuracy', cv=8)
grid.fit(X_train,y_train)
print('The parameters combination that would give best accuracy is : ')
print(grid.best_params_)
print('The best accuracy achieved after parameter tuning via grid search is : ', grid.best_score_)

这给了我以下结果:

The parameters combination that would give best accuracy is : 
{'max_depth': 5, 'criterion': 'entropy', 'min_samples_split': 2}
The best accuracy achieved after parameter tuning via grid search is :  0.8147086914995224

现在,我想在调用可视化决策树的函数时使用这些参数

函数看起来像这样

def visualize_decision_tree(decision_tree, feature, target):
    dot_data = export_graphviz(decision_tree, out_file=None, 
                         feature_names=feature,  
                         class_names=target,  
                         filled=True, rounded=True,  
                         special_characters=True)  
    graph = pydotplus.graph_from_dot_data(dot_data)  
    return Image(graph.create_png())

现在我正在尝试使用 GridSearchCV 提供的最佳参数以下列方式调用该函数

dtBestScore = DecisionTreeClassifier(parameters = grid.best_params_)
dtBestScore = dtBestScore.fit(X=dfWithTrainFeatures, y= dfWithTestFeature)
visualize_decision_tree(dtBestScore, list(dfCopy.columns.delete(0).values), 'survived')

我在第一行代码中遇到错误

TypeError: __init__() got an unexpected keyword argument 'parameters'

有什么方法可以设法使用网格搜索提供的最佳参数并自动使用它?而不是查看结果并手动设置每个参数的值?

最佳答案

试试 python kwargs:

DecisionTreeClassifier(**grid.best_params)

参见 http://pythontips.com/2013/08/04/args-and-kwargs-in-python-explaine ‌ d 了解更多关于 kwargs 的信息。

关于python - 使用来自 gridsearchcv 的最佳参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41475539/

相关文章:

python - "for item in L"循环中的语法无效

python - Plotly:如何在 Excel 中嵌入完全交互式的 Plotly 图形?

python - 使用 Python 和正则表达式在字符串中查找 C 关键字

python - 导入 WriteToDatastore 时出错(Apache Beam/Google DataFlow)

python - 类型错误 : 'ShuffleSplit' object is not iterable

python - 将特征哈希应用于 DataFrame 中的特定列

python - 为什么具有铰链损失的 SGDClassifier 比 scikit-learn 中的 SVC 实现更快

python-3.x - pandas_ml 坏了吗?

python - n_jobs=-1 的 GridSearchCV 不适用于决策树/随机森林分类

machine-learning - 如果我们在管道中包含 Transformer,scikit-learn 的 `cross_val_score` 和 `GridsearchCV` 的 k 倍交叉验证分数是否有偏差?