python - 使用scikit-learn使用管道时出错

标签 python scikit-learn pipeline

我正在尝试使用 StandardScaler 执行缩放并定义 KNeighborsClassifier(创建缩放器和估计器的管道)

最后,我想为上面的内容创建一个网格搜索交叉验证器,其中 param_grid 将是一个字典,其中包含 n_neighbors 作为超参数和 k_vals 作为值。

def kNearest(k_vals):

    skf = StratifiedKFold(n_splits=5, random_state=23)

    svp = Pipeline([('ss', StandardScaler()),
                ('knc', neighbors.KNeighborsClassifier())])

    parameters = {'n_neighbors': k_vals}

    clf = GridSearchCV(estimator=svp, param_grid=parameters, cv=skf)

    return clf

但是这样做会给我一个错误提示

Invalid parameter n_neighbors for estimator Pipeline. Check the list of available parameters with `estimator.get_params().keys()`.

我已阅读文档,但仍然不太明白错误所指示的内容以及如何修复它。

最佳答案

你是对的,scikit-learn 并没有对此进行详细记录。 (在类文档字符串中对它的引用为零。)

如果在网格搜索中使用管道作为估计器,则在指定参数网格时需要使用特殊语法。具体来说,您需要使用步骤名称,后跟双下划线,后跟参数名称,因为您将其传递给估计器。即

'<named_step>__<parameter>': value

就您而言:

parameters = {'knc__n_neighbors': k_vals}

应该可以解决问题。

此处 knc 是管道中的命名步骤。有一个属性将这些步骤显示为字典:

svp.named_steps

{'knc': KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
            metric_params=None, n_jobs=1, n_neighbors=5, p=2,
            weights='uniform'),
 'ss': StandardScaler(copy=True, with_mean=True, with_std=True)}

正如您的回溯所暗示的那样:

svp.get_params().keys()
dict_keys(['memory', 'steps', 'ss', 'knc', 'ss__copy', 'ss__with_mean', 'ss__with_std', 'knc__algorithm', 'knc__leaf_size', 'knc__metric', 'knc__metric_params', 'knc__n_jobs', 'knc__n_neighbors', 'knc__p', 'knc__weights'])

对此的一些官方引用:

关于python - 使用scikit-learn使用管道时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48726695/

相关文章:

python - 模型输入必须来自 `tf.keras.Input` ...,它们不能是前一个非输入层的输出

python Pandas :get rolling value of one Dataframe by rolling index of another Dataframe

python - Pytest - 从另一个装置调用一个装置

python - scikit学习模型的预测是线程安全的吗?

python - 当我拥有所需的 DLL 时,为什么会出现此导入错误?

bash - 如何让 bash 从 stdin 执行 ELF 二进制文件?

powershell - 如果有子目录,为什么 "get-childItem -recurse | select-string foo"不会导致错误?

python - sklearn 管道的并行化

python - 从类内部加载 MongoEngine 文档

python - GridSearchCV 引发 SIGABRT(-6) 错误,n_jobs != 1