python - 如何从与 scikit-learn 中的 predict_proba 一起使用的 cross_val_predict 获取类标签

标签 python scikit-learn cross-validation

我需要训练 Random Forest classifier使用 3 折交叉验证。对于每个样本,我需要检索它恰好在测试集中时的预测概率。

我使用的是 scikit-learn 版本 0.18.dev0。

这个新版本增加了使用方法的功能cross_val_predict()使用附加参数 method 来定义估计器需要哪种预测。

在我的例子中,我想使用 predict_proba()方法,返回多类场景中每个类的概率。

但是,当我运行该方法时,我得到的结果是预测概率矩阵,其中每一行代表一个样本,每一列代表特定类别的预测概率。

问题是该方法没有指明每一列对应于哪个类。

我需要的值与属性 classes_ 中返回的值相同(在我的例子中使用 RandomForestClassifier)定义为:

classes_ : array of shape = [n_classes] or a list of such arrays The classes labels (single output problem), or a list of arrays of class labels (multi-output problem).

predict_proba() 需要它,因为在其文档中写道:

The order of the classes corresponds to that in the attribute classes_.

一个最小的例子如下:

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_predict

clf = RandomForestClassifier()

X = np.random.randn(10, 10)
y = y = np.array([1] * 4 + [0] * 3 + [2] * 3)

# how to get classes from here?
proba = cross_val_predict(estimator=clf, X=X, y=y, method="predict_proba")

# using the classifier without cross-validation
# it is possible to get the classes in this way:
clf.fit(X, y)
proba = clf.predict_proba(X)
classes = clf.classes_

最佳答案

是的,它们将按排序顺序排列;这是因为 DecisionTreeClassifier(这是 RandomForestClassifier 的默认 base_estimator)uses np.unique to construct the classes_ attribute它返回输入数组的排序唯一值。

关于python - 如何从与 scikit-learn 中的 predict_proba 一起使用的 cross_val_predict 获取类标签,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39256287/

相关文章:

python - Excel python COM 对象的文档?

python - 向使用 PyYaml 生成的 YAML 添加注释

machine-learning - 机器学习训练和测试数据分割方法

python - RandomForestClassifier .fit 在 ec2 上因内存错误而失败,但在本地运行时没有错误

python - 扩展 Cython 类时,__cinit__() 恰好需要 2 个位置参数

python - ValueError : invalid literal for int() with base 10: '' | Django

python - 如何使用 asyncio 和 concurrent.futures.ProcessPoolExecutor 在 Python 中终止长时间运行的计算(CPU 绑定(bind)任务)?

matlab - plsregress - 谁能解释特征的标准化?

machine-learning - 橙色的混淆矩阵

machine-learning - 为什么对于 10 倍交叉验证,Weka 运行学习算法 11 次?