scala - 是否可以为 Spark ML 中的随机森林创建通用训练管道?

标签 scala apache-spark machine-learning apache-spark-mllib

我刚刚开始使用 Spark 和 Spark ML,我发现它比 Python 和 sklearn 困难得多。

开发时间要长得多,所以我想知道是否可以制作一个适用于任何(足够小的)数据集并训练随机森林分类器的通用管道。理想情况下,我会创建一个类似的函数

def trainClassifier(df: DataFrame, labelColumn: String) {
  ...
}

Spark 中的大量开发时间都花在将列编码为数字列,然后根据特征形成向量,以便 Spark ML 的随机森林实际上可以使用它。所以最终会写出像

这样的行
val indexer = new StringIndexer()
                   .setInputCol("category")
                   .setOutputCol("categoryIndex")
                   .fit(df)

val indexed = indexer.transform(df)

val encoder = new OneHotEncoder()
                   .setInputCol("categoryIndex")
                   .setOutputCol("categoryVec")

val encoded = encoder.transform(indexed)

所以我的问题更多的是一个设计问题(如果合适,请引导我到不同的站点),关于如何编写适用于任何 DataFrame 的分类通用训练函数。 ,但这也是一个关于 Spark 的问题,因为我问这种事情在 Spark 中是否可行(所以这是一个 API 问题,所以它更适合 stackoverflow)?

编辑:我的意思是我不指定列并为每个新数据帧手动转换列。我想要一个函数trainClassifier它将接受具有不同列和不同列类型的各种数据框。迭代除 labelColumn 之外的所有列并将它们一起编译成分类器可以使用的特征向量的东西。

最佳答案

您可以创建自定义管道:

val start = "category"; // can be parameter of method or function
val indexer = new StringIndexer()
               .setInputCol(start )
               .setOutputCol(start + "Index")
               .fit(df)

val encoder = new OneHotEncoder()
               .setInputCol(encoder.outputCol)
               .setOutputCol(start  + "encoded") 

这些步骤可以在返回 Array[Stage] - Array(indexer, encoder) 的函数中。现在你可以写成here一些函数来连接所有数组并创建管道:

val randomForest = ... 

val pipeline = new Pipeline()
    .setStages(allStepsArray(indexer , encoder , randomForest))

然后你可以在 Pipeline 上调用 fit ,甚至可以像 link 一样构建 CrossValidator :

val model = pipeline.fit(testData)

关于scala - 是否可以为 Spark ML 中的随机森林创建通用训练管道?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45720960/

相关文章:

scala - 案例类中的产品继承

scala - 使用 scalaz' |> 切换函数和对象

multithreading - 如何使并发与写入 hive 表的数据帧一起工作?

machine-learning - 如何在 python 中加载 Node2Vec 嵌入生成的 .model 和 .emb 文件?

json - 特征类型参数的隐式编码器

具有自定义数据格式的 JavaFX DragAndDrop

java - Spark saveAsNewAPIHadoopFile java.io.IOException : Could not find a serializer for the Value class

apache-spark - Spark中LDA模型的在线学习

python - 我想用 z 矩阵所有值的平均值填充 z 矩阵中的缺失值

python - 如何在 Keras 中使用 fit_generator() 来平衡数据集?