r - 如何从 party:::ctree 模型中删除训练数据?

标签 r memory-management classification decision-tree

我创建了几个 ctree 模型(大约 40 到 80 个),我想经常对其进行评估。

一个问题是模型对象非常大(40 个模型需要超过 2.8G 的内存),在我看来,它们存储了训练数据,可能作为 modelname@data 和 modelname@responses,而不仅仅是与预测新数据相关的信息。

大多数其他 R 学习包都具有是否将数据包含在模型对象中的可配置选项,但我在文档中找不到任何提示。我还尝试通过

分配空的 ModelEnv 对象
modelname@data <- new("ModelEnv")

但对相应 RData 文件的大小没有影响。

有人知道 ctree 是否真的存储训练数据以及如何从 ctree 模型中删除与新预测无关的所有数据,以便我可以将其中的许多数据放入内存中?

非常感谢,

斯特凡

<小时/>

感谢您的反馈,这已经非常有帮助了。

我使用dputstr来更深入地观察该对象,发现模型中不包含任何训练数据,但有一个响应 槽,其中似乎有训练标签和行名。无论如何,我注意到每个节点都有每个训练样本的权重向量。检查代码一段时间后,我最终用谷歌搜索了一下,在 party NEWS 日志中发现了以下评论:

         CHANGES IN party VERSION 0.9-13 (2007-07-23)

