java - 重新加载的 Spark 模型似乎不起作用

标签 java apache-spark apache-spark-sql apache-spark-mllib

我正在训练并保存 CSV 文件中的模型。 对于这第一步来说,一切都很顺利。保存模型后,我尝试加载保存的模型并将其与新数据一起使用,但它不起作用。

问题是什么?

训练Java文件

SparkConf sconf = new SparkConf().setMaster("local[*]").setAppName("Test").set("spark.sql.warehouse.dir","D:/Temp/wh");
          SparkSession spark = SparkSession.builder().appName("Java Spark").config(sconf).getOrCreate();
          JavaRDD<Cobj> cRDD = spark.read().textFile("file:///C:/Temp/classifications1.csv").javaRDD()
                       .map(new Function<String, Cobj>() {
                              @Override
                              public Cobj call(String line) throws Exception {
                                     String[] parts = line.split(",");
                                     Cobj c = new Cobj();
                                     c.setClassName(parts[1].trim());
                                     c.setProductName(parts[0].trim());                                   
                                     return c;
                              }
                       });

          Dataset<Row> mainDataset = spark.createDataFrame(cRDD, Cobj.class);                         

          //StringIndexer
          StringIndexer classIndexer = new StringIndexer()
                        .setHandleInvalid("skip")
                        .setInputCol("className")
                        .setOutputCol("label");
          StringIndexerModel classIndexerModel=classIndexer.fit(mainDataset);

          //Tokenizer
          Tokenizer tokenizer = new Tokenizer()                                
                       .setInputCol("productName")                     
                       .setOutputCol("words");              

          //HashingTF
          HashingTF hashingTF = new HashingTF()
                  .setInputCol(tokenizer.getOutputCol())
                  .setOutputCol("features");

          DecisionTreeClassifier  decisionClassifier = new DecisionTreeClassifier ()                      
                  .setLabelCol("label")
                  .setFeaturesCol("features");

          Pipeline pipeline = new Pipeline()
                  .setStages(new PipelineStage[] {classIndexer,tokenizer,hashingTF,decisionClassifier});

       Dataset<Row>[] splits = mainDataset.randomSplit(new double[]{0.8, 0.2});
       Dataset<Row> train = splits[0];
       Dataset<Row> test = splits[1];

       PipelineModel pipelineModel = pipeline.fit(train);

       Dataset<Row> result = pipelineModel.transform(test);           
       pipelineModel.write().overwrite().save(savePath+"DecisionTreeClassificationModel");

       IndexToString labelConverter = new IndexToString()
                   .setInputCol("prediction")
                   .setOutputCol("PredictedClassName")                       
                   .setLabels(classIndexerModel.labels());
       result=labelConverter.transform(result);
       result.show(num,false);
       Dataset<Row> predictionAndLabels = result.select("prediction", "label");
       MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
         .setMetricName("accuracy");
      System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));

输出:

+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+
|className                 |productName                                  |label|words                                                 |features                                                                                         |rawPrediction        |probability          |prediction|PredictedClassName        |
+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+
|Apple iPhone 6S 16GB      |Apple IPHONE 6S 16GB SGAY Telefon            |2.0  |[apple, iphone, 6s, 16gb, sgay, telefon]              |(262144,[27536,56559,169565,200223,210029,242621],[1.0,1.0,1.0,1.0,1.0,1.0])                     |[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0       |Apple iPhone 6S Plus 64GB |
|Apple iPhone 6S 16GB      |Apple iPhone 6S 16 GB Space Gray MKQJ2TU/A   |2.0  |[apple, iphone, 6s, 16, gb, space, gray, mkqj2tu/a]   |(262144,[10879,56559,95900,139131,175329,175778,200223,210029],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0       |Apple iPhone 6S Plus 64GB |
|Apple iPhone 6S 16GB      |iPhone 6s 16GB                               |2.0  |[iphone, 6s, 16gb]                                    |(262144,[27536,56559,210029],[1.0,1.0,1.0])                                                      |[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0       |Apple iPhone 6S Plus 64GB |
|Apple iPhone 6S Plus 128GB|Apple IPHONE 6S PLUS 128GB SG Telefon        |4.0  |[apple, iphone, 6s, plus, 128gb, sg, telefon]         |(262144,[56559,99916,137263,175839,200223,210029,242621],[1.0,1.0,1.0,1.0,1.0,1.0,1.0])          |[0.0,0.0,0.0,0.0,2.0]|[0.0,0.0,0.0,0.0,1.0]|4.0       |Apple iPhone 6S Plus 128GB|
|Apple iPhone 6S Plus 16GB |Iphone 6S Plus 16GB SpaceGray - Apple Türkiye|1.0  |[iphone, 6s, plus, 16gb, spacegray, -, apple, türkiye]|(262144,[27536,45531,46750,56559,59104,99916,200223,210029],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])   |[0.0,5.0,0.0,0.0,0.0]|[0.0,1.0,0.0,0.0,0.0]|1.0       |Apple iPhone 6S Plus 16GB |
+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+
Accuracy = 1.0

加载Java文件

SparkConf sconf = new SparkConf().setMaster("local[*]").setAppName("Test").set("spark.sql.warehouse.dir","D:/Temp/wh");
          SparkSession spark = SparkSession.builder().appName("Java Spark").config(sconf).getOrCreate();
          JavaRDD<Cobj> cRDD = spark.read().textFile("file:///C:/Temp/classificationsTest.csv").javaRDD()
                       .map(new Function<String, Cobj>() {
                              @Override
                              public Cobj call(String line) throws Exception {
                                     String[] parts = line.split(",");
                                     Cobj c = new Cobj();
                                     c.setClassName("?");
                                     c.setProductName(parts[0].trim());
                                     return c;
                              }
                       });

          Dataset<Row> mainDataset = spark.createDataFrame(cRDD, Cobj.class);
          mainDataset.show(100,false);

          PipelineModel pipelineModel = PipelineModel.load(savePath+"DecisionTreeClassificationModel");

          Dataset<Row> result = pipelineModel.transform(mainDataset);

          result.show(100,false);

输出:

+---------+-----------+-----+-----+--------+-------------+-----------+----------+
|className|productName|label|words|features|rawPrediction|probability|prediction|
+---------+-----------+-----+-----+--------+-------------+-----------+----------+
+---------+-----------+-----+-----+--------+-------------+-----------+----------+

最佳答案

我从管道中删除了 StringIndexer 并保存为“StringIndexer”。 在第二个文件中;加载管道后,我加载了 StringIndexer 以将其转换为预测标签。

关于java - 重新加载的 Spark 模型似乎不起作用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38937622/

相关文章:

java : convert Hex color #RRGGBB to rgb r g b?

java - 由于线程处于 WAITING 和 TIMED_WAITING 状态,请求处于挂起状态

java - eclipse 中缺少 "Annotation processing"菜单

java - 如何注入(inject)实现在 Java 或 Groovy 代码中调用的接口(interface)方法的类?

apache-spark - Spark 如何处理比 Spark 存储大得多的数据?

r - sparklyr - 在 Apache Spark Join 中包含空值

unit-testing - 在 Windows 上运行 spark 单元测试

apache-spark - 如何更新数据集中的值?

scala - 使用向后兼容模式读取 Spark 中的旧数据

apache-spark - 用于高效加入 Spark 数据帧/数据集的分区数据