我想用 Matlab 解决分类问题。我有一个由 3 个类和 1900 个样本组成的数据集。每个样本由 10 个特征定义,我有 900 个类别“1”的样本、500 个类别“2”的样本和 500 个类别“3”的样本。
我尝试使用 Matlab 中的标准 patternnet
工具来训练神经网络。我用不同数量的神经元(从 1 到 100)进行了不同的测试,但分类性能总是很差。
所以我查看了混淆矩阵,发现问题在于分类器混淆了类“2”和“3”。接下来我尝试的是创建两个神经网络:
- 第一个神经网络是一个 2 类分类器,具有类“1”和类“23”(类“2”和“3”的并集)。第一个分类对我来说有很好的准确率(大约 90%)
- 第二个神经网络又是一个 2 类分类器,仅将类“2”和“3”的元素作为输入。问题是第二个神经网络的准确率相当差,大约为 55%。
所以我在提高分类精度方面再次遇到一些困难。我想做一些测试,看看是否可以提高准确性。 我的想法是看看每个元素属于特定类别的概率是多少。我想做的是以下其中一项:
- 尝试更改确定样本类别的阈值。例如,如果所有具有 > 70% 概率为类“3”的元素确实是类“3”,但如果概率在 50% 到 70% 之间,则该元素通常是类“2”(I我只是编造数字来尝试解释我想测试的内容)
- 为难以分类的样本创建类别“4”。同样,如果,例如,属于“3”类的概率 > 70% 的元素确实是“3”类,并且我将考虑概率 <70% 的“4”类元素,那么这将起作用。如果这项工作我可以有一些“未知类‘4’”的元素,但分类为“2”或“3”的元素将是正确的,并且具有很高的准确性
因此,首先我想知道是否可以检索每个元素属于特定类的概率,其次,Matlab 中是否有标准方法来实现我想做的两个测试之一。 (当然,如果有人有更好的想法,我很乐意测试它) 抱歉,描述很长,但我希望至少解释了我的问题。
最佳答案
@MeSS83。为了让我提供一个正确的示例(带有代码和所有内容),我必须写出完整的答案。使用 SVM 执行多类分类的最简单方法是使用LibSVM。 LibSVM是一个免费的SVM库(您可以下载here),也可以在Matlab环境中安装和使用。解压该文件,其中有一个 matlab 文件夹,您可以在其中找到安装指南和所有内容。
基本上,您想要做的是一对一的 SVM 方法,即训练 N 个 SVM(其中 N 是类的数量),并且每个 SVM 都经过训练以分离给定的类i 来自所有其他类别(第 i 类将为正类,所有非 i 类将为负类)。假设 TrainingSet
、TrainingLabels
、ValidationSet
、ValidationLabels
是您的数据集(它们的名称相当简单)和 numLabels
是标签的数量(在您的例子中为 3)。
您可以按如下方式训练这些 SVM:
for k=1:numLabels
% k-th class positive, all the other classes are negative
LabelsRecoded(TrainingLabels==k)=1;
LabelsRecoded(TrainingLabels~=k)=-1;
model{k} = svmtrain(LabelsRecoded, TrainingSet, '-c 1 -b 1 -t 0');
end
在此代码中,'-c 1 -b 1 -t 0'
是 SVM 的 LibSVM 参数:c 是调节项(设置为 1),< em>-b 1 表示您还想收集输出概率(也称为决策值),-t 0
表示您正在使用线性内核。更多信息可以在 LibSVM 包内的自述文件中找到。相反,model
是一个元胞数组,其中第 k 个元素包含有关经过训练以将第 k 个类别与所有其他类别分开的 SVM 的结构。
预测阶段具有以下结构:
LabelsRecoded=[]; % get rid of the results stored previously in the training phase
for k=1:numLabels
# same as before, but with validation labels
LabelsRecoded(ValidationLabels==k)=1;
LabelsRecoded(ValidationLabels~=k)=-1;
[~,~,p] = svmpredict(LabelsRecoded, ValidationSet, model{k}, '-b 1');
prob(:,k) = p(:,model{k}.Label==1);
end
在 prob
中,您将有 3 列(3 是类别数),其中包含第 k 个类别为正的概率(注意 model{k}.Label= =1
)。现在您可以根据最大概率值收集预测标签,如下所示:
[~,PredictedLabels] = max(prob,[],2);
现在您已经有了预测标签和验证标签,您可以根据标准公式评估准确性。
关于matlab - matlab中的神经网络分类: get probability of element belonging to i-th class,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35521835/