scala - Spark实现Scala api的并行交叉验证

标签 scala apache-spark cross-validation apache-spark-ml

Pyspark 通过https://github.com/databricks/spark-sklearn 为模型的并行交叉验证提供了很大的可能性。 作为 sklearn 的 GridSearchCV 的简单替换为

from spark_sklearn import GridSearchCV

我怎样才能为 Spark 的 Scala CrossValidator 实现类似的功能,即并行化每个折叠?

最佳答案

自 Spark 2.3 起:

您可以通过 CrossValidator 或在创建时使用 setParallelism(n) 方法来完成此操作。即:

cv.setParallelism(2) 

cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, \ 
                    parallelism=2)  // Evaluate up to 2 parameter settings in parallel

在 Spark 2.3 之前:

你不能在 Spark Scala 中做到这一点。您无法在 Scala Spark 中并行化交叉验证。

如果你已经很好地阅读了 spark-sklearn 的文档,GridSearchCV 是并行化的,但模型训练不是。因此,这在规模上是无用的。此外,由于著名的 SPARK-5063,您可以并行化 Spark Scala API 的交叉验证:

RDD transformations and actions can only be invoked by the driver, not inside of other transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because the values transformation and count action cannot be performed inside of the rdd1.map transformation. For more information, see SPARK-5063.

摘自 README.md :

This package contains some tools to integrate the Spark computing framework with the popular scikit-learn machine library. Among other tools:

train and evaluate multiple scikit-learn models in parallel. It is a distributed analog to the multicore implementation included by default in scikit-learn. convert Spark's Dataframes seamlessly into numpy ndarrays or sparse matrices. (experimental) distribute Scipy's sparse matrices as a dataset of sparse vectors. It focuses on problems that have a small amount of data and that can be run in parallel.

for small datasets, it distributes the search for estimator parameters (GridSearchCV in scikit-learn), using Spark, for datasets that do not fit in memory, we recommend using the distributed implementation in Spark MLlib.

NOTE: This package distributes simple tasks like grid-search cross-validation. It does not distribute individual learning algorithms (unlike Spark MLlib).

关于scala - Spark实现Scala api的并行交叉验证,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41479359/

相关文章:

multithreading - 多线程Scala中长进程的惯用超时

scala - 如何关闭 Apache Spark 2 中生成代码的日志记录?

python - pyspark dataframe 将 json 列转换为新列

r - mlr 包 : Cross-validation with tuneParams() and resample() yield different results

machine-learning - lightgbm python 数据集引用参数是什么意思?

linux - 为 scala.sys.process 进程提供引用参数

postgresql - 使用 Slick 监听 PostgreSQL NOTIFY 事件

python - sklearn KFold() - 将所有折叠保存到 csv 文件

scala - 如何在 Spark 中一次对多列进行聚合

apache-spark - 如何将 Java List<Date> 转换为 Spark Dataset<Row>