在Scala/Spark中,具有一个数据框:
val dfIn = sqlContext.createDataFrame(Seq(
("r0", 0, 2, 3),
("r1", 1, 0, 0),
("r2", 0, 2, 2))).toDF("id", "c0", "c1", "c2")
我想计算一个新列
maxCol
,其中包含与最大值相对应的列的名称(针对每一行)。在此示例中,输出应为:+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0| 0| 2| 3| c2|
| r1| 1| 0| 0| c0|
| r2| 0| 2| 2| c1|
+---+---+---+---+------+
实际上,数据框有60多个列。因此,需要通用的解决方案。
Python Pandas中的等效项(是的,我知道,我应该与pyspark ...进行比较)可能是:
dfOut = pd.concat([dfIn, dfIn.idxmax(axis=1).rename('maxCol')], axis=1)
最佳答案
一个小技巧,您可以使用greatest
函数。所需进口:
import org.apache.spark.sql.functions.{col, greatest, lit, struct}
首先,我们创建一个
structs
列表,其中第一个元素是value,第二个是列名:val structs = dfIn.columns.tail.map(
c => struct(col(c).as("v"), lit(c).as("k"))
)
可以将这样的结构传递给
greatest
,如下所示:dfIn.withColumn("maxCol", greatest(structs: _*).getItem("k"))
+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0| 0| 2| 3| c2|
| r1| 1| 0| 0| c0|
| r2| 0| 2| 2| c2|
+---+---+---+---+------+
请注意,在平局的情况下,它将采用序列中稍后出现的元素(按字典顺序
(x, "c2") > (x, "c1")
)。如果由于某种原因这是 Not Acceptable ,则可以使用when
显式地减少:import org.apache.spark.sql.functions.when
val max_col = structs.reduce(
(c1, c2) => when(c1.getItem("v") >= c2.getItem("v"), c1).otherwise(c2)
).getItem("k")
dfIn.withColumn("maxCol", max_col)
+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0| 0| 2| 3| c2|
| r1| 1| 0| 0| c0|
| r2| 0| 2| 2| c1|
+---+---+---+---+------+
如果是
nullable
列,则必须对此进行调整,例如,通过coalescing
设置为-Inf
的值。
关于scala - Scala/Spark数据帧: find the column name corresponding to the max,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42485108/