java - Libsvm java训练测试示例(也是实时的)

标签 java machine-learning computer-vision libsvm

任何人都可以通过提供 libsvm java 示例来帮助我进行培训和测试。我是机器学习的新手,需要相关帮助。 @machine learner 之前提供的示例仅给出一类结果时出错。我不想使用 weka 作为之前帖子中给出的建议。

或者您可以纠正此代码中的错误吗?它总是预测结果中的一类。(我想执行多重分类)。

这个例子由“机器学习者”给出

import java.io.*;
import java.util.*;
import libsvm.*;

public class Test{
    public static void main(String[] args) throws Exception{

        // Preparing the SVM param
        svm_parameter param=new svm_parameter();
        param.svm_type=svm_parameter.C_SVC;
        param.kernel_type=svm_parameter.RBF;
        param.gamma=0.5;
        param.nu=0.5;
        param.cache_size=20000;
        param.C=1;
        param.eps=0.001;
        param.p=0.1;

        HashMap<Integer, HashMap<Integer, Double>> featuresTraining=new HashMap<Integer, HashMap<Integer, Double>>();
        HashMap<Integer, Integer> labelTraining=new HashMap<Integer, Integer>();
        HashMap<Integer, HashMap<Integer, Double>> featuresTesting=new HashMap<Integer, HashMap<Integer, Double>>();

        HashSet<Integer> features=new HashSet<Integer>();

        //Read in training data
        BufferedReader reader=null;
        try{
            reader=new BufferedReader(new FileReader("a1a.train"));
            String line=null;
            int lineNum=0;
            while((line=reader.readLine())!=null){
                featuresTraining.put(lineNum, new HashMap<Integer,Double>());
                String[] tokens=line.split("\\s+");
                int label=Integer.parseInt(tokens[0]);
                labelTraining.put(lineNum, label);
                for(int i=1;i<tokens.length;i++){
                    String[] fields=tokens[i].split(":");
                    int featureId=Integer.parseInt(fields[0]);
                    double featureValue=Double.parseDouble(fields[1]);
                    features.add(featureId);
                    featuresTraining.get(lineNum).put(featureId, featureValue);
                }
            lineNum++;
            }

            reader.close();
        }catch (Exception e){

        }

        //Read in test data
        try{
            reader=new BufferedReader(new FileReader("a1a.t"));
            String line=null;
            int lineNum=0;
            while((line=reader.readLine())!=null){

                featuresTesting.put(lineNum, new HashMap<Integer,Double>());
                String[] tokens=line.split("\\s+");
                for(int i=1; i<tokens.length;i++){
                    String[] fields=tokens[i].split(":");
                    int featureId=Integer.parseInt(fields[0]);
                    double featureValue=Double.parseDouble(fields[1]);
                    featuresTesting.get(lineNum).put(featureId, featureValue);
                }
            lineNum++;
            }
            reader.close();
        }catch (Exception e){

        }

        //Train the SVM model
        svm_problem prob=new svm_problem();
        int numTrainingInstances=featuresTraining.keySet().size();
        prob.l=numTrainingInstances;
        prob.y=new double[prob.l];
        prob.x=new svm_node[prob.l][];

        for(int i=0;i<numTrainingInstances;i++){
            HashMap<Integer,Double> tmp=featuresTraining.get(i);
            prob.x[i]=new svm_node[tmp.keySet().size()];
            int indx=0;
            for(Integer id:tmp.keySet()){
                svm_node node=new svm_node();
                node.index=id;
                node.value=tmp.get(id);
                prob.x[i][indx]=node;
                indx++;
            }

            prob.y[i]=labelTraining.get(i);
        }

        svm_model model=svm.svm_train(prob,param);

        for(Integer testInstance:featuresTesting.keySet()){
            HashMap<Integer, Double> tmp=new HashMap<Integer, Double>();
            int numFeatures=tmp.keySet().size();
            svm_node[] x=new svm_node[numFeatures];
            int featureIndx=0;
            for(Integer feature:tmp.keySet()){
                x[featureIndx]=new svm_node();
                x[featureIndx].index=feature;
                x[featureIndx].value=tmp.get(feature);
                featureIndx++;
            }

            double d=svm.svm_predict(model, x);

            System.out.println(testInstance+"\t"+d);
        }

    }
}

最佳答案

这是因为你的featuresTesting从来没有被使用过,HashMap<Integer, Double> tmp=new HashMap<Integer, Double>();应该是HashMap<Integer, Double> tmp=featuresTesting.get(testInstance);

关于java - Libsvm java训练测试示例(也是实时的),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/20458005/

相关文章:

java - 如何在 javafx 浏览器中操作链接

java - 如何在本地无线网络中的两部手机之间传输数据进行测试?

machine-learning - 在 Scikit 中使用 K Mean 选择特征并恢复特征

machine-learning - PyTorch 模型仅在鸟类靠近相机时才能识别鸟类

python - 如何使用 python 从图像中去除边界边缘噪声?

python - Pytorch 数据生成器,用于从许多 3D 立方体中提取 2D 图像

java - 如何查询 blob?

java - 使用模拟测试时无法注入(inject)application.properties中定义的@value

python - 在 Keras 中使用有状态 LSTM 训练多变量多序列回归问题

apache - 使用 apache mahout 算法的开源应用程序