python - 如何将 GridSearchCV 用于带有 SVC 估计器的 OneVsRestClassifier?

标签 python classification svc gridsearchcv

我正在尝试将 OneVsRestClassifier 与 SVC 一起用于图像的多分类问题 - 我从 CellProfiler 获得了图像的数值特征。我想使用 GridSearchCV 查找要使用的超参数,但我被卡住了。

有人对此有解决方案/建议吗?

我已通过 Google 阅读,但似乎无法解决我的问题。

    grid = GridSearchCV(pipe, scoring='f1',
                       param_grid=param_grid, cv=5,
                       return_train_score=True,
                       iid=False,
                       n_jobs=-1
                       )
    grid.fit(X_train, np.ravel(y_train))
    return grid
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import classification_report

pipe = make_pipeline(StandardScaler(),
                     OneVsRestClassifier(SVC(probability=True)))

param_grid = {
    'estimator__C': [0.001, 0.01, 0.1, 1, 10, 100],
    'estimator__kernel': ['linear', 'rbf', 'poly'],
    'estimator__degree': [2, 3, 4, 5, 7, 10],
    'estimator__gamma': [0.01, 0.02, 0.03, 0.04, 0.05, 1]
}

clf = grid_search_fit(pipe, param_grid)

preds = clf.predict(X_test)
print(classification_report(y_test, preds, target_names = ['empty', 'good', 'blurred']))
ValueError: Invalid parameter estimator for estimator Pipeline(memory=None,
         steps=[('standardscaler',
                 StandardScaler(copy=True, with_mean=True, with_std=True)),
                ('onevsrestclassifier',
                 OneVsRestClassifier(estimator=SVC(C=1.0, cache_size=200,
                                                   class_weight=None, coef0=0.0,
                                                   decision_function_shape='ovr',
                                                   degree=3,
                                                   gamma='auto_deprecated',
                                                   kernel='rbf', max_iter=-1,
                                                   probability=True,
                                                   random_state=None,
                                                   shrinking=True, tol=0.001,
                                                   verbose=False),
                                     n_jobs=None))],
         verbose=False). Check the list of available parameters with `estimator.get_params().keys()`.

最佳答案

我对你的代码做了如下修改:

  1. 删除了选项 iid=False
  2. 我稍微改变了你的 Pipeline 和 GridSearchCV 的形状

更改后的代码如下。您可以或多或少地像这样构建 Pipeline 和 Gridsearch。

from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import classification_report

pipe = Pipeline([
    ("scale", StandardScaler()),
    ('classify', OneVsRestClassifier(SVC(probability=True)))
])
    
param_grid = {
    'classify__estimator__C': [0.001, 0.01, 0.1, 1, 10, 100],
    'classify__estimator__kernel': ['linear', 'rbf', 'poly'],
    'classify__estimator__degree': [2, 3, 4, 5, 7, 10],
    'classify__estimator__gamma': [0.01, 0.02, 0.03, 0.04, 0.05, 1]
}

grid_search = GridSearchCV(
    pipe, param_grid, cv=5, scoring='f1', verbose=1, return_train_score=True, n_jobs=-1)

grid_search = grid_search.fit(X_train, np.ravel(y_train))

preds = clf.predict(X_test)
print(classification_report(y_test, preds, target_names = ['empty', 'good', 'blurred']))

关于python - 如何将 GridSearchCV 用于带有 SVC 估计器的 OneVsRestClassifier?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58723803/

相关文章:

javascript - 如何使对象在 Javascript 中可订阅?

python - 日期范围内的结束日期

machine-learning - Weka机器学习:how to interprete Naive Bayes classifier?

python - 如何将 Scikit Learn 分类器应用于大图像中的图 block /窗口

Python for 循环查找 SVC 的最佳值(C 和 gamma)

python - 如果值相同,如何对字典中的键进行排序?

python - 在列表列表中查找最大出现次数

javascript - 从 JavaScript 访问 WCF WebService - 对预检请求的响应未通过访问控制检查

python-3.x - 从 Cereal 图像中辨别有缺陷的 Cereal

windows - 使用 Visual Studio 2015 时进程 'microsoft.vshub.server.httphostx64.exe' 的高内存使用率