不久前遇到了这个问题,我认为应该有一个更好/更有效的方法来做到这一点:
我有一个大约有 70k
列和大约 10k
行的 DF。我基本上想根据行的值获取每列的计数。
df.columns.map( c => df.where(column(c)===1).count )
这适用于少量列,但在这种情况下,大量列会导致该过程需要几个小时,并且似乎要迭代每一列并查询数据。
我可以做哪些优化来更快地获得结果?
最佳答案
您可以将每一列的值替换为 1
或 0
,具体取决于列先前的值是否匹配条件,然后在一次聚合中对每一列求和。之后,您可以收集结果数据帧的唯一行并将其设为数组。
所以代码如下:
import org.apache.spark.sql.functions.{col, lit, sum, when}
val aggregation_columns = df.columns.map(c => sum(col(c)))
df
.columns
.foldLeft(df)((acc, elem) => acc.withColumn(elem, when(col(elem) === 1, lit(1)).otherwise(lit(0))))
.agg(aggregation_columns.head, aggregation_columns.tail: _*)
.collect()
.flatMap(row => df.columns.indices.map(i => row.getLong(i))
关于scala - Spark 计数大量列,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69304509/