o   update `mvt.f'

o   improve the memory footprint of RandomForest objects
    substancially (by removing the weights slots from each node).

事实证明,party 包中有一个名为 R_remove_weights 的 C 函数可以删除这些权重,其定义如下:

SEXP R_remove_weights(SEXP subtree, SEXP removestats) {
    C_remove_weights(subtree, LOGICAL(removestats)[0]);
    return(R_NilValue);
}

它也工作得很好:

# cc is my model object

sum(unlist(lapply(slotNames(cc), function (x)  object.size(slot(cc, x)))))
# returns: [1] 2521256
save(cc, file="cc_before.RData")

.Call("R_remove_weights", cc@tree, TRUE, PACKAGE="party")
# returns NULL and removes weights and node statistics

sum(unlist(lapply(slotNames(cc), function (x)  object.size(slot(cc, x)))))
# returns: [1] 1521392
save(cc, file="cc_after.RData")

正如您所看到的,它大大减少了对象大小,从大约 2.5MB 减少到 1.5MB。

但奇怪的是,相应的 RData 文件非常大,而且对它们没有影响:

$ ls -lh cc*
-rw-r--r-- 1 user user 9.6M Aug 24 15:44 cc_after.RData
-rw-r--r-- 1 user user 9.6M Aug 24 15:43 cc_before.RData

解压文件后发现2.​​5MB的对象占用了近100MB的空间:

$ cp cc_before.RData cc_before.gz
$ gunzip cc_before.gz 
$ ls -lh cc_before*
-rw-r--r-- 1 user user  98M Aug 24 15:45 cc_before

有什么想法,什么可能导致这种情况?

最佳答案

我找到了解决当前问题的方法,因此如果有人可能遇到同样的问题,我会写下这个答案。我将描述我的过程,所以可能有点漫无目的,所以请耐心等待。

在没有任何线索的情况下,我考虑了对插槽进行核攻击并删除权重,以使对象尽可能小,并至少节省一些内存,以防找不到修复方法。因此,我删除了 @data@responses 作为开始,没有它们,预测仍然很好,但对 .RData 文件大小没有影响。

我反其道而行之,创建了一个空的 ctree 模型,只需将树插入其中即可:

> library(party)

## create reference predictions for the dataset
> predictions.org <- treeresponse(c1, d)

## save tree object for reference
save(c1, "testSize_c1.RData")

检查原始对象的大小:

$ ls -lh testSize_c1.RData 
-rw-r--r-- 1 user user 9.6M 2011-08-25 14:35 testSize_c1.RData

现在,让我们创建一个空的 CTree 并仅复制树:

## extract the tree only 
> c1Tree <- c1@tree

## create empty tree and plug in the extracted one 
> newCTree <- new("BinaryTree")
> newCTree@tree <- c1Tree

## save tree for reference 
save(newCTree, file="testSize_newCTree.RData")

这个新的树对象现在小得多:

$ ls -lh testSize_newCTree.RData 
-rw-r--r-- 1 user user 108K 2011-08-25 14:35 testSize_newCTree.RData

但是,它不能用于预测:

## predict with the new tree
> predictions.new <- treeresponse(newCTree, d)
Error in object@cond_distr_response(newdata = newdata, ...) : 
  unused argument(s) (newdata = newdata)

我们没有设置@cond_distr_response,这可能会导致错误,因此也复制原始的并再次尝试预测:

## extract cond_distr_response from original tree
> cdr <- c1@cond_distr_response
> newCTree@cond_distr_response <- cdr

## save tree for reference 
save(newCTree, file="testSize_newCTree_with_cdr.RData")

## predict with the new tree
> predictions.new <- treeresponse(newCTree, d)

## check correctness
> identical(predictions.org, predictions.new)
[1] TRUE

这工作得很好,但现在 RData 文件的大小又回到了原来的值:

$ ls -lh testSize_newCTree_with_cdr.RData 
-rw-r--r-- 1 user user 9.6M 2011-08-25 14:37 testSize_newCTree_with_cdr.RData

简单地打印插槽,将其显示为绑定(bind)到环境的函数:

> c1@cond_distr_response
function (newdata = NULL, mincriterion = 0, ...) 
{
    wh <- RET@get_where(newdata = newdata, mincriterion = mincriterion)
    response <- object@responses
    if (any(response@is_censored)) {
        swh <- sort(unique(wh))
        RET <- vector(mode = "list", length = length(wh))
        resp <- response@variables[[1]]
        for (i in 1:length(swh)) {
            w <- weights * (where == swh[i])
            RET[wh == swh[i]] <- list(mysurvfit(resp, weights = w))
        }
        return(RET)
    }
    RET <- .Call("R_getpredictions", tree, wh, PACKAGE = "party")
    return(RET)
}
<environment: 0x44e8090>

因此,最初问题的答案似乎是对象的方法将环境绑定(bind)到它,然后将环境与对象一起保存在相应的 RData 文件中。这也可以解释为什么读取 RData 文件时会加载多个包。

因此,为了摆脱环境,我们无法复制方法,但没有它们我们也无法预测。相当“肮脏”的解决方案是模拟原始方法的功能并直接调用底层C代码。经过深入研究源代码,这确实是可能的。正如上面复制的代码所示,我们需要调用 get_where,它确定输入到达的树的终端节点。然后,我们需要调用 R_getpredictions 来确定每个输入样本的终端节点的响应。棘手的部分是我们需要以正确的输入格式获取数据,因此必须调用 ctree 中包含的数据预处理:

## create a character string of the formula which was used to fit the free
## (there might be a more neat way to do this)
> library(stringr)
> org.formula <- str_c(
                   do.call(str_c, as.list(deparse(c1@data@formula$response[[2]]))),
                   "~", 
                   do.call(str_c, as.list(deparse(c1@data@formula$input[[2]]))))

## call the internal ctree preprocessing 
> data.dpp <- party:::ctreedpp(as.formula(org.formula), d)

## create the data object necessary for the ctree C code
> data.ivf <- party:::initVariableFrame.df(data.dpp@menv@get("input"), 
                                           trafo = ptrafo)

## now call the tree traversal routine, note that it only requires the tree
## extracted from the @tree slot, not the whole object
> nodeID <- .Call("R_get_nodeID", c1Tree, data.ivf, 0, PACKAGE = "party")

## now determine the respective responses
> predictions.syn <- .Call("R_getpredictions", c1Tree, nodeID, PACKAGE = "party")

## check correctness
> identical(predictions.org, predictions.syn)
[1] TRUE

我们现在只需要保存提取的树和公式字符串就可以预测新数据:

> save(c1Tree, org.formula, file="testSize_extractedObjects.RData")

我们可以进一步删除不必要的权重,如上面更新的问题中所述:

> .Call("R_remove_weights", c1Tree, TRUE, PACKAGE="party")
> save(c1Tree, org.formula, file="testSize_extractedObjects__removedWeights.RData")

现在让我们再次看看文件大小:

$ ls -lh testSize_extractedObjects*
-rw-r--r-- 1 user user 109K 2011-08-25 15:31 testSize_extractedObjects.RData
-rw-r--r-- 1 user user  43K 2011-08-25 15:31 testSize_extractedObjects__removedWeights.RData

最后,使用该模型只需要 43K,而不是(压缩后)9.6M。我现在应该能够在 3G 堆空间中容纳任意数量的内存。万岁!

关于r - 如何从 party:::ctree 模型中删除训练数据?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/7149901/

相关文章:

machine-learning - k 最近邻的分类属性的距离度量

node.js - 在 Node.js 上训练分类器(自然 - NLP)以查找意外句子

python-3.x - Wine 质量数据集分析

r - 如何在 data.table 中编写累积计算

ios - 快速图像内存

r - 使用两种方法调用 ggplot() 时出现美学错误

performance - 可分配数组性能

Swift 中的 iOS 内存警告

r - 如何从 r 中的数据框中的 2 列中提取唯一级别

r - 如何使用 MCMCpack 获得差异的后验?