java - 如何从 Java 中的 Spark MLLib Logistic Regression 获取置信度分数

标签 java apache-spark-mllib logistic-regression

更新:我尝试使用以下方法来生成置信度分数,但它给了我一个异常(exception)。我使用下面的代码片段:

double point = BLAS.dot(logisticregressionmodel.weights(), datavector);
double confScore = 1.0 / (1.0 + Math.exp(-point)); 

我得到的异常是:

Caused by: java.lang.IllegalArgumentException: requirement failed: BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes: x.size = 198, y.size = 18
    at scala.Predef$.require(Predef.scala:233)
    at org.apache.spark.mllib.linalg.BLAS$.dot(BLAS.scala:99)
    at org.apache.spark.mllib.linalg.BLAS.dot(BLAS.scala)

你能帮忙吗?看起来权重 vector 比数据 vector 有更多的元素(198)(我正在生成 18 个特征)。它们在 dot() 函数中的长度必须相同。

我正在尝试用 Java 实现一个程序,以从现有数据集进行训练,并使用 Spark MLLib (1.5.0) 中提供的逻辑回归算法对新数据集进行预测。我的训练和预测程序如下,我正在使用多类实现。问题是当我执行 model.predict(vector) 时(注意预测程序中的 lrmodel.predict()),我得到了预测标签。但如果我需要置信度分数怎么办?我怎样才能得到它?我已经浏览了 API,但无法找到任何给出置信度分数的特定 API。谁能帮帮我吗?

训练程序(生成.model文件)

public static void main(final String[] args) throws Exception {
        JavaSparkContext jsc = null;
        int salesIndex = 1;

        try {
           ...
       SparkConf sparkConf =
                    new SparkConf().setAppName("Hackathon Train").setMaster(
                            sparkMaster);
            jsc = new JavaSparkContext(sparkConf);
            ...

            JavaRDD<String> trainRDD = jsc.textFile(basePath + "old-leads.csv").cache();

            final String firstRdd = trainRDD.first().trim();
            JavaRDD<String> tempRddFilter =
                    trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
                        private static final long serialVersionUID =
                                11111111111111111L;

                        public Boolean call(final String arg0) {
                            return !arg0.trim().equalsIgnoreCase(firstRdd);
                        }
                    });

           ...
            JavaRDD<String> featureRDD =
                    tempRddFilter
                            .map(new org.apache.spark.api.java.function.Function() {
                                private static final long serialVersionUID =
                                        6948900080648474074L;

                                public Object call(final Object arg0)
                                        throws Exception {
                                   ...
                                    StringBuilder featureSet =
                                            new StringBuilder();
                                   ...
                                        featureSet.append(i - 2);
                                        featureSet.append(COLON);
                                        featureSet.append(strVal);
                                        featureSet.append(SPACE);
                                    }

                                    return featureSet.toString().trim();
                                }
                            });

            List<String> featureList = featureRDD.collect();
            String featureOutput = StringUtils.join(featureList, NEW_LINE);
            String filePath = basePath + "lr.arff";
            FileUtils.writeStringToFile(new File(filePath), featureOutput,
                    "UTF-8");

            JavaRDD<LabeledPoint> trainingData =
                    MLUtils.loadLibSVMFile(jsc.sc(), filePath).toJavaRDD().cache();

            final LogisticRegressionModel model =
                    new LogisticRegressionWithLBFGS().setNumClasses(18).run(
                            trainingData.rdd());
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ObjectOutputStream oos = new ObjectOutputStream(baos);
            oos.writeObject(model);
            oos.flush();
            oos.close();
            FileUtils.writeByteArrayToFile(new File(basePath + "lr.model"),
                    baos.toByteArray());
            baos.close();

        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (jsc != null) {
                jsc.close();
            }
        }

