java - 尝试使用 Java 对新实例进行分类并定义了 Weka-No 输出实例格式时出错

标签 java classification weka

我正在尝试在我的项目中使用 Weka 使用朴素贝叶斯分类器对文本文档进行分类。我在 this site 上找到了以下两个类(class).

第一个类MyFilteredLearner构建、训练、评估分类器并将其保存到磁盘,这一切都工作正常。

第二个类MyFilteredClassifier从文本文件加载单个文本字符串并成功将其放入实例中。它还从磁盘恢复分类器。它无法使用 classify() 方法对实例进行分类,而是返回异常消息“未定义输出实例格式”。

我花了很长时间寻找答案,尝试安装 Weka 的开发者版本和稳定版本,但仍然遇到同样的问题。

有人知道代码中有什么不正确或需要以不同的方式添加/完成吗?文件详细信息和代码如下:

用于训练分类器的 ARFF 文件 (spam.ARFF):

@relation sms_test

@attribute spamclass {spam,ham}
@attribute text String

@data
ham,'Go until jurong point, crazy.. Available only in bugis n great world la e buffet...Cine there got amore wat...'
etc……………………………………………………………………

新实例的单行文本文件 (toClassify.txt):

this is spam or not, who knows?

MyFilteredLearner的代码:

public class MyFilteredLearner {
    Instances trainData;
    StringToWordVector filter;
    FilteredClassifier classifier;

    public void loadDataset(String fileName) {
        try {
            BufferedReader reader = new BufferedReader(new FileReader(fileName));
            ArffReader arff = new ArffReader(reader);
            trainData = arff.getData();
            System.out.println("===== Loaded dataset: " + fileName + " =====");
            reader.close();
        }
        catch (IOException e) {
            System.out.println("Problem found when reading: " + fileName);
        }
    }

    public void learn() {
        try {
            trainData.setClassIndex(0);
            classifier = new FilteredClassifier();
            filter = new StringToWordVector();
            filter.setAttributeIndices("last");
            classifier.setFilter(filter);
            classifier.setClassifier(new NaiveBayes());
            classifier.buildClassifier(trainData);
            System.out.println("===== Training on filtered (training) dataset done =====");
        }
        catch (Exception e) {
            System.out.println("Problem found when training");
        }
    }

    public void evaluate() {
        try {
            trainData.setClassIndex(0);
            filter = new StringToWordVector();
            filter.setAttributeIndices("last");
            classifier = new FilteredClassifier();
            classifier.setFilter(filter);
            classifier.setClassifier(new NaiveBayes());
            Evaluation eval = new Evaluation(trainData);
            eval.crossValidateModel(classifier, trainData, 4, new Random(1));
            System.out.println(eval.toSummaryString());
            System.out.println(eval.toClassDetailsString());
            System.out.println("===== Evaluating on filtered (training) dataset done =====");
        }
        catch (Exception e) {
            System.out.println("Problem found when evaluating");
        }
    }

    public void saveModel(String fileName) {
        try {
            ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(fileName));
            out.writeObject(classifier);
            System.out.println("Saved model: " + out.toString());
            out.close();
            System.out.println("===== Saved model: " + fileName + "=====");
            } 
        catch (IOException e) {
            System.out.println("Problem found when writing: " + fileName);
        }
    }
}

MyFilteredClassifier的代码:

public class MyFilteredClassifier {
    String text;
    Instances instances;
    FilteredClassifier classifier;  
    StringToWordVector filter;

    public void load(String fileName) {
        try {
            BufferedReader reader = new BufferedReader(new FileReader(fileName));
            String line;
            text = "";
            while ((line = reader.readLine()) != null) {
                        text = text + " " + line;
                }
            System.out.println("===== Loaded text data: " + fileName + " =====");
            reader.close();
            System.out.println(text);
        }
        catch (IOException e) {
            System.out.println("Problem found when reading: " + fileName);
        }
    }

    public void makeInstance() {
        FastVector fvNominalVal = new FastVector(2);
        fvNominalVal.addElement("spam");
        fvNominalVal.addElement("ham");
        Attribute attribute1 = new Attribute("class", fvNominalVal);
        Attribute attribute2 = new Attribute("text",(FastVector) null);
        FastVector fvWekaAttributes = new FastVector(2);
        fvWekaAttributes.addElement(attribute1);
        fvWekaAttributes.addElement(attribute2);
        instances = new Instances("Test relation", fvWekaAttributes,1);           
        instances.setClassIndex(0);
        DenseInstance instance = new DenseInstance(2);
        instance.setValue(attribute2, text);
        instances.add(instance);
        System.out.println("===== Instance created with reference dataset =====");
        System.out.println(instances);
    }

    public void loadModel(String fileName) {
        try {
            ObjectInputStream in = new ObjectInputStream(new FileInputStream(fileName));
            Object tmp = in.readObject();
            classifier = (FilteredClassifier) tmp;
            in.close();
            System.out.println("===== Loaded model: " + fileName + "=====");
        } 
        catch (Exception e) {
        System.out.println("Problem found when reading: " + fileName);
        }
    }

    public void classify() {
        try {
            double pred = classifier.classifyInstance(instances.instance(0));
            System.out.println("===== Classified instance =====");
            System.out.println("Class predicted: " + instances.classAttribute().value((int) pred));
        }
        catch (Exception e) {
            System.out.println("Error: " + e.getMessage());
        }       
    }

    public static void main(String args[]) {
        MyFilteredLearner c = new MyFilteredLearner();
        c.loadDataset("spam.ARFF");
        c.learn();
        c.evaluate();
        c.saveModel("spamClassifier.binary");
        MyFilteredClassifier c1 = new MyFilteredClassifier();
        c1.load("toClassify.txt");
        c1.loadModel("spamClassifier.binary");
        c1.makeInstance();
        c1.classify();
    }

}

最佳答案

看来您从博客的 GitHub 存储库中更改了代码的一个细节,这就是导致错误的原因:

c.learn();
c.evaluate();

对比

c.evaluate();
c.learn();

evaluate() 方法使用以下行重置分类器:

classifier = new FilteredClassifier();

但不构建模型。实际评估使用传递的分类器的副本,因此原始分类器(您类(class)中的分类器)仍然未经训练。

// weka/classifiers/Evaluation.java (method: crossValidateModel)
Classifier copiedClassifier = Classifier.makeCopy(classifier);
copiedClassifier.buildClassifier(train);

因此,您首先构建模型,然后在评估时覆盖它,然后保存未初始化的模型。将它们交换一下,这样您就可以在将其保存到文件之前直接对其进行训练,然后它就可以工作了。

关于java - 尝试使用 Java 对新实例进行分类并定义了 Weka-No 输出实例格式时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/27194865/

相关文章:

java - "kill -QUIT"是否真的杀死了 JVM?

machine-learning - 使用 Vowpal Wabbit 时计算 AUC

r - 在 R ~ Caret 包中设置树的数量

Weka数据加载错误

machine-learning - WEKA的 "Class has to be nominal if cost matrix given"错误

java - 任务运行时显示进度条

java - NetBeans 库问题

java - 如何在 weka 中表示用于分类的文本?

python - 在 Mac OSX 中,Weka 无法识别 Python(可能的 Python 冲突)

java - 使用 Dbx WebAuth.finish 方法获取 Dropbox 访问 token