我想获得有关 Spark MLlib 的决策树生成模型的每个节点的更多详细信息。我可以使用 API 获得的最接近的是 print(model.toDebugString())
,它返回类似这样的内容(取自 PySpark 文档)
DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.0)
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0
我如何修改 MLlib 源代码以获得例如每个节点的杂质和深度? (如果需要,我如何在 PySpark 中调用新的 Scala 函数?)
最佳答案
我将尝试通过描述我如何使用 PySpark 2.4.3 来完成@mostOfMajority 的回答。
根节点
给定一个训练有素的决策树模型,这是获取其根节点的方法:
def _get_root_node(tree: DecisionTreeClassificationModel):
return tree._call_java('rootNode')
杂质
我们可以通过从根节点向下遍历树来得到杂质。它的pre-order transversal可以这样做:
def get_impurities(tree: DecisionTreeClassificationModel) -> List[float]:
def recur(node):
if node.numDescendants() == 0:
return []
ni = node.impurity()
return (
recur(node.leftChild()) + [ni] + recur(node.rightChild())
)
return recur(_get_root_node(tree))
例子
In [1]: print(tree.toDebugString)
DecisionTreeClassificationModel (uid=DecisionTreeClassifier_f90ba6dbb0fe) of depth 3 with 7 nodes
If (feature 0 <= 6.5)
If (feature 0 <= 3.5)
Predict: 1.0
Else (feature 0 > 3.5)
If (feature 0 <= 5.0)
Predict: 0.0
Else (feature 0 > 5.0)
Predict: 1.0
Else (feature 0 > 6.5)
Predict: 0.0
In [2]: cat.get_impurities(tree)
Out[2]: [0.4444444444444444, 0.5, 0.5]
关于python - 如何获取Spark决策树模型的节点信息,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49805284/