r - makeClassif 与 MLR - ID 列从任务中排除

标签 r machine-learning gbm mlr

我的数据中有一个 ID 列。我从我的 trainTask 中删除了此列,因为它不是一个功能。但是,我想将预测概率与数据中的实际 ID 号联系起来。

我想要匹配的列是 Init_Acct,它是 data.frame 中的 ID 号

我的代码如下:

# Make classif tasks
trainTask <- makeClassifTask(
  data = train.df %>% dplyr::select(-Init_Acct) # Init_Acct is the ID I want to match
  , id
  , target = "READMIT_FLAG"
  , positive = "Y"
)
testTask <- makeClassifTask(
  data = test.df %>% dplyr::select(-Init_Acct)
  , target = "READMIT_FLAG"
  , positive = "Y"
)

# Check trainTask and testTask
trainTask <- smote(trainTask, rate = 6)
testTask <- smote(testTask, rate = 6)

# GBM ####
getParamSet('classif.gbm')
gbm.learner <- makeLearner(
  'classif.gbm'
  , predict.type = 'prob'
)
plotLearnerPrediction(gbm.learner, trainTask)

# Tune model
gbm.tune.ctl <- makeTuneControlRandom(maxit = 50L)

# Cross validation
gbm.cv <- makeResampleDesc("CV", iters = 3L)

# Grid search - Hyper-parameter space
gbm.par <- makeParamSet(
  makeDiscreteParam('distribution', values = 'bernoulli')
  , makeIntegerParam('n.trees', lower = 10, upper = 1000)
  , makeIntegerParam('interaction.depth', lower = 2, upper = 10)
  , makeIntegerParam('n.minobsinnode', lower = 10, upper = 80)
  , makeNumericParam('shrinkage', lower = 0.01, upper = 1)
)

# Tune Hyper-parameters
parallelMap::parallelStartSocket(
  4
  , level = "mlr.tuneParams"
)
gbm.tune <- tuneParams(
  learner = gbm.learner
  , task = trainTask
  , resampling = gbm.cv
  , measures = acc
  , par.set = gbm.par
  , control = gbm.tune.ctl
)

parallelMap::parallelStop()

# Check CV acc
gbm.tune$y
gbm.tune$x

# Set hyper-parameters
gbm.ps <- setHyperPars(
  learner = gbm.learner
  , par.vals = gbm.tune$x
)

# Train gbm
gbm.train <- train(gbm.ps, testTask)
plotLearningCurve(
  generateLearningCurveData(
    gbm.learner
    , testTask
  )
)

# Predict
gbm.pred <- predict(gbm.train, testTask)
plotResiduals(gbm.pred)

# Create submission file
gbm.submit <- data.frame(
  gbm.pred$data
)
head(gbm.submit, 5)
table(gbm.submit$truth, gbm.submit$response)

# Confusion Matrix
calculateConfusionMatrix(gbm.pred)
calculateROCMeasures(gbm.pred)
conf_mat_f1_func(gbm.pred)

perf_plots_func(Model = gbm.pred)

数据看起来像这样:

glimpse(train.df)
Observations: 33,031
Variables: 17
$ Init_Acct         <chr> "12345678", "87654321", "81734650", "11223344", "1422...
$ Init_LOS          <dbl> 2, 2, 5, 1, 12, 3, 16, 9, 3, 14, 1, 1, 4, 7, 4, 1, 3,...
$ Init_LACE         <dbl> 2, 7, 7, 9, 8, 8, 11, 10, 8, 10, 5, 4, 8, 8, 4, 5, 3,...
$ READMIT_FLAG      <fct> N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, Y,...
$ Init_Hosp_Pvt     <fct> PRIVATE, HOSPITALIST, HOSPITALIST, HOSPITALIST, PRIVA...
$ Age_at_Init_Admit <dbl> 37, 26, 56, 67, 51, 53, 48, 57, 92, 67, 72, 22, 60, 6...
$ Age_Bucket        <fct> 3, 2, 5, 6, 5, 5, 4, 5, 9, 6, 7, 2, 6, 6, 7, 6, 9, 5,...
$ Gender            <fct> F, M, F, M, M, F, M, F, M, M, M, F, M, F, F, F, F, M,...
$ Init_ROM          <dbl> 1, 1, 3, 4, 2, 3, 1, 3, 4, 1, 1, 1, 1, 1, 2, 1, 2, 4,...
$ Init_SOI          <dbl> 1, 1, 3, 4, 3, 3, 3, 3, 3, 2, 1, 1, 2, 2, 2, 2, 2, 4,...
$ Has_Diabetes      <fct> N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,...
$ reduced_dispo     <fct> AHR, AHR, AHR, ATH, ATW, ATW, ATW, AHR, AHR, ATW, AHR...
$ reduced_hsvc      <fct> SUR, MED, MED, Other, MED, MED, MED, MED, MED, MED, M...
$ reduced_abucket   <fct> 3, 2, 5, 6, 5, 5, 4, 5, Other, 6, 7, 2, 6, 6, 7, 6, O...
$ reduced_spclty    <fct> Other, HOSIM, HOSIM, HOSIM, Other, HOSIM, HOSIM, HOSI...
$ reduced_lihn      <fct> Other, Medical, Pneumonia, Medical, Medical, Medical,...
$ discharge_month   <fct> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...

