python - 如何使管道跳过该步骤(使用 "passthrough")以及 param_grid 中适用于该步骤的所有参数?

标签 python scikit-learn gridsearchcv

我正在使用 PCA 在 sklearn 中创建管道,并使用“passthrough”跳过此步骤。 对于 PCA,我正在测试 n_components 参数的多个值。

from sklearn.datasets import make_regression
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV

X_train, y_train = make_regression(n_samples=100, n_features=10)


param_grid = {
    'reduce_dim': [PCA(), 'passthrough'],
    'reduce_dim__n_components': [1,2,3]
}

pipeline = Pipeline(
        steps=[
            ('reduce_dim', None), 
            ('regressor', LinearRegression())
        ]
    )

grid_search = GridSearchCV(
    estimator=pipeline, 
    param_grid=param_grid, 
    verbose=10
)
grid_search.fit(X_train, y_train)

我想要实现的是使用 n_components=[1,2,3] 进行 PCA 的 3 次拟合,以及不使用 PCA 的 1 次拟合。

Fitting 5 folds for each of 4 candidates, totalling 20 fits

我得到的是 PCA 的 3 个拟合和没有 PCA 的 3 个拟合(我不需要在没有 PCA 的情况下测试 n_components 的所有三种可能性):

Fitting 5 folds for each of 6 candidates, totalling 30 fits

然后是一个运行时错误,基本上表明我无法将 n_components 值分配给“passthrough”(str 对象)

[CV 1/5; 4/6] START reduce_dim=passthrough, reduce_dim__n_components=1...
AttributeError: 'str' object has no attribute 'set_params'

如何使管道跳过该步骤(在这种情况下是 reduce_dim)以及适用于该步骤的所有参数?

我知道我可以像这样使用 param_grid:

param_grid = [
    {
        'reduce_dim': [PCA()],
        'reduce_dim__n_components': [1,2,3]
    },
    {}
]

但是可以用更优雅的方式来完成吗,因为在更复杂的场景中代码会变得非常困惑。

最佳答案

您想要的参数网格也可以在单个参数的单个字典中定义:

param_grid = {
    'reduce_dim' = [PCA(n_components=1), PCA(n_components=2), PCA(n_components=3), 'passthrough']
}

这样做的优点是避免定义多个可能不那么“困惑”的字典。

关于python - 如何使管道跳过该步骤(使用 "passthrough")以及 param_grid 中适用于该步骤的所有参数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68207661/

相关文章:

python - 随机搜索CV : All estimators failed to fit

python - 如何将自定义的QSplitter类显示到QMainWindow上

python - "Invalid Index to Scalar Variable"- 使用 Scikit Learn 时 "accuracy_score"

python - GridsearchCV 上的预处理

Scikit-learn 实现狄利克雷过程高斯混合模型 : Gibbs sampling or Variational inference?

python - 参数不会在 scikit-learn GridSearchCV 中自定义估计器

python - 使用 Scikit-Learn 在 RegressorChain 上进行 GridSearch?

python - 如何修复 'TypeError: __init__() got an unexpected keyword argument ' 发件人”

python - 比较两个列表以选择最大值

python - SQLAlchemy 和 Falcon - session 初始化