python - 如何使用 scikit learn pipelines 简化数据预处理

标签 python machine-learning scikit-learn

我有 2 个 dfs。 df1 是猫的示例,df2 是狗的示例。

我必须对这些 dfs 进行一些预处理,目前我正在通过调用不同的函数来进行这些预处理。我想使用 scikit learn pipelines。

其中一个函数是一个特殊的编码器函数,它将查看 df 中的列并返回一个特殊值。我在 scikit learn 中的我看到被使用之类的类中重写了该函数:

class Encoder(BaseEstimator, TransformerMixin):

    def __init__(self):
        self.values = []
        super().__init__()

    def fit(self, X, y=None):
        return self

    def encode(self,row):
        result = []
        for base in row:
            result.append(bases[base])

        self.values.append(result)

    def transform(self, X):
        assert isinstance(X, pd.DataFrame)
        X["seq_new"].apply(self.encode)

        return self.values

所以现在我会得到 2 个列表:

encode = Encoder()
X1 = encode.transform(df1)
X2 = encode.transform(df2)

下一步是:

features = np.concatenate((X1, X1), axis=0)

下一步构建标签:

Y_dog = [[1]] * len(X1)
Y_cat = [[0]] * len(X2)
labels = np.concatenate((Y_dog, Y_cat), axis=0)

以及其他一些操作,然后我将执行 model_selection.train_test_split() 将数据拆分为训练和测试。

我如何在 scikit 管道中调用所有这些函数?我找到的示例是从已经完成训练/测试拆分的地方开始的。

最佳答案

sklearn.pipeline.Pipeline 的特点是每一步都需要实现 fittransform。因此,举例来说,如果您知道您始终需要执行串联步骤,并且您真的很想将其放入 Pipeline (我不会,但这只是我的拙见),您需要使用适当的 fittransform 方法创建一个 Concatenator class 。 p>

类似这样的事情:

class Encoder(object):
    def fit(self, X, *args, **kwargs):
        return self
    def transform(self, X):
        return X*2

class Concatenator(object):
    def fit(self, X, *args, **kwargs):
        return self
    def transform(self, Xs):
        return np.concatenate(Xs, axis=0)

class MultiEncoder(Encoder):
    def transform(self, Xs):
        return list(map(super().transform, Xs))

pipe = sklearn.pipeline.Pipeline((
    ("encoder", MultiEncoder()),
    ("concatenator", Concatenator())
))

pipe.fit_transform((
    pd.DataFrame([[1,2],[3,4]]), 
    pd.DataFrame([[5,6],[7,8]])
))

关于python - 如何使用 scikit learn pipelines 简化数据预处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52978272/

相关文章:

Python - POST 请求时 403 访问被拒绝

python - Matplotlib 动画 : how to dynamically extend x limits?

python - 为什么这一行会产生错误?

python-3.x - 如何在 Azure ML 实验脚本上导入自定义函数?

python - 显示逻辑回归分类器 sklearn 的训练迭代分数

python - python unittest doc推荐的惰性导入方法如何做?

python - 如何消除具有共享轴的子图上的额外空白?

matlab - 在 LIBSVM matlab 中执行额外验证

python-3.x - 如何构建用于创建模型的分类器?

python - 使用 joblib 加载腌制的 scikit-learn 模型时出现 KeyError