r - Caret 如何通过 K 折交叉验证生成 OLS 模型?

标签 r linear-regression cross-validation r-caret

假设我有一些通用数据集,OLS 回归是最佳选择。因此,我生成了一个包含一些一阶项的模型,并决定使用 R 中的 Caret 进行回归系数估计和误差估计。

在插入符号中,这最终是:

k10_cv = trainControl(method="cv", number=10)
ols_model = train(Y ~ X1 + X2 + X3, data = my_data, trControl = k10_cv, method = "lm")

从那里,我可以使用 summary(ols_model) 提取回归信息。并且还可以通过调用 ols_model 获取更多信息.

当我只看ols_model时,RMSE/R-square/MAE 是否通过典型的 k 倍 CV 方法计算?另外,当我在summary(ols_model)中看到的模型时生成后,该模型是在整个数据集上训练的还是每个折叠生成的模型的平均值?

如果不是,为了用方差换取偏差,有没有办法在 Caret 中获取一次一次训练一次的 OLS 模型?

最佳答案

这是您的示例的可重现数据。

library("caret")
my_data <- iris

k10_cv <- trainControl(method="cv", number=10)

set.seed(100)
ols_model <- train(Sepal.Length ~  Sepal.Width + Petal.Length + Petal.Width,
                  data = my_data, trControl = k10_cv, method = "lm")


> ols_model$results
  intercept      RMSE  Rsquared       MAE     RMSESD RsquaredSD      MAESD
1      TRUE 0.3173942 0.8610242 0.2582343 0.03881222 0.04784331 0.02960042

1) 上面的 ols_model$results 基于下面每个不同重采样的平均值:

> (ols_model$resample)
        RMSE  Rsquared       MAE Resample
1  0.3386472 0.8954600 0.2503482   Fold01
2  0.3154519 0.8831588 0.2815940   Fold02
3  0.3167943 0.8904550 0.2441537   Fold03
4  0.2644717 0.9085548 0.2145686   Fold04
5  0.3769947 0.8269794 0.3070733   Fold05
6  0.3720051 0.7792611 0.2746565   Fold06
7  0.3258501 0.8095141 0.2647466   Fold07
8  0.2962375 0.8530810 0.2731445   Fold08
9  0.3059100 0.8351535 0.2611982   Fold09
10 0.2615792 0.9286246 0.2108592   Fold10

> mean(ols_model$resample$RMSE)==ols_model$results$RMSE
[1] TRUE

2) 模型在整个训练集上进行训练。您可以使用 lm 或指定 method = "none" 来检查这一点> 对于 trainControl

 coef(lm(Sepal.Length ~  Sepal.Width + Petal.Length + Petal.Width, data = my_data))
 (Intercept)  Sepal.Width Petal.Length  Petal.Width 
   1.8559975    0.6508372    0.7091320   -0.5564827 

ols_model$finalModel 相同。

关于r - Caret 如何通过 K 折交叉验证生成 OLS 模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52460078/

相关文章:

scala - Spark - 从 CSV 文件创建(标签、特征)对的 RDD

python - 如何在pytorch中使用SGD成功训练简单的线性回归模型?

r - R中Tune函数中 'dispersion'的含义

r - 可以使用经过验证的模型来预测整个数据集吗?

python - 使用 RandomizedSearchCV 在 sklearn 中进行超参数调整需要花费大量时间

r - 将父表的子集分配给 R 中的对象

python - 独立 R 脚本加载依赖项的性能

r - 如何在时间序列预测图中的 x 轴显示日期而不是周期?

r - 以编程方式将列名传递给 data.table

r - geom_smooth() 中的 "prediction from a rank-deficient fit may be misleading"