java - 使用 weka 的文本分类器 : how to correctly train a classifier issue

标签 java weka text-classification categorization

我正在尝试使用 Weka 构建文本分类器,但在一种情况下,类的 distributionForInstance 的概率为 1.0,在所有其他情况下为 0.0,因此 classifyInstance 始终返回与预测相同的类。训练中的某些内容无法正常工作。

ARFF 培训

@relation test1

@attribute tweetmsg    String
@attribute classValues {politica,sport,musicatvcinema,infogeneriche,fattidelgiorno,statopersonale,checkin,conversazione}

@DATA

"Renzi Berlusconi Salvini Bersani",politica
"Allegri insulta la terna arbitrale",sport
"Bravo Garcia",sport

训练方法

public void trainClassifier(final String INPUT_FILENAME) throws Exception
{
    getTrainingDataset(INPUT_FILENAME);

    //trainingInstances consists of feature vector of every input

    for(Instance currentInstance : inputDataset)
    {           
        Instance currentFeatureVector = extractFeature(currentInstance);

        currentFeatureVector.setDataset(trainingInstances);
        trainingInstances.add(currentFeatureVector);                
    }

    classifier = new NaiveBayes();

    try {
        //classifier training code
        classifier.buildClassifier(trainingInstances);

        //storing the trained classifier to a file for future use
        weka.core.SerializationHelper.write("NaiveBayes.model",classifier);
    } catch (Exception ex) {
        System.out.println("Exception in training the classifier."+ex);
    }
}

private Instance extractFeature(Instance inputInstance) throws Exception
{       
    String tweet = inputInstance.stringValue(0);
    StringTokenizer defaultTokenizer = new StringTokenizer(tweet);
    List<String> tokens=new ArrayList<String>();
    while (defaultTokenizer.hasMoreTokens())
    {
        String t= defaultTokenizer.nextToken();
        tokens.add(t);
    }

    Iterator<String> a = tokens.iterator();
    while(a.hasNext())
    {
                String token=(String) a.next();
                String word = token.replaceAll("#","");
                if(featureWords.contains(word))
                {                                              
                    double cont=featureMap.get(featureWords.indexOf(word))+1;
                    featureMap.put(featureWords.indexOf(word),cont);
                }
                else{
                    featureWords.add(word);
                    featureMap.put(featureWords.indexOf(word), 1.0);
                }

    }
    attributeList.clear();
    for(String featureWord : featureWords)
    {
        attributeList.add(new Attribute(featureWord));   
    }
    attributeList.add(new Attribute("Class", classValues));
    int indices[] = new int[featureMap.size()+1];
    double values[] = new double[featureMap.size()+1];
    int i=0;
    for(Map.Entry<Integer,Double> entry : featureMap.entrySet())
    {
        indices[i] = entry.getKey();
        values[i] = entry.getValue();
        i++;
    }
    indices[i] = featureWords.size();
    values[i] = (double)classValues.indexOf(inputInstance.stringValue(1));
    trainingInstances = createInstances("TRAINING_INSTANCES");

    return new SparseInstance(1.0,values,indices,1000000);
}


private void getTrainingDataset(final String INPUT_FILENAME)
{
    try{
        ArffLoader trainingLoader = new ArffLoader();
        trainingLoader.setSource(new File(INPUT_FILENAME));
        inputDataset = trainingLoader.getDataSet();
    }catch(IOException ex)
    {
        System.out.println("Exception in getTrainingDataset Method");
    }
    System.out.println("dataset "+inputDataset.numAttributes());
}

private Instances createInstances(final String INSTANCES_NAME)
{
    //create an Instances object with initial capacity as zero 
    Instances instances = new Instances(INSTANCES_NAME,attributeList,0);
    //sets the class index as the last attribute
    instances.setClassIndex(instances.numAttributes()-1);

    return instances;
}

public static void main(String[] args) throws Exception
{
      Classificatore wekaTutorial = new Classificatore();
      wekaTutorial.trainClassifier("training_set_prova_tent.arff");
      wekaTutorial.testClassifier("testing.arff");
}

public Classificatore()
{
    attributeList = new ArrayList<Attribute>();
    initialize();
}    

private void initialize()
{

    featureWords= new ArrayList<String>(); 

    featureMap = new TreeMap<>();

    classValues= new ArrayList<String>();
    classValues.add("politica");
    classValues.add("sport");
    classValues.add("musicatvcinema");
    classValues.add("infogeneriche");
    classValues.add("fattidelgiorno");
    classValues.add("statopersonale");
    classValues.add("checkin");
    classValues.add("conversazione");
}

测试方法

