检索插入符号中保留折叠的预测

标签 r cross-validation r-caret ensemble-learning

我想知道如何恢复交叉验证预测。我有兴趣手动构建堆叠模型 ( like here in point 3.2.1 ),我需要模型对每个保留折叠的预测。我附上了一个简短的例子。

# load the library
library(caret)
# load the iris dataset
data(cars)
# define folds
cv_folds <- createFolds(cars$Price, k = 5, list = TRUE)
# define training control
train_control <- trainControl(method="cv", index = cv_folds, savePredictions = 'final')
# fix the parameters of the algorithm
# train the model
model <- caret::train(Price~., data=cars, trControl=train_control, method="gbm", verbose = F)
# looking at predictions
model$pred

# verifying the number of observations
nrow(model$pred[model$pred$Resample == "Fold1",])
nrow(cars)

我想知道在折叠 1-4 上估计模型和在折叠 5 上评估等的预测是什么。查看 model$pred 似乎没有给我我想要的东西需要。

最佳答案

当使用 createFolds 函数创建的折叠在插入符号中执行 CV 时,默认情况下使用训练索引。所以当你这样做的时候:

cv_folds <- createFolds(cars$Price, k = 5, list = TRUE)

你收到火车套装折叠

lengths(cv_folds)
#output
Fold1 Fold2 Fold3 Fold4 Fold5 
  161   160   161   160   162

每个包含 20% 的数据

然后您在 trainControl 中指定了这些折叠:

train_control <- trainControl(method="cv", index = cv_folds, savePredictions = 'final')

trainControl 的帮助下:

index - a list with elements for each resampling iteration. Each list element is a vector of integers corresponding to the rows used for training at that iteration.

indexOut - a list (the same length as index) that dictates which data are held-out for each resample (as integers). If NULL, then the unique set of samples not contained in index is used.

因此,每次模型都建立在 160 行上并在其余行上进行验证。这就是为什么

nrow(model$pred[model$pred$Resample == "Fold1",])

返回 643

你应该做的是:

cv_folds <- createFolds(cars$Price, k = 5, list = TRUE, returnTrain = TRUE)

现在:

lengths(cv_folds)
#output
Fold1 Fold2 Fold3 Fold4 Fold5 
  644   643   642   644   643 

在训练模型之后:

nrow(model$pred[model$pred$Resample == "Fold1",])
#output
160

关于检索插入符号中保留折叠的预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48427517/

相关文章:

c++ - 处理大数据网络文件计算n个最近节点的高效算法

r - 使用插入符号指定交叉验证折叠

R-Caret 相当于 Scikit-Learn 的 DummyClassifier?

r - macOS 和 CentOS 上的并行 Caret 与 doSNOW 集群

R混淆矩阵敏感性和特异性标记

r - 如何计算数据框中值序列的出现次数?

r - 在 rmarkdown ioslides 中包含 block 引用

machine-learning - 训练-测试分离的缺点

string - 如何在 R 中显示带引号的文本?

machine-learning - weka 是否在交叉验证中平衡跨类的训练/测试集?