R:使用 regr.svm 任务在 mlr 中使用新因子水平进行预测

标签 r mlr

我正在使用 mlr 包从 SVM 进行预测。如果我的验证集包含训练数据中不存在的因子水平,则无论我在创建 SVM 学习器时如何设置 fix.factors.prediction,预测都会失败。

处理这个问题的正确方法是什么?使用 e1071::svm() 将返回新因子水平的响应,但如何使用 mlr 方法执行相同操作?

示例

library(mlr)
library(dplyr)

set.seed(575)
data(iris)

# Split data
train_set <- sample_frac(iris, 4/5)
valid_set <- setdiff(iris, train_set)

# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <- 
  sample(c("virginica", "versicolor"), 
         sum(train_set$Species == "setosa"), replace = TRUE)    
# Fit model
iris_task <- makeRegrTask(data = train_set, target = "Petal.Width")

svm_lrn <- makeLearner("regr.svm", fix.factors.prediction = TRUE)

svm_mod <- train(svm_lrn, iris_task)

# Predict on new factor levels
predict(svm_mod, newdata = valid_set)

Error in (function (..., row.names = NULL, check.rows = FALSE, check.names = TRUE, : arguments imply differing number of rows: 29, 20

使用 makeLearner("regr.svm", fix.factors.prediction = FALSE) 时,我在调用 predict 时收到以下错误:

Error in scale.default(newdata[, object$scaled, drop = FALSE], center = object$x.scale$"scaled:center", : length of 'center' must equal the number of columns of 'x'

有效的事情

当对训练集中的因子水平进行子集化时,我可以生成预测:

predict(svm_mod, newdata = valid_set %>% 
          filter(Species %in% train_set$Species))

使用不同的学习器时不会出现错误:

nnet_lrn <- makeLearner("regr.nnet", fix.factors.prediction = TRUE)
nnet_mod <- train(nnet_lrn, iris_task)
predict(nnet_mod, newdata = valid_set)

或者直接从包中使用相同的学习器时:

e1071_mod <- 
  e1071::svm(Petal.Width ~ Sepal.Length + Sepal.Width +
               Petal.Length + Species, train_set)
predict(e1071_mod, newdata = valid_set)

session 信息

R version 3.4.4 (2018-03-15)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 14.04.6 LTS

Matrix products: default
BLAS: /usr/lib/libblas/libblas.so.3.0
LAPACK: /usr/lib/lapack/liblapack.so.3.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] dplyr_0.8.0.1     mlr_2.14.0.9000   ParamHelpers_1.12

loaded via a namespace (and not attached):
 [1] parallelMap_1.4    Rcpp_1.0.1         pillar_1.4.1      
 [4] compiler_3.4.4     class_7.3-14       tools_3.4.4       
 [7] tibble_2.1.3       gtable_0.3.0       checkmate_1.9.3   
[10] lattice_0.20-38    pkgconfig_2.0.2    rlang_0.3.99.9003 
[13] Matrix_1.2-14      fastmatch_1.1-0    rstudioapi_0.8    
[16] yaml_2.2.0         parallel_3.4.4     e1071_1.7-1       
[19] nnet_7.3-12        grid_3.4.4         tidyselect_0.2.5  
[22] glue_1.3.1         data.table_1.12.2  R6_2.4.0          
[25] XML_3.98-1.20      survival_2.41-3    ggplot2_3.2.0.9000
[28] purrr_0.3.2        magrittr_1.5       backports_1.1.4   
[31] scales_1.0.0.9000  BBmisc_1.11        splines_3.4.4     
[34] assertthat_0.2.1   colorspace_1.3-2   stringi_1.4.3     
[37] lazyeval_0.2.2     munsell_0.5.0      crayon_1.3.4 

最佳答案

好吧,这有点具有挑战性。前面的一些事情:

  • e1071::svm() 无法处理 newdata 中缺失的因子水平 ( Error in predict.svm: test data does not match model )
  • 您的示例的手动执行仅会运行,因为您没有删除 train_data 中未使用的因子级别
  • 参数fix.factor.predictions没有做它应该做的事情。我在 this branch 中发布了临时修复。 这个修复非常脏,只是一个概念证明。我可能会把它清理干净。

手动执行无效的证明:

library(mlr)
#> Loading required package: ParamHelpers
#> Registered S3 methods overwritten by 'ggplot2':
#>   method         from 
#>   [.quosures     rlang
#>   c.quosures     rlang
#>   print.quosures rlang
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

set.seed(575)
data(iris)

# Split data
train_set <- sample_frac(iris, 4 / 5)
valid_set <- setdiff(iris, train_set)

# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <-
  sample(c("virginica", "versicolor"),
    sum(train_set$Species == "setosa"), replace = TRUE)

# this is important
train_set = droplevels(train_set)

e1071_mod <- e1071::svm(Petal.Width ~ Sepal.Length + Sepal.Width +
  Petal.Length + Species, train_set)
predict(e1071_mod, newdata = valid_set)
#> Error in scale.default(newdata[, object$scaled, drop = FALSE], center = object$x.scale$"scaled:center", : length of 'center' must equal the number of columns of 'x'

reprex package 于 2019-06-13 创建(v0.3.0)

使用 mlr 中提供的修复的工作示例:

remotes::install_github("mlr-org/mlr@fix-factors")
#> Downloading GitHub repo mlr-org/mlr@fix-factors
library(mlr)
#> Loading required package: ParamHelpers
#> Registered S3 methods overwritten by 'ggplot2':
#>   method         from 
#>   [.quosures     rlang
#>   c.quosures     rlang
#>   print.quosures rlang
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

set.seed(575)
data(iris)

# Split data
train_set <- sample_frac(iris, 4 / 5)
valid_set <- setdiff(iris, train_set)

# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <-
  sample(c("virginica", "versicolor"),
    sum(train_set$Species == "setosa"), replace = TRUE)

# this is important
train_set = droplevels(train_set)

# Fit model
iris_task <- makeRegrTask(data = train_set, target = "Petal.Width")

svm_lrn <- makeLearner("regr.svm", fix.factors.prediction = TRUE)

svm_mod <- train(svm_lrn, iris_task)

# Predict on new factor levels
predict(svm_mod, newdata = valid_set)
#> Prediction: 30 observations
#> predict.type: response
#> threshold: 
#> time: 0.00
#>   truth  response
#> 1   0.3 0.2457751
#> 2   0.1 0.2730398
#> 3   0.2 0.2717464
#> 4   0.1 0.2717748
#> 5   0.1 0.2651599
#> 6   0.4 0.2582568
#> ... (#rows: 30, #cols: 2)

reprex package 于 2019-06-13 创建(v0.3.0)

关于R:使用 regr.svm 任务在 mlr 中使用新因子水平进行预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56553479/

相关文章:

r - 多列移动平均 - 分组数据

r - R 中使用 roll apply 的滚动回归

r - 为向量中的每个值分配一个负号(在 R 中)?

r - mlr - 集成模型

r - 如果特定学习者在特定任务上失败,如何使基准函数不会失败?

r - 加载 Tidycensus 数据时出现问题

r - R中的fread会将一个大的.csv文件作为具有一行的数据框导入

r - 使用 MLR 包调整 rpart 中的参数?

r - makeClassif 与 MLR - ID 列从任务中排除

在 Ubuntu 18.04 上的 mlr R 包中运行随机搜索需要太长时间