public void testClassifier(final String INPUT_FILENAME) throws Exception
{
    getTrainingDataset(INPUT_FILENAME);

    //trainingInstances consists of feature vector of every input
    Instances testingInstances = createInstances("TESTING_INSTANCES");

    for(Instance currentInstance : inputDataset)
    {

        //extractFeature method returns the feature vector for the current input
        Instance currentFeatureVector = extractFeature(currentInstance);
        //Make the currentFeatureVector to be added to the trainingInstances
        currentFeatureVector.setDataset(testingInstances);
        testingInstances.add(currentFeatureVector);

    }


    try {
        //Classifier deserialization
        classifier = (Classifier) weka.core.SerializationHelper.read("NaiveBayes.model");

        //classifier testing code
        for(Instance testInstance : testingInstances)
        {

            double score = classifier.classifyInstance(testInstance);
            double[] vv= classifier.distributionForInstance(testInstance);
            for(int k=0;k<vv.length;k++){
            System.out.println("distribution "+vv[k]); //this are the probabilities of the classes and as result i get 1.0 in one and 0.0 in all the others
            }
            System.out.println(testingInstances.attribute("Class").value((int)score));
        }
    } catch (Exception ex) {
        System.out.println("Exception in testing the classifier."+ex);
    }
}

我想创建一个短信文本分类器,这段代码基于本教程http://preciselyconcise.com/apis_and_installations/training_a_weka_classifier_in_java.php 。问题在于,分类器对testing.arff 中的几乎每条消息都预测了错误的类别,因为类别的概率不正确。 Training_set_prova_tent.arff 每个类别具有相同数量的消息。 我下面的示例使用 featureWords.dat 并将 1.0 与该单词关联(如果该单词出现在消息中),而不是我想使用 Training_set_prova_tent 中存在的单词以及测试中存在的单词创建自己的字典,并将出现次数与每个单词相关联。

附注 我知道这正是我可以使用过滤器 StringToWordVector 做的事情,但我还没有找到任何示例来解释如何使用此过滤器与两个文件:一个用于训练集,一个用于测试集。所以改编我找到的代码似乎更容易。

非常感谢

最佳答案

看来您更改了 website you referenced 中的代码在一些关键点上,但不是以一种好的方式。我会尝试起草您想要做的事情以及我发现的错误。

您(可能)想要在 extractFeature 中执行的操作是

  • 将每条推文拆分为单词(标记化)
  • 计算这些单词出现的次数
  • 创建一个代表这些字数和类别的特征向量

您在该方法中忽略的是

  1. 您从未重置您的featureMap。线路

    Map<Integer,Double> featureMap = new TreeMap<>();
    

    最初位于extractFeatures开头,但您将其移至initialize。这意味着您始终会累加字数,但永远不会重置它们。对于每条新推文,您的字数统计还包括之前所有推文的字数统计。我确信这不是您想要的。

  2. 您没有使用想要作为特征的单词来初始化 featureWords。是的,您创建了一个空列表,但您用每条推文迭代地填充它。原始代码在 initialize 方法中初始化了一次,之后就再也没有改变过。这有两个问题:

    • 每条新推文都会添加新特征(单词),因此您的特征向量会随着每条推文而增长。这不会是一个大问题(SparseInstance),但这意味着
    • 您的 class 属性始终位于另一个位置。这两行适用于原始代码,因为 featureWords.size() 基本上是一个常量,但在您的代码中,类标签将位于索引 5、然后是 8、然后是 12,依此类推,但每个实例必须相同。
    indices[i] = featureWords.size();
    values[i] = (double) classValues.indexOf(inputInstance.stringValue(1));
    
  3. 这也体现在以下事实中:您为每条新推文构建一个新的 attributeList,而不是在 initialize 中仅构建一次,这由于已经解释过的原因而很糟糕。

可能还有更多的东西,但是 - 事实上 - 你的代码是相当不可修复的。您想要的比您的版本更接近您修改的教程源代码。

此外,您应该查看 StringToWordVector因为这似乎正是您想要做的:

Converts String attributes into a set of attributes representing word occurrence (depending on the tokenizer) information from the text contained in the strings. The set of words (attributes) is determined by the first batch filtered (typically training data).

关于java - 使用 weka 的文本分类器 : how to correctly train a classifier issue,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/27712040/

相关文章:

java - 获取 Observable 的最新值并立即发出

java - Weka 创建实例对象报错

machine-learning - 解释以下混淆矩阵的方法

python - MultinomialNB - 理论与实践

python - 用于文本分类的预训练模型

weka - 为什么weka中的分类模型将所有实例预测为一个类?

java - 让 Spring Boot 重新创建测试数据库

具有 SSL 客户端身份验证的 Java RMI : obtain client x509Certificate

java - 如何保存Java中检测到的峰值位置?

machine-learning - Weka:10 倍 CV 中每次折叠的结果