scikit-learn - 如何使用 scikit-learn 从决策树中获取区间限制?

标签 scikit-learn decision-tree tree-structure

假设我正在使用泰坦尼克号数据集,仅包含变量年龄:

import pandas as pd

data = pd.read_csv('https://www.openml.org/data/get_csv/16826755/phpMYEkMl')[["age", "survived"]]
data = data.replace('?', np.nan)
data = data.fillna(0)
print(data)

结果:

         age  survived
0         29         1
1     0.9167         1
2          2         0
3         30         0
4         25         0
...      ...       ...
1304    14.5         0
1305       0         0
1306    26.5         0
1307      27         0
1308      29         0

[1309 rows x 2 columns]

现在我训练一个决策树来预测年龄的生存情况:

from sklearn.tree import DecisionTreeClassifier
tree_model = DecisionTreeClassifier(max_depth=3)
tree_model.fit(data['age'].to_frame(),data["survived"])

如果我打印树的结构:

from sklearn import tree
print(tree.export_text(tree_model))

我得到:

|--- feature_0 <= 0.08
|   |--- class: 0
|--- feature_0 >  0.08
|   |--- feature_0 <= 8.50
|   |   |--- feature_0 <= 1.50
|   |   |   |--- class: 1
|   |   |--- feature_0 >  1.50
|   |   |   |--- class: 1
|   |--- feature_0 >  8.50
|   |   |--- feature_0 <= 60.25
|   |   |   |--- class: 0
|   |   |--- feature_0 >  60.25
|   |   |   |--- class: 0

这意味着每个节点的最终划分是:

0-0.08; 0.08-1.50; 1.50-8.50; 8.50-60; >60

我的问题是,如何在如下所示的数组中捕获这些限制:

[-np.inf, 0.08, 1.5, 8.5, 60, np.inf]

谢谢!

最佳答案

决策分类器(在本例中为tree_model)有一个名为tree_的属性,它允许访问低级属性。

print(tree_model.tree_.threshold)

array([ 0.08335, -2.     ,  8.5    ,  1.5    , -2.     , -2.     ,
       60.25   , -2.     , -2.     ])
print(tree_model.tree_.feature)

array([ 0, -2,  0,  0, -2, -2,  0, -2, -2], dtype=int64)

数组featurethreshold仅适用于分割节点。因此,这些数组中叶节点的值是任意的。

要获取特征的划分/阈值,您可以使用feature数组过滤阈值。

threshold = tree_model.tree_.threshold
feature = tree_model.tree_.feature
feature_threshold = threshold[feature == 0]
thresholds = sorted(feature_threshold)
print(thresholds)

[0.08335000276565552, 1.5, 8.5, 60.25]

要拥有np.inf,您需要自己添加它。

thresholds = [-np.inf] + thresholds + [np.inf]
print(thresholds)

[-inf, 0.08335000276565552, 1.5, 8.5, 60.25, inf]

引用:Understanding the decision tree structure .

关于scikit-learn - 如何使用 scikit-learn 从决策树中获取区间限制?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/75663472/

相关文章:

python - 使用 Dask 或 Joblib 进行并行 Sklearn 模型构建

java - 使用 Weka 进行错误的类别预测

python - 如何分割数据集 - 标签数 = 150 与样本数 = 600 不匹配

c# - AdaBoost 反复选择相同的弱学习器

java - 树结构的正则表达式?

python - Scikit-Learn LinearRegression 在非常简单的数据集上表现不佳,

python - 从 Scikit (Python) 中的管道检索中间特征

python - 如何预测值(value)?

mysql - Rails 3、MySQL、树形结构

machine-learning - 贝叶斯超参数优化