python - 提取每个终端节点的路径

标签 python json nested xgboost

我有一个 python 嵌套字典结构,如下所示。 这是一个小示例,但我有更大的示例,可以有不同级别的嵌套。

从中,我需要提取一个列表:

  1. 每个终端“叶”节点一条记录
  2. 表示通向该节点的逻辑路径的字符串、列表或对象
    • (例如“nodeid_3:X < 0.500007 和 X < 0.279907”)

这个周末我花了大部分时间试图让一些东西发挥作用,并且意识到我在递归方面有多糟糕。

# Extract json string
json_string = booster.get_dump(with_stats=True, dump_format='json')[0]

# Convert to python dictionary
json.loads(json_string)

{u'children': [{u'children': [
    {u'cover': 2291, u'leaf': -0.0611795, u'nodeid': 3},
    {u'cover': 1779, u'leaf': -0.00965727, u'nodeid': 4}],
   u'cover': 4070,
   u'depth': 1,
   u'gain': 265.811,
   u'missing': 3,
   u'no': 4,
   u'nodeid': 1,
   u'split': u'X',
   u'split_condition': 0.279907,
   u'yes': 3},
  {u'cover': 3930, u'leaf': -0.0611946, u'nodeid': 2}],
 u'cover': 8000,
 u'depth': 0,
 u'gain': 101.245,
 u'missing': 1,
 u'no': 2,
 u'nodeid': 0,
 u'split': u'X',
 u'split_condition': 0.500007,
 u'yes': 1}

最佳答案

您的数据结构是递归的。如果一个节点有一个键,那么我们可以认为它不是终端。

要分析数据,您需要一个递归函数来跟踪祖先(路径)。

我会这样实现:

def find_path(obj, path=None):
    path = path or []
    if 'children' in obj:
        child_obj = {k: v for k, v in obj.items()
                     if k in ['nodeid', 'split_condition']}
        child_path = path + [child_obj]
        children = obj['children']
        for child in children:
            find_path(child, child_path)
    else:
        pprint.pprint((obj, path))

如果您调用:

find_path(data)

您得到 3 个结果:

({'cover': 2291, 'leaf': -0.0611795, 'nodeid': 3},
 [{'nodeid': 0, 'split_condition': 0.500007},
  {'nodeid': 1, 'split_condition': 0.279907}])
({'cover': 1779, 'leaf': -0.00965727, 'nodeid': 4},
 [{'nodeid': 0, 'split_condition': 0.500007},
  {'nodeid': 1, 'split_condition': 0.279907}])
({'cover': 3930, 'leaf': -0.0611946, 'nodeid': 2},
 [{'nodeid': 0, 'split_condition': 0.500007}])

当然,您可以用 yield 替换对 pprint.pprint() 的调用,从而将此函数转换为生成器:

def iter_path(obj, path=None):
    path = path or []
    if 'children' in obj:
        child_obj = {k: v for k, v in obj.items()
                     if k in ['nodeid', 'split_condition']}
        child_path = path + [child_obj]
        children = obj['children']
        for child in children:
            # for o, p in iteration_path(child, child_path):
            #     yield o, p
            yield from iter_path(child, child_path)
    else:
        yield obj, path

请注意递归调用中 yield from 的用法。您可以像下面这样使用这个生成器:

for obj, path in iter_path(data):
    pprint.pprint((obj, path))

您还可以更改 child_obj 对象的构建方式以满足您的需求。

要保持对象的顺序:反转 if 条件:if 'children' not in obj: ...

关于python - 提取每个终端节点的路径,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44620114/

相关文章:

python - 如何根据另一个列表中的值重新排列列表中的值?

javascript 数组序列化

java - 编辑 .json 文件 : adding lines in certain parts of the file?

Java:石头、剪刀、布即将完成,但出现 fatal error

java - 如何替换验证过程中的大量嵌套if?

php - 使用准备好的语句在 mysqli 中嵌套 SELECT

python - Panda 0.22 dataframe.drop 比它应该多的行

python - scrapy 无法提交表单

python - 获取整数和 float 的属性

python - 如何序列化一个数组?