Java、weka LibSVM 预测不正确

标签 java machine-learning regression weka libsvm

我在 java 代码中使用 LibSVM 和 weka。我正在尝试进行回归。下面是我的代码,

public static void predict() {

    try {
        DataSource sourcePref1 = new DataSource("train_pref2new.arff");
        Instances trainData = sourcePref1.getDataSet();

        DataSource sourcePref2 = new DataSource("testDatanew.arff");
        Instances testData = sourcePref2.getDataSet();

        if (trainData.classIndex() == -1) {
            trainData.setClassIndex(trainData.numAttributes() - 2);
        }

        if (testData.classIndex() == -1) {
            testData.setClassIndex(testData.numAttributes() - 2);
        }

        LibSVM svm1 = new LibSVM();

        String options = ("-S 3 -K 2 -D 3 -G 1000.0 -R 0.0 -N 0.5 -M 40.0 -C 1.0 -E 0.001 -P 0.1");
        String[] optionsArray = options.split(" ");
        svm1.setOptions(optionsArray);

        svm1.buildClassifier(trainData);

        for (int i = 0; i < testData.numInstances(); i++) {

            double pref1 = svm1.classifyInstance(testData.instance(i));                
            System.out.println("predicted value : " + pref1);

        }

    } catch (Exception ex) {
        Logger.getLogger(Test.class.getName()).log(Level.SEVERE, null, ex);
    }
}

但是我从这段代码中获得的预测值与我使用 Weka GUI 获得的预测值不同。

示例: 下面是我针对 java 代码和 weka GUI 给出的单个测试数据。

Java 代码预测值为 1.9064516129032265,而 Weka GUI 的预测值为 10.043。我对 Java 代码和 Weka GUI 使用相同的训练数据集和相同的参数。

我希望你能理解我的问题。有人能告诉我我的代码有什么问题吗?

最佳答案

您使用了错误的算法来执行 SVM 回归。 LibSVM 用于分类。你想要的是SMOreg ,这是一个用于回归的特定 SVM。

下面是一个完整的示例,展示了如何使用 Weka Explorer GUI 和 Java API 来使用 SMOreg。对于数据,我将使用 Weka 发行版附带的 cpu.arff 数据文件。请注意,我将使用此文件进行训练和测试,但理想情况下您将拥有单独的数据集。

使用 Weka Explorer GUI

  1. 打开 WEKA Explorer GUI,单击Preprocess 选项卡,单击Open File,然后打开 cpu.arff 文件应该在你的 Weka 发行版中。在我的系统上,该文件位于 weka-3-8-1/data/cpu.arff 下。资源管理器窗口应如下所示:

