python - 跟踪 sklearn 预处理中的输出列

标签 python scikit-learn preprocessor

如何跟踪 sklearn.compose.ColumnTransformer 生成的转换数组的列?我所说的“跟踪”是指执行逆变换所需的每一位信息都必须明确显示。这至少包括以下内容:

  1. 输出数组中每列的源变量是什么?
  2. 如果输出数组的某一列来自分类变量的 one-hot 编码,那么该类别是什么?
  3. 每个变量的确切估算值是多少?
  4. 用于标准化每个数值变量的(均值、标准差)是什么? (由于估算缺失值,这些可能与直接计算不同。)

我使用基于 this answer 的相同方法。我的输入数据集也是一个通用的 pandas.DataFrame ,具有多个数字和分类列。是的,这个答案可以改变原始数据集。但我忘记了输出数组中的列。我需要这些信息来进行同行评审、报告撰写、演示和进一步的模型构建步骤。我一直在寻找一种系统的方法,但没有成功。

最佳答案

上面提到的答案是基于this在 Sklearn 中。

您可以使用以下代码片段获得前两个问题的答案。

def get_feature_names(columnTransformer):

    output_features = []

    for name, pipe, features in columnTransformer.transformers_:
        if name!='remainder':
            for i in pipe:
                trans_features = []
                if hasattr(i,'categories_'):
                    trans_features.extend(i.get_feature_names(features))
                else:
                    trans_features = features
            output_features.extend(trans_features)

    return output_features
import pandas as pd
pd.DataFrame(preprocessor.fit_transform(X_train),
            columns=get_feature_names(preprocessor))

enter image description here

transformed_cols = get_feature_names(preprocessor)

def get_original_column(col_index):
    return transformed_cols[col_index].split('_')[0]

get_original_column(3)
# 'embarked'

get_original_column(0)
# 'age'
def get_category(col_index):
    new_col = transformed_cols[col_index].split('_')
    return 'no category' if len(new_col)<2 else new_col[-1]

print(get_category(3))
# 'Q'

print(get_category(0))
# 'no category'

在当前版本的 Sklearn 中,跟踪某项功能是否进行了插补或缩放并非易事。

关于python - 跟踪 sklearn 预处理中的输出列,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58911481/

相关文章:

python - 为什么 Decimal ('0' ) > 9999.0 在 Python 中为真?

python - 对具有相同日期的单元格求和,并按多列分组

PythonAnywhere:django.db.utils.OperationalError:没有这样的表:

python-3.x - Sklearn PCA 分解解释_方差_比率_

python-3.x - 使用相关和随机语料库计算 TF-IDF 单词得分

c++ - ##__VA_ARGS__ 是什么意思?

c - 将十六进制字符串转换为字节数组的预处理器宏

preprocessor - C++预处理

python - 添加新小部件时,滚动区域无法扩展(滚动)

Python sklearn.datasets.dump_svmlight_file 未能输出列的正确索引