apache-spark - Pyspark --- 添加新列,其中包含每组的值

标签 apache-spark dataframe group-by pyspark

假设我有以下数据集:

a | b   
1 | 0.4 
1 | 0.8 
1 | 0.5 
2 | 0.4
2 | 0.1

我想添加一个名为“label”的新列,其中的值是在本地为 a 中的每组值确定的。组 ab 的最高值标记为 1,所有其他值标记为 0。

输出如下:

a | b   | label
1 | 0.4 | 0
1 | 0.8 | 1
1 | 0.5 | 0
2 | 0.4 | 1
2 | 0.1 | 0

如何使用 PySpark 高效地完成此操作?

最佳答案

您可以使用窗口函数来做到这一点。首先,您需要一些导入:

from pyspark.sql.functions import desc, row_number, when
from pyspark.sql.window import Window

和窗口定义:

w = Window().partitionBy("a").orderBy(desc("b"))

最后你使用这些:

df.withColumn("label", when(row_number().over(w) == 1, 1).otherwise(0))

例如数据:

df = sc.parallelize([
    (1, 0.4), (1, 0.8), (1, 0.5), (2, 0.4), (2, 0.1)
]).toDF(["a", "b"])

结果是:

+---+---+-----+
|  a|  b|label|
+---+---+-----+
|  1|0.8|    1|
|  1|0.5|    0|
|  1|0.4|    0|
|  2|0.4|    1|
|  2|0.1|    0|
+---+---+-----+

关于apache-spark - Pyspark --- 添加新列,其中包含每组的值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41050991/

相关文章:

scala - Scala 中 TypedDataset 和类型界限的隐式编码器

python - 无法将大型 Spark 数据帧从 databricks 写入 csv

PHP 使用 LEFT JOIN 时如何按结果数组进行分组

python - Pandas 数据框计数矩阵

mysql - Rails + MySQL GROUP BY 并从 GROUP 中找到 MAX

apache-spark - Spark 流、Kafka 和多个主题的性能不佳

python - pyspark 中处理大数据的优化

python - 扁平化 python 数据框中的条目,例如 Apache PIG 包

python - 从分类数据类型列中提取平均值

r - 如何从数据框中的关联因子水平中减去值