scala - CrossValidator 不支持 VectorUDT 作为 spark-ml 中的标签

标签 scala apache-spark apache-spark-mllib apache-spark-ml

我在使用一个热编码器时遇到 scala spark 中的 ml.crossvalidator 问题。

这是我的代码

val tokenizer = new Tokenizer().
                    setInputCol("subjects").
                    setOutputCol("subject")

//CountVectorizer / TF
val countVectorizer = new CountVectorizer().
                        setInputCol("subject").
                        setOutputCol("features")

// convert string into numerical values
val labelIndexer = new StringIndexer().
                        setInputCol("labelss").
                        setOutputCol("labelsss")

// convert numerical to one hot encoder
val labelEncoder = new OneHotEncoder().
                   setInputCol("labelsss").
                   setOutputCol("label")

val logisticRegression = new LogisticRegression()

val pipeline = new Pipeline().setStages(Array(tokenizer,countVectorizer,labelIndexer,labelEncoder,logisticRegression))

然后给我这样的错误

cv: org.apache.spark.ml.tuning.CrossValidator = cv_8cc1ae985e39
java.lang.IllegalArgumentException: requirement failed: Column label must be of type NumericType but was actually of type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7.

我不知道如何解决。

我需要一个热编码器,因为我的标签是绝对的。

谢谢你帮助我:)

最佳答案

实际上没有必要对标签 (目标变量) 使用 OneHotEncoder/OneHotEncoderEstimator,您实际上也不应该这样做。这将创建一个向量(type org.apache.spark.ml.linalg.VectorUDT)。

StringIndexer 足以定义您的标签是分类的。

让我们用一个小例子来验证一下:

val df = Seq((0, "a"),(1, "b"),(2, "c"),(3, "a"),(4, "a"),(5, "c")).toDF("category", "text")
// df: org.apache.spark.sql.DataFrame = [category: int, text: string]

val indexer = new StringIndexer().setInputCol("category").setOutputCol("categoryIndex").fit(df)
// indexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_cf691c087e1d

val indexed = indexer.transform(df)
// indexed: org.apache.spark.sql.DataFrame = [category: int, text: string ... 1 more field]

indexed.schema.map(_.metadata).foreach(println)
// {}
// {}
// {"ml_attr":{"vals":["4","5","1","0","2","3"],"type":"nominal","name":"categoryIndex"}}

正如您所注意到的,StringIndexer 实际上将元数据附加到该列 (categoryIndex) 并将其标记为nominal a.k.a categorical.

您还可以注意到,在列的属性中,您有类别列表。

在我关于 How to handle categorical features with spark-ml? 的其他回答中有更多相关信息

关于使用 spark-ml 准备数据元数据,我强烈建议您阅读以下条目:

https://github.com/awesome-spark/spark-gotchas/blob/5ad4c399ffd2821875f608be8aff9f1338478444/06_data_preparation.md

免责声明:我是链接中条目的合著者。

注意:(文档摘录)

Because this existing OneHotEncoder is a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a new OneHotEncoderEstimator was created that produces an OneHotEncoderModel when fitting. For more detail, please see SPARK-13030.

OneHotEncoder has been deprecated in 2.3.0 and will be removed in 3.0.0. Please use OneHotEncoderEstimator instead.

关于scala - CrossValidator 不支持 VectorUDT 作为 spark-ml 中的标签,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50598738/

相关文章:

parsing - 如何使用 Scala Parser Combinators 更改代码以考虑运算符优先级?

scala - 如何在我的 Spark 应用程序中使用 OpenHashSet?

java - Spark-maven 全新安装 : How to compile both Java and Scala classes

scala - 在Apache Spark中删除空的DataFrame分区

Scala 要么带单位

scala - 如何在 UDAF 的 MutableAggregationBuffer 中添加/改变 Map 对象?

scala - 如何从 UDF 创建自定义 Transformer?

python - SparkSession 与上下文混淆

java - 为什么在 IntelliJ IDEA 中运行 MLlib 项目失败并显示 "AssertionError: assertion failed: unsafe symbol CompatContext"?

scala - 将元数据附加到 Spark 中的向量列