r - 从 R 中的 rpart 对象中提取拆分值

标签 r decision-tree rpart

我找不到 rpart 对象中节点的拆分值(或其他数据)。我在 summary(sample_model) 中看到它,但不在列表或数据框中

一些示例数据

foo.df <- structure(list(type = c("fudai", "fudai", "fudai", "fudai", "fudai", 
                              "fudai", "fudai", "tozama", "fudai", "fudai", "tozama", "tozama", 
                              "fudai", "tozama", "fudai", "fudai", "tozama", "fudai", "fudai", 
                              "tozama", "fudai", "fudai", "fudai", "tozama", "fudai", "fudai", 
                              "tozama", "fudai", "fudai", "fudai", "fudai", "fudai", "tozama", 
                              "fudai", "fudai", "fudai", "fudai", "fudai", "fudai", "tozama", 
                              "tozama", "fudai", "tozama", "tozama", "tozama", "tozama", "fudai", 
                              "fudai", "tozama", "tozama"), distance = c(12.5366985071383, 
                                                                         272.697138147139, 40.4780423740381, 109.806349869662, 147.781805212839, 
                                                                         89.4280438527415, 49.1425850803745, 555.414271440522, 119.365138867582, 
                                                                         182.902536555383, 310.019126513348, 277.122207392514, 214.510428881317, 
                                                                         235.111617874157, 104.494518693549, 50.7561853895564, 343.308898045237, 
                                                                         151.796857505073, 36.0391449169937, 30.8214406651022, 343.294467363406, 
                                                                         135.841501028422, 154.798119311647, 317.739208576563, 3.33794280697559, 
                                                                         98.9182898110913, 422.915369767251, 194.957988642709, 87.6548263591412, 
                                                                         187.571370158631, 236.292608259126, 17.915709270268, 193.548578374405, 
                                                                         262.190146422316, 21.6219797945323, 121.199009527283, 261.670997612517, 
                                                                         202.2051991431, 125.418459536787, 275.964068539003, 190.112226847932, 
                                                                         20.1753302760961, 488.80323504215, 579.25515722891, 233.500797034697, 
                                                                         207.588349435329, 183.770003408524, 168.739293254246, 313.140075747773, 
                                                                         131.69228390613), age = c(1756, 1711, 1712, 1746, 1868, 1866, 
                                                                                                   1682, 1617, 1771, 1764, 1672, 1636, 1864, 1704, 1762, 1868, 1694, 
                                                                                                   1749, 1703, 1616, 1691, 1702, 1723, 1683, 1742, 1691, 1623, 1721, 
                                                                                                   1704, 1745, 1749, 1723, 1639, 1661, 1843, 1845, 1669, 1698, 1698, 
                                                                                                   1664, 1868, 1633, 1783, 1642, 1615, 1648, 1734, 1758, 1725, 1635
                                                                         )), class = c("tbl_df", "tbl", "data.frame"), row.names = c(NA, 
                                                                                                                                     -50L))

还有一个基本模型

library("rpart")
sample_model <- rpart(formula = type ~ ., 
                  data = sample_data, 
                  method = "class",
                  control = rpart.control(xval = 50, minbucket = 5, cp = 0.05),
                  parms = list(split = "gini"))

rpart 文档说 sample_model$frame 中应该有一个名为“splits”的列,但它不在那里。引用:“拆分,每个节点的左右拆分标签的两列矩阵”https://www.rdocumentation.org/packages/rpart/versions/4.1-15/topics/rpart.object

sample_model$framesample_model 中的那些列在哪里?但是,我在

中看到了我想要的数据
summary(sample_model)

这是怎么回事?

最佳答案

文档确实过时了。这是通过检查 summary.rpart 函数派生的提取器:


rpart_splits <- function(fit, digits = getOption("digits")) {
  splits <- fit$splits
  if (!is.null(splits)) {
    ff <- fit$frame
    is.leaf <- ff$var == "<leaf>"
    n <- nrow(splits)
    nn <- ff$ncompete + ff$nsurrogate + !is.leaf
    ix <- cumsum(c(1L, nn))
    ix_prim <- unlist(mapply(ix, ix + c(ff$ncompete, 0), FUN = seq, SIMPLIFY = F))
    type <- rep.int("surrogate", n)
    type[ix_prim[ix_prim <= n]] <- "primary"
    type[ix[ix <= n]] <- "main"
    left <- character(nrow(splits))
    side <- splits[, 2L]
    for (i in seq_along(left)) {
      left[i] <- if (side[i] == -1L)
                   paste("<", format(signif(splits[i, 4L], digits)))
                 else if (side[i] == 1L)
                   paste(">=", format(signif(splits[i, 4L], digits)))
                 else {
                   catside <- fit$csplit[splits[i, 4L], 1:side[i]]
                   paste(c("L", "-", "R")[catside], collapse = "", sep = "")
                 }
    }
    cbind(data.frame(var = rownames(splits),
                     type = type,
                     node = rep(as.integer(row.names(ff)), times = nn),
                     ix = rep(seq_len(nrow(ff)), nn),
                     left = left),
          as.data.frame(splits, row.names = F))
  }
}

过滤 type == "main" 以仅获取主要拆分:

> fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
> rpart_splits(fit)
      var      type node ix    left count ncat    improve index       adj
1   Start      main    1  1  >= 8.5    81    1 6.76232996   8.5 0.0000000
2  Number   primary    1  1   < 5.5    81   -1 2.86679493   5.5 0.0000000
3     Age   primary    1  1  < 39.5    81   -1 2.25021152  39.5 0.0000000
4  Number surrogate    1  1   < 6.5     0   -1 0.80246914   6.5 0.1578947
5   Start      main    2  2 >= 14.5    62    1 1.02052786  14.5 0.0000000
6     Age   primary    2  2    < 55    62   -1 0.68486352  55.0 0.0000000
7  Number   primary    2  2   < 4.5    62   -1 0.29753321   4.5 0.0000000
8  Number surrogate    2  2   < 3.5     0   -1 0.64516129   3.5 0.2413793
9     Age surrogate    2  2    < 16     0   -1 0.59677419  16.0 0.1379310
10    Age      main    5  4    < 55    33   -1 1.24675325  55.0 0.0000000
11  Start   primary    5  4 >= 12.5    33    1 0.28877005  12.5 0.0000000
12 Number   primary    5  4  >= 3.5    33    1 0.17532468   3.5 0.0000000
13  Start surrogate    5  4   < 9.5     0   -1 0.75757576   9.5 0.3333333
14 Number surrogate    5  4  >= 5.5     0    1 0.69696970   5.5 0.1666667
15    Age      main   11  6  >= 111    21    1 1.71428571 111.0 0.0000000
16  Start   primary   11  6 >= 12.5    21    1 0.79365079  12.5 0.0000000
17 Number   primary   11  6  >= 3.5    21    1 0.07142857   3.5 0.0000000

关于r - 从 R 中的 rpart 对象中提取拆分值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56209774/

相关文章:

r - R用唯一的随机数替换NA

r - 警告消息 : "missing values in resampled performance measures" in caret train() using rpart

r - 为R中的rpart/ctree包获取预测数据集的每一行的决策树规则/路径模式

r - 使用插入符号库修剪树返回复杂的树

r - R lubridate 中两个日期之间的严格差异

r - 如何在 Delaunay 三角剖分中设置三角形边的最大长度?

r - 在R时间中找到第二个间隔

Python 复杂对象的决策树分类

algorithm - 优化大量基于二元谓词的规则的计算

machine-learning - scikit、分类列、决策树