r - 为模型选择实现 LOOCV(不使用 caret 包)

标签 r statistics data-science cross-validation

我有这个数据集,我正在尝试使用 R 找到最佳模型

数据集:

structure(list(V1 = c(1.43359910241166, 0.411971077467806, 0.236361845246534, 
-0.289263426819727, -1.23202861459847, -0.738796384986188, 0.200420172968439, 
1.55763841132305, 0.306848974278087, -1.06336757529454, 0.208462982445177, 
-0.161933544137143, -0.529226737265933, 1.06311471300635, 0.154281146875831, 
0.609577869238014, -0.13720552696616, 0.920650581183744, 1.18282854178987, 
-0.792945405521446, -0.609722647650392, -0.21688852962299, -3.06426186175807, 
0.5498848363865), V2 = c(-0.322161064448354, -0.203202315321523, 
-1.37681357972322, -2.09183896169083, -1.73416522569493, -0.167163678879473, 
-0.496644140621754, -0.378640254832213, 1.71897857982319, 0.987886990249993, 
-0.464176577243306, 0.313599912560739, 0.279305189424942, 0.621879051693468, 
-1.35413705469938, -0.904307866112488, 0.563960402008738, 0.942178870082166, 
1.05504527675313, 1.72684177309, 0.487583880103759, 0.366982237506534, 
0.341207392409481, 0.0878011635613361), V3 = c(-3.06259779143185, 
0.113156471002083, -0.596111339640452, -0.0549465535711572, -0.941898864240695, 
-0.653015082018507, -0.169956676284042, -0.35411953808696, 0.713862293279259, 
1.20019049753438, 0.295042002436139, -0.248609439893179, 1.9312167684667, 
0.674670687298312, 0.224140747830105, -0.59349261052001, 0.0558808922143246, 
0.749007982254512, 1.04584894162381, 0.280651184742914, -0.313568542992107, 
-1.54267082673779, 0.397265080266878, 0.850053716467332), V4 = c(-2.72697312636474, 
0.851743869193346, -0.0599187094978506, 0.341978048955579, -0.484015693411596, 
-0.131475393689722, -0.021866557862478, -1.8191792655517, 1.74883985589495, 
-0.446343374015597, 0.0107633789594956, 0.55528371030783, 0.31132242799237, 
0.0710046563366782, 0.701388784100771, 1.56870481640847, -0.841113890934613, 
-0.881987858407386, 1.37693978208629, -0.488560120797117, 0.366895195216852, 
0.0627972059134885, -0.655416452787133, 0.589188711953821), V5 = c(-1.79836984688233, 
0.50295466271361, -1.17227869532777, 0.661412408202374, 0.853890060320874, 
0.349725611664228, -0.308069063888987, -0.433246608902138, -0.178767449882736, 
1.34125510863996, 0.206474174580616, -0.657831069822233, 0.215632332747088, 
0.573672331330443, -0.202823754124207, 0.609758501891791, 0.222044482387977, 
2.56037433110525, -1.29345283990688, 0.174550400877521, -0.174265216769768, 
0.55419775558349, -0.458225457879011, -2.14861215865916), V6 = c(-0.18026818728965, 
-0.480816154309526, -0.50256960223903, -1.31874854057412, -0.896086924318379, 
-1.79382217103909, -1.60213450587948, -0.481119812364401, 0.377075792056211, 
1.34981730088023, 0.0611706096060544, 0.83874651540465, 0.58899516399665, 
1.24066391945654, -1.08080170411743, 0.597620326597847, -1.21365483260366, 
0.230893469563153, -0.576677068566099, 1.31703258659203, 0.35136844419016, 
0.925208426922233, 1.73348977742475, 0.514617170610343), V7 = c(0.692646184527114, 
1.64958468445801, -0.722861261417701, -0.411292490473929, -1.73926867251488, 
0.479847732965793, 0.224291785874008, -0.650661070391403, -0.20779505689401, 
-0.900990363217965, 0.712570690351891, 0.0291624484927884, 0.613871305452367, 
-0.901767959624604, -0.184130922600279, 2.60941994159236, 0.0144701586285878, 
1.00941096184201, -1.07148389565784, -0.439790917550134, -0.786567592396622, 
0.926243735906836, -1.39392614240757, 0.449016715055174), V8 = c(-0.218730876718155, 
0.279536175230915, -0.860839531512879, 1.62382620633742, -0.656202640703168, 
-3.05801703213563, 0.243884147081474, 0.926579301241956, 0.58184138659717, 
-0.0814078168437784, -0.0668035158044736, 0.00153834639170001, 
0.806767895958209, 0.834326360087515, -0.0790896439523125, 0.07028192584928, 
-0.619273530317688, 1.07556660504801, -1.13473924521572, 0.668145147063421, 
0.758090513962191, 0.456430947715887, -1.73160959029873, 0.179898464937389
), V9 = c(2.56974590352874, -0.263155790779132, 0.646658371629822, 
-0.752843366448987, 0.200047856906594, 0.659371008337854, 1.24620285734473, 
0.94634794321528, -1.3304334794271, 1.33090401796431, -0.819840444239054, 
0.272969704571894, -0.486961950780986, 0.169639870524667, -0.451658048721127, 
-1.04537018765646, -1.16107891054576, -1.20995090654021, -0.839823653138378, 
0.62253221198192, 0.622634591405887, -0.547608828939565, 0.786557248787584, 
-1.16488601898254), V10 = c(-2.26412916115509, 0.67348993363598, 
-0.342027192999345, 0.249815496496033, 0.30352488488975, -0.744451635640458, 
1.58487417838063, -1.01570448604582, -0.541105970352036, 1.13647671257197, 
-0.54886598448313, -0.962789161396563, -0.538065955333129, 0.0781727823942247, 
0.0970193660300894, 1.18927210039089, -0.6957686086705, -0.386785336508124, 
-0.35257548033064, 2.31937096293864, -0.549132531058022, -0.0974568592721698, 
1.43853645612397, -0.0316945106071529), V11 = c(-1.86095070927053, 
0.573330283491408, -1.03183858717977, -1.83745190916475, -0.077180684913356, 
-0.94533768863225, -0.641638632478328, 0.154349543995556, 1.89664953662371, 
1.3494700201932, 1.04343452008192, 1.03948878970461, 0.394740150081754, 
1.24869842481551, 0.33270007318232, 0.373677276693529, 0.670774298645023, 
-0.0191045174843475, 0.0901593335518681, -0.813757209813031, 
-0.527741614949631, -1.55637393322463, -0.0817683516977811, 0.225671587747989
), V12 = c(0.235155165117673, 0.0334071835637513, 0.141983465568844, 
0.441692874434554, 0.0707526888389656, 0.332161357520943, 0.0735800395703528, 
-0.281305763416249, 0.16538364649173, -1.15487983901285, 1.56899928098857, 
-0.567750194144175, 0.541218236160627, 1.48159680904495, -0.568523352759803, 
-0.0545712227404042, -2.93340050534491, 0.662421496450859, 1.11729205722267, 
-0.581175560009803, 0.792548304722282, 0.955149345977461, -0.821090667653583, 
-1.65064484659245), V13 = c(-1.97412125867671, 0.44572205242864, 
-0.274712915255066, -1.44692140049933, -1.18035700830368, -0.260286573948736, 
-0.95815595797825, -0.242760674716397, 0.477953228907608, 0.992878959448502, 
0.48518262700317, -0.882424015844636, 2.03856721097186, 0.782640940939034, 
0.00789969362112054, -0.295894328060507, 1.27922468162261, 0.51472928905797, 
0.0447383908218823, 0.165638463053774, -0.263332324321804, -1.15204704327981, 
-0.258342890933598, 1.95418085394235), V14 = c(-0.181993529177506, 
1.39403983793056, -0.152733307069606, -1.52421030170283, -0.924924418962197, 
-0.364387222675804, 1.10283509955152, 0.0727783277608945, -1.77522562543095, 
1.08664918075833, -1.04803884297856, -0.940631906527986, 1.12617755875177, 
1.21705368328955, -0.279102677856877, 0.343713803473868, 1.26542530994074, 
-0.774396836280874, 0.417125600747737, 1.49096714826284, 0.284166748008431, 
-1.53295609357739, 0.105608954195959, -0.407940490431605), V15 = c(-1.46474265513464, 
1.19486941463858, 0.244933071673175, -0.459011700723317, 0.241718140420906, 
0.282959623977014, 0.00585677416957126, -2.03400384857495, 0.537918956631718, 
-1.04030075327707, -0.557219563096931, -0.252427064540924, 0.547956268292219, 
-0.526158422645334, 0.251554548033225, -0.745912076395139, -0.0351666299711204, 
1.15204026955591, 0.842246979246097, 1.52268303136091, -1.90156582122334, 
-0.142035061237368, 0.385224459566802, 1.94858205925399), V16 = c(0.828548104520814, 
0.713189024971904, 0.774573684318552, -0.425568343697551, 0.259608074896051, 
-1.22029633555545, -0.344755278537263, 0.973749897026122, -0.474553098183039, 
0.0257155566445092, -0.476287023663646, 0.974669054546108, -1.77164686907544, 
1.56028342699847, 1.24959541751606, -0.574201649578301, 1.2099741843225, 
-0.0750690376790856, -0.0110241372862062, -0.984530244128971, 
-2.52086075001167, 0.0287667805602271, 0.731343831738835, -0.451224270663529
), V17 = c(-0.681074029216176, -0.0390433509889875, 0.0328512523391066, 
1.12428796011696, 0.176765286103444, -0.222850967042728, 0.988520019729737, 
2.09179105565111, 0.116819106946508, 0.51447781508645, 1.87648378755979, 
-1.08036997332246, -0.418517756914466, 0.291253915397003, -0.355756145391065, 
0.874359244531183, -2.35192438381252, -0.200559130397419, -1.29305021151605, 
-0.216777649470054, -1.43207151780606, -0.392317470556723, 0.447601162558867, 
0.149101980414553), V18 = c(-1.96475300593026, 0.422711683040055, 
-1.12996029903421, -2.33587910613298, 0.179352498545959, -0.600058127770143, 
-1.35077156778998, -0.727365308346169, 1.43052873254504, 1.07048786910024, 
1.15649152054786, 0.702163956193049, 0.599458156020645, 0.489172517239038, 
0.957116387643539, 0.335186798948586, -0.598777825023964, 0.10012893280699, 
0.0822063408722808, 0.393896776121708, 0.968441995451939, -0.625513747288306, 
-0.437871585012806, 0.883606407251895), V19 = c(0.203243289070699, 
0.206783154660488, 0.0730205054389099, 0.151752499129077, 0.339065300597841, 
0.198750153846351, 0.246574181097875, 0.219716854159337, 0.112571755773366, 
0.108437458425644, 0.159923853880819, 0.198217376539615, 1.27794667790059, 
0.0628191359027579, -0.023668700184257, 0.0103470645871769, -4.55192891533295, 
0.0932248108210876, 0.0372915017676821, 0.103290843005291, 0.1485089149749, 
0.167015138770557, 0.258108289841612, 0.198988855325523), V20 = c(-0.6885610185506, 
0.215106818871655, -1.26229703607397, -1.15415874394993, -0.770942786330788, 
-1.07811513531511, -1.34581518035362, 0.296281823344214, -0.525449013409778, 
1.52659228597052, 1.66011376586839, 0.204981756466606, 2.25710524990656, 
0.850893107617607, 0.181598239123184, 0.0790398588000734, -0.0665218787774753, 
0.411298611581292, 0.0839458342094344, -0.122405563089466, -1.6897393933796, 
1.24061257187769, -0.157685318761091, -0.145878855645788), outcome_var = c(-3, 
4, 1, -1, -1, -3, -1, -3, 3, 2, -2, -3, 1, 0, 0, 0, 3, 0, 2, 
2, 1, -3, 1, 0)), class = "data.frame", row.names = c(NA, -24L
)) 

