scala - 在 Apache Spark 中的 groupBy 之后聚合 Map 中的所有列值

标签 scala apache-spark apache-spark-sql

我一整天都在用 Dataframe 尝试这个,但到目前为止还没有运气。已经用 RDD 做到了,但它并不是真正可读,所以这种方法在代码可读性方面会好得多。

采用这个初始 DF 和结果 DF,包括起始 DF 和我希​​望在执行 .groupBy() 后获得的结果。

case class SampleRow(name:String, surname:String, age:Int, city:String)
case class ResultRow(name: String, surnamesAndAges: Map[String, (Int, String)])

val df = List(
  SampleRow("Rick", "Fake", 17, "NY"),
  SampleRow("Rick", "Jordan", 18, "NY"),
  SampleRow("Sandy", "Sample", 19, "NY")
).toDF()

val resultDf = List(
  ResultRow("Rick", Map("Fake" -> (17, "NY"), "Jordan" -> (18, "NY"))),
  ResultRow("Sandy", Map("Sample" -> (19, "NY")))
).toDF()

到目前为止我尝试过执行以下.groupBy...

val resultDf = df
  .groupBy(
    Name
  )
  .agg(
    functions.map(
      selectColumn(Surname),
      functions.array(
        selectColumn(Age),
        selectColumn(City)
      )
    )
  )

但是,控制台会提示以下内容。

Exception in thread "main" org.apache.spark.sql.AnalysisException: expression '`surname`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;;

但是,这样做会导致每个姓氏只有一个条目,我希望将这些条目累积在单个 map 中,正如您在 resultDf 中看到的那样。有没有一种简单的方法可以使用 DF 来实现这一目标?

最佳答案

您可以使用单个 UDF 将数据转换为 map 来实现:

 val toMap = udf((keys: Seq[String], values1: Seq[String], values2: Seq[String]) => {
    keys.zip(values1.zip(values2)).toMap
  })



   val myResultDF = df.groupBy("name").agg(collect_list("surname") as "surname", collect_list("age") as "age", collect_list("city") as "city").withColumn("surnamesAndAges", toMap($"surname", $"age", $"city")).drop("age", "city", "surname").show(false)
+-----+--------------------------------------+
|name |surnamesAndAges                       |
+-----+--------------------------------------+
|Sandy|[Sample -> [19, NY]]                  |
|Rick |[Fake -> [17, NY], Jordan -> [18, NY]]|
+-----+--------------------------------------+

关于scala - 在 Apache Spark 中的 groupBy 之后聚合 Map 中的所有列值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57787239/

相关文章:

python - 当我远未达到最大连接数时,为什么 PostgreSQL 会说 FATAL : sorry, 客户端数量过多?

amazon-web-services - 将数据从 Amazon Redshift 导出为 JSON

python - 使用 Pyspark-sql 将 unix 时间转换为日期时间的结果不正确

python - 如何计算pyspark中每行某些列的最大值

apache-spark-sql - 来自 sbt scala 的 google dataproc 上的 Spark-SQL

scala - 将 Scala 数组转换为 Java 数组的最快方法

scala - 我如何解释 fold 和 foldK 之间的区别?

Scala 通过表达式向数据框添加新列

Scala:什么是 CompactBuffer?

scala - 如何在Scala(Spark)中比较两个数据框中的列