r - 为决策树的概率结果设置阈值

标签 r decision-tree threshold confusion-matrix

我在进行决策树模型后尝试计算混淆矩阵

# tree model
tree <- rpart(LoanStatus_B ~.,data=train, method='class')
# confusion matrix
pdata <- predict(tree, newdata = test, type = "class")
confusionMatrix(data = pdata, reference = test$LoanStatus_B, positive = "1")

我如何为我的混淆矩阵设置阈值,比如我希望默认概率高于 0.2,这是二元结果。

最佳答案

这里有几点需要注意。首先,确保在进行预测时获得类别概率。使用预测类型 ="class" 你只是得到离散的类,所以你想要的是不可能的。所以你要让它像下面我的那样 "p"

library(rpart)
data(iris)

iris$Y <- ifelse(iris$Species=="setosa",1,0)

# tree model
tree <- rpart(Y ~Sepal.Width,data=iris, method='class')

# predictions
pdata <- as.data.frame(predict(tree, newdata = iris, type = "p"))
head(pdata)

# confusion matrix
table(iris$Y, pdata$`1` > .5)

接下来请注意,这里的 .5 只是一个任意值——您可以将其更改为任何您想要的值。

我看不出有什么理由使用 confusionMatrix 函数,因为可以通过这种方式简单地创建混淆矩阵并允许您实现轻松更改截止值的目标。

话虽如此,如果您确实想为您的混淆矩阵使用 confusionMatrix 函数,那么只需首先根据您的自定义截止值创建一个离散类预测,如下所示:

pdata$my_custom_predicted_class <- ifelse(pdata$`1` > .5, 1, 0)

同样,.5 是您自定义选择的截止值,可以是您想要的任何值。

caret::confusionMatrix(data = pdata$my_custom_predicted_class, 
                  reference = iris$Y, positive = "1")
Confusion Matrix and Statistics

          Reference
Prediction  0  1
         0 94 19
         1  6 31

               Accuracy : 0.8333          
                 95% CI : (0.7639, 0.8891)
    No Information Rate : 0.6667          
    P-Value [Acc > NIR] : 3.661e-06       

                  Kappa : 0.5989          
 Mcnemar's Test P-Value : 0.0164          

            Sensitivity : 0.6200          
            Specificity : 0.9400          
         Pos Pred Value : 0.8378          
         Neg Pred Value : 0.8319          
             Prevalence : 0.3333          
         Detection Rate : 0.2067          
   Detection Prevalence : 0.2467          
      Balanced Accuracy : 0.7800          

       'Positive' Class : 1

关于r - 为决策树的概率结果设置阈值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46042966/

相关文章:

python - 决策树分类器sklearn中节点的不同颜色表示什么?

检查模拟值是否在阈值内

c++ - 使用 OpenCV 改进文本二值化/OCR 预处理

c++ - 将 Mat 的每个像素设置为特定值,如果它低于某个值?

r - 如何使用代码在 R Shiny 中触发 session$onSessionEnded?

r - 使用开源 Shiny 服务器时,我的图标不会显示在我的应用程序的浏览器选项卡上

python - 如何知道使用 Scikit-learn 构建的树的大小(节点数)?

r - 尝试使用 R 中的 RWeka 包应用决策 C4.5 算法时出错

r - 使用 dplyr 合并来自两个数据帧的信息

r - 如何逐行减去一个向量,保持数据帧(df)列的均值来自 df?