r - 带有 R 包 nnet 的神经网络 : rubbish prediction due to overfitting?

标签 r machine-learning neural-network nnet

试图弄清楚我是否遇到 R 问题或一般神经网络问题。

假设我有这些数据:

set.seed(123)
n = 1e3
x = rnorm(n)
y = 1 + 3*sin(x/2) + 15*cos(pi*x) + rnorm(n = length(x))
df = data.frame(y,x)
df$train = sample(c(TRUE, FALSE), length(y), replace=TRUE, prob=c(0.7,0.3))
df_train = subset(df, train = TRUE)
df_test = subset(df, train = FALSE)

然后你训练神经网络,它在坚持中看起来不错:

library(nnet)
nn = nnet(y~x, data = df_train, size = 60, linout=TRUE) 
yhat_nn = predict(nn, newdata = df_test)
plot(df_test$x,df_test$y)
points(df_test$x, yhat_nn, col = 'blue')

enter image description here

好吧,所以我想,让我们生成新数据,然后使用经过训练的网络进行预测。但预测离我们还很遥远:

x2 = rnorm(n)
y2 = 1 + 3*sin(x2/2) + 15*cos(pi*x2) + rnorm(n = length(x2))
df2 = data.frame(y2,x2)
plot(df2$x, df2$y)
points(df2$x, predict(nn, newdata = df2), col = 'blue')

enter image description here

这是因为我对训练集过度拟合了吗?我认为通过将原始数据拆分为测试训练可以避免过度拟合。

最佳答案

致命问题是您的新数据框df2没有正确的变量名称。因此,predict.nnet 无法找到正确的值。

names(df)
#[1] "y"     "x"     "train"

names(df2)
#[1] "y2"     "x2"

构建用于预测的数据框时要小心。

## the right way
df2 <- data.frame(y = y2, x = x2)

## and it solves the mystery
plot(df2$x, df2$y)
points(df2$x, predict(nn, newdata = df2), col = 'blue')

prediction on df2


另一个小问题是您对子集的使用。应该是

## not train = TRUE or train = FALSE
df_train <- subset(df, train == TRUE) ## or simply subset(df, train)
df_test <- subset(df, train == FALSE) ## or simply subset(df, !train)

这有有趣的效果:

nrow(subset(df, train == TRUE))
#[1] 718

nrow(subset(df, train = TRUE))  ## oops!!
#[1] 1000

完整的 R session

set.seed(123)
n = 1e3
x = rnorm(n)
y = 1 + 3*sin(x/2) + 15*cos(pi*x) + rnorm(n = length(x))
df = data.frame(y,x)
df$train = sample(c(TRUE, FALSE), length(y), replace=TRUE, prob=c(0.7,0.3))
df_train = subset(df, train == TRUE)  ## fixed
df_test = subset(df, train == FALSE)  ## fixed
library(nnet)
nn = nnet(y~x, data = df_train, size = 60, linout=TRUE) 
yhat_nn = predict(nn, newdata = df_test)
plot(df_test$x,df_test$y)
points(df_test$x, yhat_nn, col = 'blue')
x2 = rnorm(n)
y2 = 1 + 3*sin(x2/2) + 15*cos(pi*x2) + rnorm(n = length(x2))
df2 = data.frame(y = y2, x = x2)  ## fixed
plot(df2$x, df2$y)
points(df2$x, predict(nn, newdata = df2), col = 'blue')

关于r - 带有 R 包 nnet 的神经网络 : rubbish prediction due to overfitting?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72748078/

相关文章:

statistics - 模拟/统计/预测研究生学位?

python - 同一算法 scikit 神经网络多次运行的不同预测

python - 如何提高前馈神经网络的准确性?

用 mutate 替换部分字符串(在管道中)

r - ggplot中的条形图

r - 以最少的输入将函数应用于 R 中的最后结果

java - weka 培训和 java 培训的不同结果

r - 使用 Fread 读取带有双引号和逗号的字段时出错

machine-learning - 解释决策树回归的 Graphviz 输出

machine-learning - 如何使用 to_categorical 将 [[4,7,10],[10,20,30]] 转换为一种热编码