python - scikit learn - 决策树中的特征重要性计算

标签 python scikit-learn decision-tree feature-selection

我正在尝试了解如何计算 sci-kit 学习中决策树的特征重要性。之前有人问过这个问题,但我无法重现算法提供的结果。

例如:

from StringIO import StringIO

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree.export import export_graphviz
from sklearn.feature_selection import mutual_info_classif

X = [[1,0,0], [0,0,0], [0,0,1], [0,1,0]]

y = [1,0,1,1]

clf = DecisionTreeClassifier()
clf.fit(X, y)

feat_importance = clf.tree_.compute_feature_importances(normalize=False)
print("feat importance = " + str(feat_importance))

out = StringIO()
out = export_graphviz(clf, out_file='test/tree.dot')

特征重要性的结果:

feat importance = [0.25       0.08333333 0.04166667]

并给出如下决策树:

decision tree

现在,这个 answer对于类似的问题,建议重要性计算为

formula_a

其中 G 是节点杂质,在本例中是基尼杂质。据我了解,这是杂质减少。但是,对于功能 1,这应该是:

formula_b

answer表明重要性由到达节点的概率加权(由到达该节点的样本比例近似)。同样,对于功能 1,这应该是:

formula_c

这两个公式都提供了错误的结果。如何正确计算特征重要性?

最佳答案

我认为功能重要性取决于实现,因此我们需要查看 scikit-learn 的文档。

The feature importances. The higher, the more important the feature. The importance of a feature is computed as the (normalized) total reduction of the criterion brought by that feature. It is also known as the Gini importance

减少或加权信息增益定义为:

The weighted impurity decrease equation is the following:

N_t / N * (impurity - N_t_R / N_t * right_impurity - N_t_L / N_t * left_impurity)

where N is the total number of samples, N_t is the number of samples at the current node, N_t_L is the number of samples in the left child, and N_t_R is the number of samples in the right child.

http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier

由于每个特征在您的案例中只使用一次,因此特征信息必须等于上面的等式。

对于 X[2] :

feature_importance = (4/4) * (0.375 - (0.75 * 0.444)) = 0.042

对于 X[1] :

feature_importance = (3/4) * (0.444 - (2/3 * 0.5)) = 0.083

对于 X[0]:

feature_importance = (2/4) * (0.5) = 0.25

关于python - scikit learn - 决策树中的特征重要性计算,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49170296/

相关文章:

python - 使用 numpy、pandas 和 scikit-learn 等依赖包运行 pyspark

python - 更改使用导出 graphviz 创建的决策 TreeMap 的颜色

具有多种功能的Python sklearn决策树分类器?

python - pymssql 中的 freetds.log 文件非常大。如何禁用?

python - python中的共聚类算法

python - scikit-learn 中的随机分层 k 折交叉验证?

r - Tidymodels 包 : Visualising a random forest model using ggplot() to show the most important predictors

python - 使用虚拟环境时权限被拒绝

python - numpy.distutils 对架构的奇怪选择

python - 你如何在Python中使用IXR_Base64?