Weka Explorer - Choosing the file

  • 单击分类选项卡。它确实应该被称为“预测”,因为您可以在这里进行分类和回归。在Classifier下,单击Choose,然后选择weka --> classifiers --> functions --> SMOreg,如下所示。
  • Weka Explorer - Choosing the regression algorithm

  • 现在构建回归模型并对其进行评估。在测试选项下选择使用训练集,以便我们的训练集也用于测试(正如我上面提到的,这不是理想的方法)。现在按开始,结果应如下所示:
  • Weka Explorer - Results from testing

    记下 RMSE 值 (74.5996)。我们将在 Java 代码实现中重新审视这一点。

    使用 Java API

    下面是一个完整的 Java 程序,它使用 Weka API 来复制之前在 Weka Explorer GUI 中显示的结果。

    import weka.classifiers.functions.SMOreg;
    import weka.classifiers.Evaluation;
    import weka.core.Instance;
    import weka.core.Instances;
    import weka.core.converters.ConverterUtils.DataSource;
    
    public class Tester {
    
        /**
         * Builds a regression model using SMOreg, the SVM for regression, and 
         * evaluates it with the Evalution framework.
         */
        public void buildAndEvaluate(String trainingArff, String testArff) throws Exception {
    
            System.out.printf("buildAndEvaluate() called.\n");
    
            // Load the training and test instances.
            Instances trainingInstances = DataSource.read(trainingArff);
            Instances testInstances = DataSource.read(testArff);
    
            // Set the true value to be the last field in each instance.
            trainingInstances.setClassIndex(trainingInstances.numAttributes()-1);
            testInstances.setClassIndex(testInstances.numAttributes()-1);
    
            // Build the SMOregression model.
            SMOreg smo = new SMOreg();
            smo.buildClassifier(trainingInstances);
    
            // Use Weka's evaluation framework.
            Evaluation eval = new Evaluation(trainingInstances);
            eval.evaluateModel(smo, testInstances);
    
            // Print the options that were used in the ML algorithm.
            String[] options = smo.getOptions();
            System.out.printf("Options used:\n");
            for (String option : options) {
                System.out.printf("%s ", option);
            }
            System.out.printf("\n\n");
    
            // Print the algorithm details.
            System.out.printf("Algorithm:\n %s\n", smo.toString());
    
            // Print the evaluation results.
            System.out.printf("%s\n", eval.toSummaryString("\nResults\n=====\n", false));
        }
    
        /**
         * Builds a regression model using SMOreg, the SVM for regression, and 
         * tests each data instance individually to compute RMSE.
         */
        public void buildAndTestEachInstance(String trainingArff, String testArff) throws Exception {
    
            System.out.printf("buildAndTestEachInstance() called.\n");
    
            // Load the training and test instances.
            Instances trainingInstances = DataSource.read(trainingArff);
            Instances testInstances = DataSource.read(testArff);
    
            // Set the true value to be the last field in each instance.
            trainingInstances.setClassIndex(trainingInstances.numAttributes()-1);
            testInstances.setClassIndex(testInstances.numAttributes()-1);
    
            // Build the SMOregression model.
            SMOreg smo = new SMOreg();
            smo.buildClassifier(trainingInstances);
    
            int numTestInstances = testInstances.numInstances();
    
            // This variable accumulates the squared error from each test instance.
            double sumOfSquaredError = 0.0;
    
            // Loop over each test instance.
            for (int i = 0; i < numTestInstances; i++) {
    
                Instance instance = testInstances.instance(i);
    
                double trueValue = instance.value(testInstances.classIndex());
                double predictedValue = smo.classifyInstance(instance);
    
                // Uncomment the next line to see every prediction on the test instances.
                //System.out.printf("true=%10.5f, predicted=%10.5f\n", trueValue, predictedValue);
    
                double error = trueValue - predictedValue;
                sumOfSquaredError += (error * error);
            }
    
            // Print the RMSE results.
            double rmse = Math.sqrt(sumOfSquaredError / numTestInstances);
            System.out.printf("RMSE = %10.5f\n", rmse);
        }
    
        public static void main(String argv[]) throws Exception {
    
            Tester classify = new Tester();
            classify.buildAndEvaluate("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff");
            classify.buildAndTestEachInstance("../weka-3-8-1/data/cpu.arff", "../weka-3-8-1/data/cpu.arff");
        }
    }
    

    我编写了两个函数来训练 SMOreg 模型并通过对训练数据运行预测来评估模型。

    • buildAndEvaluate() 使用 Weka 评估模型 Evaluation 框架运行一套测试以获得完全相同的结果 结果作为资源管理器 GUI。值得注意的是,它产生 RMSE 值。

    • buildAndTestEachInstance() 通过显式评估模型 循环每个测试实例,进行预测,计算 误差,并计算总体 RMSE。请注意,此 RMSE 匹配 来自 buildAndEvaluate() 的那个,它又与那个匹配 从资源管理器 GUI。

    下面是程序的编译和运行结果。

    prompt> javac -cp weka.jar Tester.java
    
    prompt> java -cp .:weka.jar Tester
    
    buildAndEvaluate() called.
    Options used:
    -C 1.0 -N 0 -I weka.classifiers.functions.supportVector.RegSMOImproved -T 0.001 -V -P 1.0E-12 -L 0.001 -W 1 -K weka.classifiers.functions.supportVector.PolyKernel -E 1.0 -C 250007 
    
    Algorithm:
     SMOreg
    
    weights (not support vectors):
     +       0.01   * (normalized) MYCT
     +       0.4321 * (normalized) MMIN
     +       0.1847 * (normalized) MMAX
     +       0.1175 * (normalized) CACH
     +       0.0973 * (normalized) CHMIN
     +       0.0235 * (normalized) CHMAX
     -       0.0168
    
    
    
    Number of kernel evaluations: 21945 (93.081% cached)
    
    Results
    =====
    
    Correlation coefficient                  0.9044
    Mean absolute error                     31.7392
    Root mean squared error                 74.5996
    Relative absolute error                 33.0908 %
    Root relative squared error             46.4953 %
    Total Number of Instances              209     
    
    buildAndTestEachInstance() called.
    RMSE =   74.59964
    

    关于Java、weka LibSVM 预测不正确,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43948295/

    相关文章:

    java - TreeMap 对于某些键不起作用

    machine-learning - 在逻辑回归中混合二元和非二元特征

    r - 使用 mlogit R 函数时出错 : the two indexes don't define unique observations

    python - 多变量高斯过程回归: adaptation of kernels

    java - 将 View 高度设置为屏幕尺寸的百分比(以编程方式)

    java - 用于 Hibernate 映射的自定义类加载器

    java - 如何获取 Thymeleaf 模板中环境变量的值?

    python - 如何使用 python 处理测试数据集中看不见的分类值?

    machine-learning - 使 Harmonic 中的记忆意义加倍重要

    r - 以对数刻度(半)R 绘制数据的置信区间