预测程序(使用训练程序生成的lr.model)

    public static void main(final String[] args) throws Exception {
        JavaSparkContext jsc = null;
        int salesIndex = 1;
        try {
            ...
        SparkConf sparkConf =
                    new SparkConf().setAppName("Hackathon Predict").setMaster(sparkMaster);
            jsc = new JavaSparkContext(sparkConf);

            ObjectInputStream objectInputStream =
                    new ObjectInputStream(new FileInputStream(basePath
                            + "lr.model"));
            LogisticRegressionModel lrmodel =
                    (LogisticRegressionModel) objectInputStream.readObject();
            objectInputStream.close();

            ...

            JavaRDD<String> trainRDD = jsc.textFile(basePath + "new-leads.csv").cache();

            final String firstRdd = trainRDD.first().trim();
            JavaRDD<String> tempRddFilter =
                    trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
                        private static final long serialVersionUID =
                                11111111111111111L;

                        public Boolean call(final String arg0) {
                            return !arg0.trim().equalsIgnoreCase(firstRdd);
                        }
                    });

            ...
            final Broadcast<LogisticRegressionModel> broadcastModel =
                    jsc.broadcast(lrmodel);

            JavaRDD<String> featureRDD =
                    tempRddFilter
                            .map(new org.apache.spark.api.java.function.Function() {
                                private static final long serialVersionUID =
                                        6948900080648474074L;

                                public Object call(final Object arg0)
                                        throws Exception {
                                   ...
                                    LogisticRegressionModel lrModel =
                                            broadcastModel.value();
                                    String row = ((String) arg0);
                                    String[] featureSetArray =
                                            row.split(CSV_SPLITTER);
                                   ...
                                    final Vector vector =
                                            Vectors.dense(doubleArr);
                                    double score = lrModel.predict(vector);
                                   ...
                                    return csvString;
                                }
                            });

            String outputContent =
                    featureRDD
                            .reduce(new org.apache.spark.api.java.function.Function2() {

                                private static final long serialVersionUID =
                                        1212970144641935082L;

                                public Object call(Object arg0, Object arg1)
                                        throws Exception {
                                    ...
                                }

                            });
            ...
            FileUtils.writeStringToFile(new File(basePath
                    + "predicted-sales-data.csv"), sb.toString());
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (jsc != null) {
                jsc.close();
            }
        }
    }
}

最佳答案

经过多次尝试,我终于成功编写了一个自定义函数来生成置信度分数。它一点也不完美,但目前对我有用!

private static double getConfidenceScore(
            final LogisticRegressionModel lrModel, final Vector vector) {
        /* Approach to get confidence scores starts */
        Vector weights = lrModel.weights();
        int numClasses = lrModel.numClasses();
        int dataWithBiasSize = weights.size() / (numClasses - 1);
        boolean withBias = (vector.size() + 1) == dataWithBiasSize;
        double maxMargin = 0.0;
        double margin = 0.0;
        for (int j = 0; j < (numClasses - 1); j++) {
            margin = 0.0;
            for (int k = 0; k < vector.size(); k++) {
                double value = vector.toArray()[k];
                if (value != 0.0) {
                    margin += value
                            * weights.toArray()[(j * dataWithBiasSize) + k];
                }
            }
            if (withBias) {
                margin += weights.toArray()[(j * dataWithBiasSize)
                        + vector.size()];
            }
            if (margin > maxMargin) {
                maxMargin = margin;
            }
        }
        double conf = 1.0 / (1.0 + Math.exp(-maxMargin));
        DecimalFormat twoDForm = new DecimalFormat("#.##");
        double confidenceScore = Double.valueOf(twoDForm.format(conf * 100));
        /* Approach to get confidence scores ends */
        return confidenceScore;
    }

关于java - 如何从 Java 中的 Spark MLLib Logistic Regression 获取置信度分数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39328601/

相关文章:

java - 如何从 Java 应用程序内部获取公共(public) IPV6 地址?

java - 使用编程配置时的 Spring Bean 命名

scala - 如何在 MLBase 中将类别变量转换为虚拟/指示变量

python - TensorFlow 中特定的线性分类器 : input element as vector

r - 如何修正我的模型以获取摘要统计信息?

java - float 大于或小于零

java - hibernate.hbm2ddl.auto create-drop 与生产数据库

java - 无法在代码中向spark-cluster提交应用程序

apache-spark - spark mllib 中逻辑回归的原始预测是什么?

python - 使用 matplotlib 在 python 中绘制曲线决策边界