R 中多类分类的 ROC 曲线

标签 r machine-learning classification roc

我有一个包含 6 个类别的数据集,我想绘制多类别分类的 ROC 曲线。 Achim Zeileis 给出的第一个答案非常好。

ROC curve in R using rpart package?

但这仅适用于二项式分类。我得到的错误是预测错误,类数不等于2。有人做过多类分类吗?

这是我正在尝试做的事情的一个简单示例。 数据 <- read.csv("colors.csv")

假设 data$cType6 个值(或级别)为(红色、绿色、蓝色、黄色、黑色白色)

有没有办法为这 6 个类别绘制 ROC 曲线?任何超过 2 类的工作示例将不胜感激。

最佳答案

我知道这是一个老问题,但事实上,唯一的答案是使用 Python 编写的,这让我很困扰,因为这个问题专门要求 R 解决方案。

正如您从下面的代码中看到的,我正在使用 pROC::multiclass.roc() 函数。使其工作的唯一要求是预测矩阵的列名称匹配真实类(real_values)。

第一个示例生成随机预测。第二个产生更好的预测。第三个生成完美的预测(即始终将最高概率分配给真实类别。)

library(pROC)
set.seed(42)
head(real_values)
real_values <- matrix( c("class1", "class2", "class3"), nc=1 )

# [,1]    
# [1,] "class1"
# [2,] "class2"
# [3,] "class3"

# Random predictions
random_preds <- matrix(rbeta(3*3,2,2), nc=3)
random_preds <- sweep(random_preds, 1, rowSums(a1), FUN="/")
colnames(random_preds) <- c("class1", "class2", "class3")


head(random_preds)

#       class1    class2    class3
# [1,] 0.3437916 0.6129104 0.4733117
# [2,] 0.6016169 0.4700832 0.9364681
# [3,] 0.6741742 0.8677781 0.4823129

multiclass.roc(real_values, random_preds)
#Multi-class area under the curve: 0.1667



better_preds <- matrix(c(0.75,0.15,0.5,
                         0.15,0.5,0.75,
                         0.15,0.75,0.5), nc=3)
colnames(better_preds) <- c("class1", "class2", "class3")

head(better_preds)

#       class1 class2 class3
# [1,]   0.75   0.15   0.15
# [2,]   0.15   0.50   0.75
# [3,]   0.50   0.75   0.50

multiclass.roc(real_values, better_preds)
#Multi-class area under the curve: 0.6667


perfect_preds <- matrix(c(0.75,0.15,0.5,
                          0.15,0.75,0.5,
                          0.15,0.5,0.75), nc=3)
colnames(perfect_preds) <- c("class1", "class2", "class3")
head(perfect_preds)

multiclass.roc(real_values, perfect_preds)
#Multi-class area under the curve: 1

关于R 中多类分类的 ROC 曲线,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36631054/

相关文章:

database - 复杂的数据转换

r - 如何在 R 中的函数中逆向工作

python - 在 python 中从 R 转换

machine-learning - 无法处理多类和连续的混合

python - Keras 模型的相同输出

r - 在 R 中使用正则表达式从字符串中获取数字

numpy - BallTree现在支持不规则数据的自定义指标吗?

python - Azure ML 中的属性错误 : 'Logger' object has no attribute 'activity_info' during Dataset Registration

python - sklearn中BaggingClassifier默认配置与硬投票的区别

python-3.x - 在 Tensorflow 上训练随机森林