python - scikit-learn 在其树结构中的哪个位置保存每个叶节点的决策标签?

标签 python scikit-learn decision-tree tree-structure

我已经使用 scikit-learn 训练了一个随机森林模型,现在我想将其树结构保存在一个文本文件中,以便我可以在其他地方使用它。 根据this link树对象由许多并行数组组成,每个数组都包含有关树的不同节点的一些信息(例如,左 child ,右 child ,它检查的特征,...)。但是好像没有关于每个叶节点对应的类标签的信息!上面链接中提供的示例甚至都没有提到它。

有谁知道类标签存储在 scikit-learn 决策树结构中的什么位置?

最佳答案

查看 sklearn.tree.DecisionTreeClassifier.tree_.value 的文档:

from sklearn.datasets import load_iris
from sklearn.cross_validation import cross_val_score
from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()

clf.fit(iris.data, iris.target)

print(clf.classes_)

[0, 1, 2]

print(clf.tree_.value)

[[[ 50.  50.  50.]]

 [[ 50.   0.   0.]]

 [[  0.  50.  50.]]

 [[  0.  49.   5.]]

 [[  0.  47.   1.]]

 [[  0.  47.   0.]]

 [[  0.   0.   1.]]

 [[  0.   2.   4.]]

 [[  0.   0.   3.]]

 [[  0.   2.   1.]]

 [[  0.   2.   0.]]

 [[  0.   0.   1.]]

 [[  0.   1.  45.]]

 [[  0.   1.   2.]]

 [[  0.   0.   2.]]

 [[  0.   1.   0.]]

 [[  0.   0.  43.]]]

clf.tree_.value 中的每一行“包含每个节点的常量预测值”,(help(clf.tree_)) 对应于索引- clf.classes_ 的索引。

参见 this answer了解(几乎)更多详细信息。

关于python - scikit-learn 在其树结构中的哪个位置保存每个叶节点的决策标签?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44158993/

相关文章:

python 复制带有时间戳的文件

python - 延时图像的亮度/直方图归一化

python - 你如何解决 pyinstaller for scipy 中的 'hidden imports not found!' 警告?

python - 使用 scikit-learn PCA 找到具有最高方差的维度

r - 拨浪鼓错误 : the FUN argument to prp is not a function 中的决策树可视化错误

python - Django 模型没有将数据实时保存到数据库

python - 使用 flask.redirect 时 Flask 302 错误代码

python - 如何创建平衡的 k 均值地理空间聚类?

python - 构建决策树回归模型并预测样本输出 - 机器学习

numpy - 训练模型时发生不兼容行维度的值错误