`glmnet` 的岭回归给出的系数与我通过 "textbook definition"计算的系数不同?

标签 r machine-learning regression linear-regression glmnet

我正在使用 glmnet R 包运行 Ridge 回归。我注意到我从 glmnet::glmnet 函数获得的系数与我通过定义计算系数(使用相同的 lambda 值)获得的系数不同。有人可以解释一下为什么吗?

数据(包括:响应 Y 和设计矩阵 X)被缩放。

library(MASS)
library(glmnet)

# Data dimensions
p.tmp <- 100
n.tmp <- 100

# Data objects
set.seed(1)
X <- scale(mvrnorm(n.tmp, mu = rep(0, p.tmp), Sigma = diag(p.tmp)))
beta <- rep(0, p.tmp)
beta[sample(1:p.tmp, 10, replace = FALSE)] <- 10
Y.true <- X %*% beta
Y <- scale(Y.true + matrix(rnorm(n.tmp))) # Y.true + Gaussian noise

# Run glmnet 
ridge.fit.cv <- cv.glmnet(X, Y, alpha = 0)
ridge.fit.lambda <- ridge.fit.cv$lambda.1se

# Extract coefficient values for lambda.1se (without intercept)
ridge.coef <- (coef(ridge.fit.cv, s = ridge.fit.lambda))[2:(p.tmp+1)]

# Get coefficients "by definition"
ridge.coef.DEF <- solve(t(X) %*% X + ridge.fit.lambda * diag(p.tmp)) %*% t(X) %*% Y

# Plot estimates
plot(ridge.coef, type = "l", ylim = range(c(ridge.coef, ridge.coef.DEF)),
     main = "black: Ridge `glmnet`\nred: Ridge by definition")
lines(ridge.coef.DEF, col = "red")

enter image description here

最佳答案

如果您阅读 ?glmnet ,你会看到高斯响应的惩罚目标函数是:

1/2 * RSS / nobs + lambda * penalty

万一岭罚1/2 * ||beta_j||_2^2被使用,我们有

1/2 * RSS / nobs + 1/2 * lambda * ||beta_j||_2^2

正比于

RSS + lambda * nobs * ||beta_j||_2^2

这与我们通常在课本上看到的有关岭回归的内容不同:

RSS + lambda * ||beta_j||_2^2

你写的公式:

##solve(t(X) %*% X + ridge.fit.lambda * diag(p.tmp)) %*% t(X) %*% Y
drop(solve(crossprod(X) + diag(ridge.fit.lambda, p.tmp), crossprod(X, Y)))

为教科书结果;对于 glmnet我们应该期待:

##solve(t(X) %*% X + n.tmp * ridge.fit.lambda * diag(p.tmp)) %*% t(X) %*% Y
drop(solve(crossprod(X) + diag(n.tmp * ridge.fit.lambda, p.tmp), crossprod(X, Y)))

因此,教科书使用惩罚最小二乘法,但是glmnet使用惩罚均方误差

请注意,我没有将您的原始代码用于 t() , "%*%"solve(A) %*% b ;使用 crossprodsolve(A, b)更有效率!请参阅最后的跟进部分。


现在让我们做一个新的比较:

library(MASS)
library(glmnet)

# Data dimensions
p.tmp <- 100
n.tmp <- 100

# Data objects
set.seed(1)
X <- scale(mvrnorm(n.tmp, mu = rep(0, p.tmp), Sigma = diag(p.tmp)))
beta <- rep(0, p.tmp)
beta[sample(1:p.tmp, 10, replace = FALSE)] <- 10
Y.true <- X %*% beta
Y <- scale(Y.true + matrix(rnorm(n.tmp)))

# Run glmnet 
ridge.fit.cv <- cv.glmnet(X, Y, alpha = 0, intercept = FALSE)
ridge.fit.lambda <- ridge.fit.cv$lambda.1se

# Extract coefficient values for lambda.1se (without intercept)
ridge.coef <- (coef(ridge.fit.cv, s = ridge.fit.lambda))[-1]

# Get coefficients "by definition"
ridge.coef.DEF <- drop(solve(crossprod(X) + diag(n.tmp * ridge.fit.lambda, p.tmp), crossprod(X, Y)))

# Plot estimates
plot(ridge.coef, type = "l", ylim = range(c(ridge.coef, ridge.coef.DEF)),
     main = "black: Ridge `glmnet`\nred: Ridge by definition")
lines(ridge.coef.DEF, col = "red")

enter image description here

注意我设置了intercept = FALSE当我调用 cv.glmnet (或 glmnet )。这比它在实践中的影响具有更多的概念意义。从概念上讲,我们的教科书计算没有截距,所以我们想在使用 glmnet 时去掉截距。 .但实际上,因为你的 XY被标准化,截距的理论估计值为 0。即使有 intercepte = TRUE ( glment 默认值),您可以检查拦截的估计值是 ~e-17 (数值为 0),因此其他系数的估计不会受到显着影响。另一个答案只是显示这个。


跟进

As for the using crossprod and solve(A, b) - interesting! Do you by chance have any reference to simulation comparison for that?

t(X) %*% Y将首先进行转置 X1 <- t(X) , 然后做 X1 %*% Y , 而 crossprod(X, Y)不会做转置。 "%*%" DGEMM 的包装器对于案例 op(A) = A, op(B) = B , 而 crossprodop(A) = A', op(B) = B 的包装器.同样tcrossprod对于 op(A) = A, op(B) = B' .

crossprod(X) 的主要用途用于 t(X) %*% X ;同样的 tcrossprod(X)对于 X %*% t(X) , 在这种情况下 DSYRK 而不是 DGEMM叫做。您可以阅读Why the built-in lm function is so slow in R?第一部分出于原因和基准。

请注意,如果 X不是方阵,crossprod(X)tcrossprod(X)由于它们涉及不同数量的浮点运算,因此它们的速度并不相同,您可以阅读 Any faster R function than “tcrossprod” for symmetric dense matrix multiplication?边注

关于 solvel(A, b)solve(A) %*% b ,请阅读How to compute diag(X %% solve(A) %% t(X)) efficiently without taking matrix inverse?第一节

关于 `glmnet` 的岭回归给出的系数与我通过 "textbook definition"计算的系数不同?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48845108/

相关文章:

r - 将 .maf 文件保存为表

r - devtools::install_github() - 忽略 SSL 证书验证失败

tensorflow - 如何在 Tensorflow 检测模型上使用 Lucid Interpretability 工具?

python - 我们是否需要 GPU 系统来训练深度学习模型?

r - 在R中的多列上执行lm()和segmented()

r - 逻辑回归替代解释

R:如何找到数据帧两行元素的交集?

r - 带有 HTML 的 Shiny 、ggvis 和 add_tooltip

c# - 使用ID3算法、Accord.Net框架进行预测

r - 贝塔回归库克距离