r - 如何逐步提取mlr3调谐图?

标签 r resampling mlr3

我的代码如下

library(mlr3verse)
library(mlr3pipelines)
library(mlr3filters)
library(paradox)
filter_importance = mlr_pipeops$get(
  "filter",
  filter = FilterImportance$new(learner = lrn("classif.ranger", importance = "impurity")),
  param_vals = list(filter.frac = 0.7)
)

learner_classif = lrn(
  "classif.ranger",
  predict_type = "prob",
  importance = "impurity",
  num.trees = 500
)
polrn_classif = PipeOpLearner$new(learner_classif)

# create learner graph 
glrn_classif = filter_importance %>>%  polrn_classif
glrn_classif = GraphLearner$new(glrn_classif)
glrn_classif$predict_type = "prob"

# task 

task = tsk("german_credit")

# set search_space
ps_classif = ParamSet$new(list(
  ParamInt$new("classif.ranger.num.trees", lower = 300, upper = 500),
  ParamDbl$new("classif.ranger.sample.fraction", lower = 0.7, upper = 0.8)
))

# auto tunning
at = AutoTuner$new(
  learner = glrn_classif, 
  resampling = rsmp("cv", folds = 3),
  measure = msr("classif.auc"), 
  search_space = ps_classif, 
  terminator = trm("evals", n_evals = 3), 
  tuner = tnr("random_search")
)

# sampling
rr = resample(task, at, rsmp("cv", folds = 2))

在我从 at 重采样和训练学习器中获得 rr 对象之后。 请问如何提取这些步骤的作用?

例如:

  • 当我从 at 对象获得结果时,如何手动重新运行?
  • 每个步骤使用了哪个样本(train_index、test_index)?
  • filter_importance 步骤中选择了哪些变量?这一步每个变量的得分是多少?

非常感谢!!!

最佳答案

为了能够在重采样后调整模型,最好使用 store_models = TRUE 调用 resample

使用你的例子

library(mlr3verse)

set.seed(1)
rr <- resample(task,
               at,
               rsmp("cv", folds = 2),
               store_models = TRUE)

完成重采样后,您可以像这样访问生成对象的内部结构:

获取每个折叠中的行 ID:

rr$resampling$instance
#output
      row_id fold
   1:      5    1
   2:      8    1
   3:      9    1
   4:     12    1
   5:     13    1
  ---            
 996:    989    2
 997:    993    2
 998:    994    2
 999:    995    2
1000:    996    2

有了这些和调整后的自动调谐器,我们可以手动生成预测。

生成测试索引列表

rsample <- split(rr$resampling$instance$row_id,
                 rr$resampling$instance$fold)

遍历折叠并调整自动调谐器并预测:

lapply(1:2, function(i){
  x <- rsample[[i]] #get the test row ids
  task_test <- task$clone() #clone the task so we don't change the original task
  task_test$filter(x) #filter on the test row ids
  preds <- rr$learners[[i]]$predict(task_test) #use the trained autotuner and above filtered task
  preds
  }) -> preds_manual

检查这些预测是否与重采样的输出匹​​配

all.equal(preds_manual,
          rr$predictions())
#output
TRUE

获取调优信息

zz <- rr$data$learners()$learner

lapply(zz, function(x) x$tuning_result)
#output
[[1]]
   classif.ranger.num.trees classif.ranger.sample.fraction learner_param_vals
1:                      342                      0.7931022          <list[7]>
    x_domain classif.auc
1: <list[2]>   0.7981283

[[2]]
   classif.ranger.num.trees classif.ranger.sample.fraction learner_param_vals
1:                      407                      0.7964164          <list[7]>
    x_domain classif.auc
1: <list[2]>   0.7706533

插槽

zz[[1]]$learner$state$model$importance

包含关于filter_importance 步骤的信息

特别是

lapply(zz, function(x) x$learner$state$model$importance$scores)
#output
[[1]]
                 amount                  status                     age 
              27.491369               25.776145               22.021369 
               duration                 purpose          credit_history 
              18.732521               16.251643               14.884843 
    employment_duration                 savings                property 
              11.225678               10.796583                9.078619 
    personal_status_sex       present_residence        installment_rate 
               8.914802                7.875384                7.491573 
                    job          number_credits other_installment_plans 
               6.293323                5.662485                5.345666 
                housing               telephone           other_debtors 
               4.869471                3.742213                3.548856 
          people_liable          foreign_worker 
               2.632163                1.054919 

[[2]]
                 amount                duration                     age 
              26.764389               22.139400               20.749865 
                 status                 purpose     employment_duration 
              20.524764               11.793789               10.962301 
         credit_history        installment_rate                 savings 
              10.416572                9.597835                9.491894 
               property       present_residence                     job 
               9.403157                7.877391                6.760945 
    personal_status_sex                 housing other_installment_plans 
               6.699065                5.811131                5.710761 
              telephone           other_debtors          number_credits 
               4.716322                4.318972                3.974793 
          people_liable          foreign_worker 
               3.196563                0.846520 

包含特征的排名。而

lapply(zz, function(x) x$learner$state$model$importance$outtasklayout)
#output
[[1]]
                     id    type
 1:                 age integer
 2:              amount integer
 3:      credit_history  factor
 4:            duration integer
 5: employment_duration  factor
 6:    installment_rate ordered
 7:                 job  factor
 8:      number_credits ordered
 9: personal_status_sex  factor
10:   present_residence ordered
11:            property  factor
12:             purpose  factor
13:             savings  factor
14:              status  factor

[[2]]
                     id    type
 1:                 age integer
 2:              amount integer
 3:      credit_history  factor
 4:            duration integer
 5: employment_duration  factor
 6:             housing  factor
 7:    installment_rate ordered
 8:                 job  factor
 9: personal_status_sex  factor
10:   present_residence ordered
11:            property  factor
12:             purpose  factor
13:             savings  factor
14:              status  factor

包含过滤步骤后保留的特征。

关于r - 如何逐步提取mlr3调谐图?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67869401/

相关文章:

基于grep结果替换R中的数据值

r - 计算数据的唯一天数总数

rust - 使用 Rust 中的 Polars 重新采样时间序列

python - 按天重新采样并对具有日期时间开始和日期时间结束的 DataFrame 进行分类

python - 重采样后获取 pandas 第 n 个条目(在 'DatetimeIndexResampler' 对象上)

mlr3 - 使用mlr3在xgboost学习器中设置 `early_stopping_rounds`

machine-learning - 使用 mlr3pipeline 编码和缩放后无法通过 mlr3proba 训练数据集

r - 仅获取矩阵列 "middle"中 NA 的位置

R,ggplot2 - 在图例中,如何在一个几何图形中隐藏未使用的颜色,同时在其他几何图形中显示它们?

mlr3 - 使用 #mlr3 查找 XGBoost 超参数时的 Term_evals