代码:
train.control <- trainControl(method = "LOOCV")

step.model <- train(outcome_var ~., data = total,
                method = "leapSeq", 
                tuneGrid = data.frame(nvmax = 1:5),
                trControl = train.control
)

step.model$results

summary(step.model$finalModel)

结果:
20 Variables  (and intercept)
Forced in Forced out
V1      FALSE      FALSE
V2      FALSE      FALSE
V3      FALSE      FALSE
V4      FALSE      FALSE
V5      FALSE      FALSE
V6      FALSE      FALSE
V7      FALSE      FALSE
V8      FALSE      FALSE
V9      FALSE      FALSE
V10     FALSE      FALSE
V11     FALSE      FALSE
V12     FALSE      FALSE
V13     FALSE      FALSE
V14     FALSE      FALSE
V15     FALSE      FALSE
V16     FALSE      FALSE
V17     FALSE      FALSE
V18     FALSE      FALSE
V19     FALSE      FALSE
V20     FALSE      FALSE
1 subsets of each size up to 3
Selection Algorithm: 'sequential replacement'
         V1  V2  V3  V4  V5  V6  V7  V8  V9  V10 V11 V12 V13 V14 V15 V16 V17 V18 V19 V20
1  ( 1 ) " " " " "*" " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " " "
2  ( 1 ) " " " " "*" " " " " " " " " " " " " " " " " " " " " " " " " " " "*" " " " " " "
3  ( 1 ) " " " " " " " " " " " " " " " " " " "*" " " " " "*" " " " " " " "*" " " " " " "

