python - 多类分类和概率预测

标签 python machine-learning scikit-learn naivebayes multiclass-classification

import pandas as pd
import numpy
from sklearn import cross_validation
from sklearn.naive_bayes import GaussianNB

fi = "df.csv"
# Open the file for reading and read in data
file_handler = open(fi, "r")
data = pd.read_csv(file_handler, sep=",")
file_handler.close()

# split the data into training and test data
train, test = cross_validation.train_test_split(data,test_size=0.6, random_state=0)
# initialise Gaussian Naive Bayes
naive_b = GaussianNB()


train_features = train.ix[:,0:127]
train_label = train.iloc[:,127]

test_features = test.ix[:,0:127]
test_label = test.iloc[:,127]

naive_b.fit(train_features, train_label)
test_data = pd.concat([test_features, test_label], axis=1)
test_data["p_malw"] = naive_b.predict_proba(test_features)

print "test_data\n",test_data["p_malw"]
print "Accuracy:", naive_b.score(test_features,test_label)

我编写此代码是为了接受来自 128 列的 csv 文件的输入,其中 127 列是特征,第 128 列是类标签。

我想预测样本属于每个类别的概率(有5个类别(1-5))并将其打印在矩阵的for中,并根据预测确定样本的类别。 Predict_proba() 没有给出所需的输出。请提出所需的更改建议。

最佳答案

GaussianNB.predict_proba 返回模型中每个类的样本概率。在您的情况下,它应该返回一个包含五列的结果,其行数与测试数据中的行数相同。您可以使用 naive_b.classes_ 验证哪一列对应于哪个类。因此,不清楚为什么你说这不是所需的输出。也许,您的问题来自于您将预测概率的输出分配给数据框列这一事实。尝试:

pred_prob = naive_b.predict_proba(test_features)

而不是

test_data["p_malw"] = naive_b.predict_proba(test_features)

并使用 pred_prob.shape 验证其形状。第二个维度应为 5。

如果您想要每个样本的预测标签,您可以使用预测方法,然后使用混淆矩阵来查看有多少标签被正确预测。

from sklearn.metrics import confusion_matrix

naive_B.fit(train_features, train_label)

pred_label = naive_B.predict(test_features)

confusion_m = confusion_matrix(test_label, pred_label)
confusion_m

这里有一些有用的读物​​。

sklearn GaussianNB - http://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.GaussianNB.html#sklearn.naive_bayes.GaussianNB.predict_proba

sklearn fusion_matrix - http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html

关于python - 多类分类和概率预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50131032/

相关文章:

python - 机器学习教程中的类型错误,numpy

python - R 稀疏矩阵的内部处理

python - 在 Spyder 中使用 input() 时 Matplotlib 卡住

python - 对列表中的 Python 字典对象进行排序

Python @property 与 @property.getter

python - 开始制作游戏,一切正常,但移动功能停止工作

machine-learning - 识别分类中最弱的特征

machine-learning - 使用 "Bag of Words"方法进行主题检测的朴素贝叶斯

r - 获取和排序 R 上 GBM 对象上使用的数据

python - 在 Win10 机器上更新 scikit-learn 时出现“_remove_dead_weakref”错误