scala - 如何从交叉验证器获得经过训练的最佳模型

标签 scala apache-spark machine-learning decision-tree cross-validation

我构建了一个包含像这样的 DecisionTreeClassifier(dt) 的管道

val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))

然后我使用这个管道作为 CrossValidator 中的估计器,以获得具有最佳超参数集的模型,如下所示

val c_v = new CrossValidator().setEstimator(pipeline).setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")).setEstimatorParamMaps(paramGrid).setNumFolds(5)

最后,我可以使用此交叉验证器在训练测试中训练模型

val model = c_v.fit(train)

但问题是,我想用DecisionTreeClassificationModel的参数.toDebugTree来查看训练最好的决策树模型。但模型是一个 CrossValidatorModel。是的,您可以使用 model.bestModel,但它仍然是 Model 类型,您不能将 .toDebugTree 应用于它。而且我还假设 bestModel 仍然是一个管道,包括 labelIndexerfeatureIndexerdtlabelConverter

那么有谁知道如何从交叉验证器拟合的模型中获取决策树模型,我可以通过toDebugString查看实际模型?或者有什么解决方法可以让我查看决策树模型吗?

最佳答案

嗯,在 cases like this one答案总是相同的 - 具体说明类型。

首先提取管道模型,因为您要训练的是管道:

import org.apache.spark.ml.PipelineModel

val bestModel: Option[PipelineModel] = model.bestModel match {
  case p: PipelineModel => Some(p)
  case _ => None
}

然后您需要从底层提取模型。在你的例子中,它是一个决策树分类模型:

import org.apache.spark.ml.classification.DecisionTreeClassificationModel

val treeModel: Option[DecisionTreeClassificationModel] = bestModel
  flatMap {
    _.stages.collect {
      case t: DecisionTreeClassificationModel => t
    }.headOption
  }

打印树,例如:

treeModel.foreach(_.toDebugString)

关于scala - 如何从交叉验证器获得经过训练的最佳模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36347875/

相关文章:

java - 使用 Spark 检测大型数据集中的重复连续值

machine-learning - Keras 在 1000 多个类上进行迁移学习

xml - 从 Scala 中的 StructType 中提取行标记架构以解析嵌套 XML

apache-spark - Spark升级到2.4.5时出现NoSuchMethodError

scala - HList/KList 是否适合作为方法参数?如何引用?类型列表?

scala - 将 DataFrame 写入 Parquet 或 Delta 似乎没有被并行化 - 耗时太长

python - 像素 RNN Pytorch 实现

audio - 我可以将扬声器与音调,音色和音量匹配吗?

Scalaz 无法解析符号 |+|,未使用的导入语句

Scala:带元组的 flatMap