我已经检查了"Making predictions" WEKA 文档,它包含命令行和 GUI 预测的明确说明。
我想知道如何获得预测值,就像下面我使用Agrawal
从GUI获得的预测值一样。我自己的 Java 代码中的数据集 ( weka.datagenerators.classifiers.classification.Agrawal
):
inst#, actual, predicted, error, prediction
1, 1:0, 2:1, +, 0.941
2, 1:0, 1:0, , 1
3, 1:0, 1:0, , 1
4, 1:0, 1:0, , 1
5, 1:0, 1:0, , 1
6, 1:0, 1:0, , 1
7, 1:0, 2:1, +, 0.941
8, 2:1, 2:1, , 0.941
9, 2:1, 2:1, , 0.941
10, 2:1, 2:1, , 0.941
1, 1:0, 1:0, , 1
2, 1:0, 1:0, , 1
3, 1:0, 1:0, , 1
<小时/>
即使it我也无法复制这个结果说:
Java
If you want to perform the classification within your own code, see the classifying instances section of this article, explaining the Weka API in general.
我去了link它说:
Classifying instances
In case you have an unlabeled dataset that you want to classify with your newly trained classifier, you can use the following code snippet. It loads the file
/some/where/unlabeled.arff
, uses the previously built classifier tree to label the instances, and saves the labeled data as/some/where/labeled.arff
.
这不是我想要的情况,因为我只想对当前数据集建模的 k 倍交叉验证预测。
<小时/>更新
predictions
public FastVector predictions()
Returns the predictions that have been collected.
Returns:
block 引用>a reference to the
FastVector
containing the predictions that have been collected. This should be null if no predictions have been collected.我找到了
predictions()
Evaluation
类型对象的方法并使用代码:Object[] preds = evaluation.predictions().toArray(); for(Object pred : preds) { System.out.println(pred); }
结果是:
... NOM: 0.0 0.0 1.0 0.9466666666666667 0.05333333333333334 NOM: 0.0 0.0 1.0 0.8947368421052632 0.10526315789473684 NOM: 0.0 0.0 1.0 0.9934883720930232 0.0065116279069767444 NOM: 0.0 0.0 1.0 0.9466666666666667 0.05333333333333334 NOM: 0.0 0.0 1.0 0.9912575655682583 0.008742434431741762 NOM: 0.0 0.0 1.0 0.9934883720930232 0.0065116279069767444 ...
这和上面的一样吗?
最佳答案
经过深入的谷歌搜索(并且因为 documentation provides minimal help )我终于找到了答案。
我希望这个明确的答案将来对其他人有所帮助。
对于示例代码,我看到了问题 "How to print out the predicted class after cross-validation in WEKA"我很高兴能够解读这个不完整的答案,其中有些内容很难理解。
这是我的代码,其工作方式与 GUI 的输出类似
StringBuffer predictionSB = new StringBuffer(); Range attributesToShow = null; Boolean outputDistributions = new Boolean(true); PlainText predictionOutput = new PlainText(); predictionOutput.setBuffer(predictionSB); predictionOutput.setOutputDistribution(true); Evaluation evaluation = new Evaluation(data); evaluation.crossValidateModel(j48Model, data, numberOfFolds, randomNumber, predictionOutput, attributesToShow, outputDistributions);
为了帮助您理解,我们需要实现
StringBuffer
将被类型转换在AbstractOutput
对象以便函数crossValidateModel
能认出来。使用
StringBuffer
只会导致java.lang.ClassCastException
使用PlainText
时与问题中的类似没有StringBuffer
将显示java.lang.IllegalStateException
.谢谢ManChon U (Kevin)和他们的问题"How to identify the cross-evaluation result to its corresponding instance in the input data set?"让我了解这意味着什么:
... you just need a single addition argument that is a concrete subclass of
weka.classifiers.evaluation.output.prediction.AbstractOutput
.weka.classifiers.evaluation.output.prediction.PlainText
is probably the one you want to use. Source和
... Try creating a
PlainText
object, which extendsAbstractOutput
(called output for example) instance and callingoutput.setBuffer(forPredictionsPrinting)
and passing that in instead of the buffer. Source这些实际上只是为了创建一个
PlainText
对象,放置一个StringBuffer
并使用它通过方法setOutput(boolean)
调整输出以及其他。最后,要获得我们想要的预测,只需使用:
System.out.println(predictionOutput.getBuffer());
其中
predictionOutput
是AbstractOutput
中的一个对象家庭(PlainText
、CSV
、XML
等)。此外,
evaluation.predictions()
的结果与 WEKA GUI 中提供的不同。幸运的是,马克·霍尔在问题"Print out the predict class after cross-validation"中解释了这一点。Evaluation.predictions()
returns aFastVector
containing eitherNominalPrediction
orNumericPrediction
objects from theweka.classifiers.evaluation
package. CallingEvaluation.crossValidateModel()
with the additionalAbstractOutput
object results in the evaluation object printing the prediction/distribution information fromNominal
/NumericPrediction
objects to theStringBuffer
in the format that you see in the Explorer or from the command line.
引用文献:
关于java - 使用自己的 Java 代码在 WEKA 中获取风险预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/21424248/