python - 使用和不使用基于目标的编码的流水线

标签 python machine-learning encoding scikit-learn

如果我同时使用简单编码器和目标编码器,我对组装管道的最佳方式感到困惑。我找到了这个例子 here ,这说明问题与必须将目标变量与要编码的变量一起传递有关。

from examples.source_data.loaders import get_mushroom_data
from sklearn.compose import ColumnTransformer
from category_encoders import TargetEncoder

# get data from the mushroom dataset
X, y, _ = get_mushroom_data()

# encode the specified columns
ct = ColumnTransformer(
    [
        ('Target encoding', TargetEncoder(), ['bruises', 'odor'])
    ], remainder='passthrough'
)
encoded = ct.fit_transform(X=X, y=y)

但是,我不想直接执行 fit_transform,而是想将其添加为我的管道的一部分,以便我可以在交叉验证方案中执行此操作。

所以,不起作用的代码是:

pipeline_ordinal = Pipeline(steps=[('imputer', SimpleImputer(strategy='constant', fill_value='missing'))
    ,('ord encoding', ce.ordinal.OrdinalEncoder())])

pipeline_loo = Pipeline(steps=[('imputer', SimpleImputer(strategy='constant', fill_value='missing'))
    ,('loo encoding', ce.LeaveOneOutEncoder())])

preprocessor = ColumnTransformer(
    transformers=[('simple', pipeline_ordinal, ['x1','x2','x3']),
                  ('targetbased', pipeline_loo, ['x4','x5','y'])
                 ])

rf = RandomForestRegressor()

pipe = Pipeline(steps=[('preprocessor', preprocessor),('regression', rf)])

gs = GridSearchCV(pipe, param_grid=params, cv = cv)

gs.fit(X, y)

关于将这一切修补在一起的更好方法有什么想法吗?

编辑:

问题在于将 X 传递给 gs.fit()。照原样,上面的代码说:ValueError: A given column is not a column of the dataframe

如果我尝试变聪明并在 X 中发送“y”,那么它会告诉我 ValueError: cannot reindex from a duplicate axis

最佳答案

目标变量 y 被传递并在 gs.fit(X, y) 中被特殊处理。您不需要(也不应该)将其指定为 ColumnTransformer 中的列。

(pipeline_ordinalpipeline_loo 都可以访问 y,尽管前者实际上不会使用它。)

关于python - 使用和不使用基于目标的编码的流水线,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62880686/

相关文章:

python - 如何单独捕获这些异常?

python - Tkinter 图像脚本的路径双正斜杠

machine-learning - 线性回归: Substituting the non-numerical discrete domain of a predictor with numerical one

javascript - 如何设置PhantomJS内部编码?

python - 使用远程驱动程序设置 chrome 选项

python - 使用自己的数据集训练网络

python - 文本数据的多标签核外学习 : ValueError on partial fit

php - ASCII 编码字符串的问题 - PHP

mysql - 带有非英文字符和方法帖子的 Spring MVC

python - 计算每组连续 1 的最大数量