输出:

glimpse(gbm.submit)
Observations: 23,896
Variables: 5
$ id       <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,...
$ truth    <fct> Y, N, N, N, N, N, N, Y, N, N, N, N, N, N, Y, N, Y, N, N, N, N,...
$ prob.N   <dbl> 0.9150623, 0.7914781, 0.9661108, 0.9198683, 0.8502536, 0.94376...
$ prob.Y   <dbl> 0.08493774, 0.20852192, 0.03388919, 0.08013167, 0.14974644, 0....
$ response <fct> N, N, N, N, N, N, N, Y, N, N, N, N, N, N, N, N, Y, N, N, N, N,...

最佳答案

MLR 的 predict() 保留行名称,并在其输出中生成一个附加的 id 列,用于索引原始数据。您可以使用其中任何一种将预测与其原始样本 ID 相关联。

设置

library(tidyverse)
library(mlr)

## Add a custom sample ID column
iris2 <- iris %>% mutate(Init_Acct = paste0("Acct",1:n()))
lrn <- makeLearner( "classif.gbm", predict.type="prob" )

选项1:使用id列索引原始数据

## Drop the custom column as in your original post
task <- makeClassifTask( data=select(iris2, -Init_Acct), target="Species" )
mdl <- train( lrn, task )
pred <- predict( mdl, task )

## Join against the original data by the "id" column
iris2 %>% mutate(id=1:n()) %>% select(Init_Acct, id) %>% 
    inner_join( pred$data ) %>% select(-id)
#   Init_Acct  truth prob.setosa prob.versicolor prob.virginica response
# 1     Acct1 setosa   0.9998775    1.225043e-04   2.836942e-08   setosa
# 2     Acct2 setosa   0.9999652    3.468690e-05   1.118015e-07   setosa
# 3     Acct3 setosa   0.9999538    4.611200e-05   8.389636e-08   setosa

选项 2:使用行名称

## Store the sample names into rownames
task <- makeClassifTask( data=column_to_rownames(iris2, "Init_Acct"),
                         target="Species" )
mdl <- train( lrn, task )
pred <- predict( mdl, task )

## Pull the rownames back out into their own column
pred$data %>% rownames_to_column( "Init_Acct" ) %>% select(-id)
#     Init_Acct      truth  prob.setosa prob.versicolor prob.virginica   response
# 1       Acct1     setosa 9.999266e-01    7.331226e-05   6.889259e-08     setosa
# 2       Acct2     setosa 9.999751e-01    2.462816e-05   3.154618e-07     setosa
# 3       Acct3     setosa 9.999656e-01    3.421543e-05   1.449155e-07     setosa

关于r - makeClassif 与 MLR - ID 列从任务中排除,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56135478/

相关文章:

python - 在sklearn中运行10倍交叉验证后如何运行SVC分类器?

python - Keras 序列模型输入层

r - 如何建立从data.table到magrittr再回到data.table的管道

r - 通过 ssh 安装包

r - 两个相关系数差异的显着性检验

r - 使用插入符优化二元分类的偏差

r - "valid deviance"对于 GBM 模型来说是 nan,这意味着什么以及如何摆脱它?

r - R 中朴素贝叶斯数值预测器的奇怪结果

python - 如何在 scikit 学习中对具有多个文本列的数据框进行矢量化而不丢失对原始列的跟踪

r - R gbm 函数中的权重参数