java - 使用Java的Spark MLlib分类输入格式

标签 java apache-spark apache-spark-mllib apache-spark-ml

如何将 DTO 列表转换为 Spark ML 输入数据集格式

我有 DTO:

public class MachineLearningDTO implements Serializable {
    private double label;
    private double[] features;

    public MachineLearningDTO() {
    }

    public MachineLearningDTO(double label, double[] features) {
        this.label = label;
        this.features = features;
    }

    public double getLabel() {
        return label;
    }

    public void setLabel(double label) {
        this.label = label;
    }

    public double[] getFeatures() {
        return features;
    }

    public void setFeatures(double[] features) {
        this.features = features;
    }
}

和代码:

Dataset<MachineLearningDTO> mlInputDataSet = spark.createDataset(mlInputData, Encoders.bean(MachineLearningDTO.class));
LogisticRegression logisticRegression = new LogisticRegression();
LogisticRegressionModel model = logisticRegression.fit(MLUtils.convertMatrixColumnsToML(mlInputDataSet));

执行代码后我得到:

java.lang.IllegalArgumentException: requirement failed: Column features must be of type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 but was actually ArrayType(DoubleType,false).

如果将其更改为org.apache.spark.ml.linalg.VectorUDT,代码如下:

VectorUDT vectorUDT = new VectorUDT();
vectorUDT.serialize(Vectors.dense(......));

然后我得到:

java.lang.UnsupportedOperationException: Cannot infer type for class org.apache.spark.ml.linalg.VectorUDT because it is not bean-compliant

at org.apache.spark.sql.catalyst.JavaTypeInference$.org$apache$spark$sql$catalyst$JavaTypeInference$$serializerFor(JavaTypeInference.scala:437)

最佳答案

我已经想通了,以防万一有人也坚持使用它,我编写了简单的转换器并且它可以工作:

private Dataset<Row> convertToMlInputFormat(List< MachineLearningDTO> data) {
    List<Row> rowData = data.stream()
            .map(dto ->
                    RowFactory.create(dto.getLabel() ? 1.0d : 0.0d, Vectors.dense(dto.getFeatures())))
            .collect(Collectors.toList());
    StructType schema = new StructType(new StructField[]{
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("features", new VectorUDT(), false, Metadata.empty()),
    });

    return spark.createDataFrame(rowData, schema);
}

关于java - 使用Java的Spark MLlib分类输入格式,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44524501/

相关文章:

Java排序程序

java - 自定义 RDD 的分区提示

apache-spark - 如何设置SPARK_HOME变量?

scala - Spark MLlib ALS 中的非整数 ID

apache-spark - PySpark - ALS 输出中的 RDD 到 DataFrame

java - 如何确保多个数据库在 spring 中与单个事务管理器进行事务处理?

java - 问题覆盖 tostring

java - 谷歌图表数据格式错误(javascript)

java - 在 Java Spark 中使用 Jackson 2.9.9

scala - 在spark中为LDA准备数据