java - 如何从 Java 中的 TrainValidationSplitModel 中提取最佳参数集?

标签 java apache-spark random-forest apache-spark-mllib

我正在使用 ParamGridBuilder 构建参数网格以进行搜索,并使用 TrainValidationSplit 来确定 Java 中的最佳模型 (RandomForestClassifier)。现在,我想知道 ParamGridBuilder 中生成最佳模型的参数(maxDepth、numTrees)是什么。

      Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{
          new VectorAssembler()
         .setInputCols(new String[]{"a", "b"}).setOutputCol("features"), 
          new RandomForestClassifier()
         .setLabelCol("label")
         .setFeaturesCol("features")});

      ParamMap[] paramGrid = new ParamGridBuilder()
            .addGrid(rf.maxDepth(), new int[]{10, 15})
            .addGrid(rf.numTrees(), new int[]{5, 10})
            .build();

      BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator().setLabelCol("label");

      TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
            .setEstimator(pipeline)
            .setEstimatorParamMaps(paramGrid)
            .setEvaluator(evaluator)
            .setTrainRatio(0.85);
      TrainValidationSplitModel model = trainValidationSplit.fit(dataLog);

      System.out.println("paramMap size: " + model.bestModel().paramMap().size());
      System.out.println("defaultParamMap size: " + model.bestModel().defaultParamMap().size());
      System.out.println("extractParamMap: " + model.bestModel().extractParamMap());
      System.out.println("explainParams: " + model.bestModel().explainParams());
      System.out.println("numTrees: " + model.bestModel().getParam("numTrees"))//NoSuchElementException: Param numTrees does not exist.

这些尝试没有帮助...

paramMap size: 0
defaultParamMap size: 0
extractParamMap: {

}
explainParams: 

最佳答案

我找到了一种方法:

Pipeline bestModelPipeline = (Pipeline) model.bestModel().parent();
RandomForestClassifier bestRf = (RandomForestClassifier) bestModelPipeline.getStages()[1];

System.out.println("maxDepth : " + bestRf.getMaxDepth());
System.out.println("numTrees : " + bestRf.getNumTrees());
System.out.println("maxBins : " + bestRf.getMaxBins());

关于java - 如何从 Java 中的 TrainValidationSplitModel 中提取最佳参数集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60322875/

相关文章:

java - Java 的错误答案

java - Ubuntu 内核杀死 java 进程,即使它没有内存不足

python - 使用 PySpark 展平嵌套 json 响应结构的最有效方法是什么?

apache-spark - Apache Spark : Get the first and last row of each partition

r - 为什么重要性参数会影响 R 中随机森林的性能?

r - 是否可以更改 randomForest 中使用的引导和/或子采样方案?

javascript - 调试 Java Stream() NullPointerException 原因的最佳方法

java - 通过hssf/xssf向现有excel单元格写入数据。

amazon-web-services - 如何在 EMR 中设置自定义环境变量以供 spark 应用程序使用

tensorflow - 带有 CART 树的 TensorFlow 随机森林使用什么杂质指数(基尼系数、熵?)?