这给了我我正在寻找的输出,但现在我正在尝试制作我自己的 LOOCV 函数而不是使用 caret 包。我没有得到相同的结果,
loocv = function(fit) {
  n = length(fit$residuals)
  yvar = fit$model[, 1]
  index = 1:n
  e = rep(NA, n)
  for (i in index) {
    refit = update(fit, subset = index != i)
    pred = predict(refit, dplyr::slice(fit$model, i))
    e[i] = yvar[i] - pred
  }
  return(mean(e^2))
}

如何在不使用 caret 包的情况下使用 LOOCV 并找到最佳拟合模型?

最佳答案

用于交叉验证,如 LOOCV ,应该为每个测试折叠从头开始构建模型。通过反复试验,我相信 caret用途 leaps::regsubsets用于逐步模型选择。

library(leaps)

nvmax = 3 #number of max variables
pred = rep(NA, nrow(total))
for (i in seq(nrow(total))) #LOOCV
  {#train a new model
   tem = regsubsets(x=total[-i,1:20], 
                    y=total[-i,21], 
                    nvmax=nvmax, 
                    method="seqrep")
  coef(tem, nvmax) #best coef chosen
  fit = lm(outcome_var ~ ., 
           data = total[-i,
                  c(which(summary(tem)$which[nvmax,-1]), 
                  21)])

  #predict the hold-out data
  pred[i] = predict(fit, newdata=total[i,])
  }

RMSE(pred, total[,'outcome_var'])
#1.945036

MAE(pred, total[,'outcome_var'])
#1.442353
插入符号的结果:
step.model$results
# nvmax     RMSE    Rsquared      MAE
#     3 1.945036 0.238655497 1.442353

关于r - 为模型选择实现 LOOCV(不使用 caret 包),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62217645/

相关文章:

statistics - 可以在 Mathematica 中扩展 PDF、CDF、FindDistributionParameters 等的功能吗?

r - 如何在 R 中使用白色校正进行异方差

python - 根据条件从其他列中的最小日期推算日期值

r - 使用数据框中的缺失值创建 ts 时间序列

r - 使用替代语法方法在 R 中定义函数

r - 是否可以向 RTextTools 包提供自定义停用词列表?

java - 循环和统计。未打印正确的值

python - Pandas 简单的并行/多进程计算

python - 如何判断哪个 Keras 模型更好?

r - 图表上的非字母字符箭头标签