scala - Scala/Spark数据帧: find the column name corresponding to the max

标签 scala apache-spark dataframe apache-spark-sql argmax

在Scala/Spark中,具有一个数据框:

val dfIn = sqlContext.createDataFrame(Seq(
  ("r0", 0, 2, 3),
  ("r1", 1, 0, 0),
  ("r2", 0, 2, 2))).toDF("id", "c0", "c1", "c2")

我想计算一个新列maxCol,其中包含与最大值相对应的列的名称(针对每一行)。在此示例中,输出应为:

+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0|  0|  2|  3|    c2|
| r1|  1|  0|  0|    c0|
| r2|  0|  2|  2|    c1|
+---+---+---+---+------+

实际上,数据框有60多个列。因此,需要通用的解决方案。

Python Pandas中的等效项(是的,我知道,我应该与pyspark ...进行比较)可能是:
dfOut = pd.concat([dfIn, dfIn.idxmax(axis=1).rename('maxCol')], axis=1) 

最佳答案

一个小技巧,您可以使用greatest函数。所需进口:

import org.apache.spark.sql.functions.{col, greatest, lit, struct}

首先,我们创建一个structs列表,其中第一个元素是value,第二个是列名:
val structs = dfIn.columns.tail.map(
  c => struct(col(c).as("v"), lit(c).as("k"))
)

可以将这样的结构传递给greatest,如下所示:
dfIn.withColumn("maxCol", greatest(structs: _*).getItem("k"))

+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0|  0|  2|  3|    c2|
| r1|  1|  0|  0|    c0|
| r2|  0|  2|  2|    c2|
+---+---+---+---+------+

请注意,在平局的情况下,它将采用序列中稍后出现的元素(按字典顺序(x, "c2") > (x, "c1"))。如果由于某种原因这是 Not Acceptable ,则可以使用when显式地减少:
import org.apache.spark.sql.functions.when

val max_col = structs.reduce(
  (c1, c2) => when(c1.getItem("v") >= c2.getItem("v"), c1).otherwise(c2)
).getItem("k")

dfIn.withColumn("maxCol", max_col)

+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0|  0|  2|  3|    c2|
| r1|  1|  0|  0|    c0|
| r2|  0|  2|  2|    c1|
+---+---+---+---+------+

如果是nullable列,则必须对此进行调整,例如,通过coalescing设置为-Inf的值。

关于scala - Scala/Spark数据帧: find the column name corresponding to the max,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42485108/

相关文章:

python - 在Python中将给定数据框的列的选定值分配给另一个数据框

scala - 如何将字符串写入 Scala Process?

scala - 部署静态文件的方法是什么,以便 Spray 可以为它们提供服务?

r - Spark Dataframe 中的重复列

python - 如何在数据帧上使用 UserDefinedFunction 解决错误 "Method __getnewargs__([]) does not exist"?

python - 如何通过另一个文件中的值更改 csv 列中的值

syntax - java代码中的Scala : $colon

scala - specs2 验收测试中的案例类上下文 : "must is not a member of Int"

apache-spark - spark Dataframe 执行更新语句

scala - 如何使用 Map[String,Long] 列作为 DataFrame 的头部并保留类型?