python - 如何获取Spark决策树模型的节点信息

标签 python scala pyspark apache-spark-mllib apache-spark-ml

我想获得有关 Spark MLlib 的决策树生成模型的每个节点的更多详细信息。我可以使用 API 获得的最接近的是 print(model.toDebugString()),它返回类似这样的内容(取自 PySpark 文档)

  DecisionTreeModel classifier of depth 1 with 3 nodes
  If (feature 0 <= 0.0)
   Predict: 0.0
  Else (feature 0 > 0.0)
   Predict: 1.0

我如何修改 MLlib 源代码以获得例如每个节点的杂质和深度? (如果需要,我如何在 PySpark 中调用新的 Scala 函数?)

最佳答案

我将尝试通过描述我如何使用 PySpark 2.4.3 来完成@mostOfMajority 的回答。

根节点

给定一个训练有素的决策树模型,这是获取其根节点的方法:

def _get_root_node(tree: DecisionTreeClassificationModel):
    return tree._call_java('rootNode')

杂质

我们可以通过从根节点向下遍历树来得到杂质。它的pre-order transversal可以这样做:

def get_impurities(tree: DecisionTreeClassificationModel) -> List[float]:
    def recur(node):
        if node.numDescendants() == 0:
            return []
        ni = node.impurity()
        return (
            recur(node.leftChild()) + [ni] + recur(node.rightChild())
        )
    return recur(_get_root_node(tree))

例子

In [1]: print(tree.toDebugString)
DecisionTreeClassificationModel (uid=DecisionTreeClassifier_f90ba6dbb0fe) of depth 3 with 7 nodes
  If (feature 0 <= 6.5)
   If (feature 0 <= 3.5)
    Predict: 1.0
   Else (feature 0 > 3.5)
    If (feature 0 <= 5.0)
     Predict: 0.0
    Else (feature 0 > 5.0)
     Predict: 1.0
  Else (feature 0 > 6.5)
   Predict: 0.0


In [2]: cat.get_impurities(tree)
Out[2]: [0.4444444444444444, 0.5, 0.5]

关于python - 如何获取Spark决策树模型的节点信息,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49805284/

相关文章:

python - 从 stat().st_mtime 到 datetime?

python - Azure 数字孪生 Python API - 创建了模型孪生,但无法检索它

python - 使用带有Python和OpenCV的Raspberry Pi和Android IP摄像机进行对象检测

apache-spark - 为什么 dropna() 不起作用?

python - 如何通过pyspark读取gz压缩文件

python - 在 openSUSE 上安装最新的 Python

scala - RedisClient失败策略

algorithm - 将前 10% 的未排序 RDD 作为 Spark 中的另一个 RDD 返回的有效方法?

scala - 在Spark Dataframe中,如何在两个数据框中获取重复记录和不同记录?

python - 在 PySpark 中使用 'window' 函数按天分组时出现问题