python - 使用collect_list(column)将spark数据帧转换回长格式

标签 python apache-spark pyspark apache-spark-sql

假设我们有数据框 iris:

import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv')
df = spark.createDataFrame(df)

我需要按物种对萼片宽度执行一些聚合函数,例如获取每组的 3 个最高值。

import pyspark.sql.functions as F
get_max_3 = F.udf(
    lambda x: sorted(x)[-3:]
)

agged = df.groupBy('species').agg(F.collect_list('sepal_width').alias('sepal_width'))
agged = agged.withColumn('sepal_width', get_max_3('sepal_width'))

+----------+---------------+
|   species|    sepal_width|
+----------+---------------+
| virginica|[3.6, 3.8, 3.8]|
|versicolor|[3.2, 3.3, 3.4]|
|    setosa|[4.1, 4.2, 4.4]|
+----------+---------------+

现在,我如何有效地将其转换回长格式的数据帧(意味着每个物种三行,每行对应一个值)?

有没有办法在不使用collect_list的情况下做到这一点?

最佳答案

要将数据帧转换回长格式,可以使用explode;但是,要使用此方法,您需要首先修复 udf,以便它返回正确的类型:

from pyspark.sql.types import *
import pyspark.sql.functions as F

get_max_3 = F.udf(lambda x: sorted(x)[-3:], ArrayType(DoubleType()))

agged = agged.withColumn('sepal_width', get_max_3('sepal_width'))
agged.withColumn('sepal_width', F.explode(F.col('sepal_width'))).show()

+----------+-----------+
|   species|sepal_width|
+----------+-----------+
| virginica|        3.6|
| virginica|        3.8|
| virginica|        3.8|
|versicolor|        3.2|
|versicolor|        3.3|
|versicolor|        3.4|
|    setosa|        4.1|
|    setosa|        4.2|
|    setosa|        4.4|
+----------+-----------+

或者不收集为列表并展开,您可以先对 sepal_width 列进行排名,然后根据排名进行过滤:

df.selectExpr(
    "species", "sepal_width", 
    "row_number() over (partition by species order by sepal_width desc) as rn"
).where(F.col("rn") <= 3).drop("rn").show()
+----------+-----------+
|   species|sepal_width|
+----------+-----------+
| virginica|        3.8|
| virginica|        3.8|
| virginica|        3.6|
|versicolor|        3.4|
|versicolor|        3.3|
|versicolor|        3.2|
|    setosa|        4.4|
|    setosa|        4.2|
|    setosa|        4.1|
+----------+-----------+

关于python - 使用collect_list(column)将spark数据帧转换回长格式,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47060294/

相关文章:

python - 如何在 Django 中编写自己的装饰器?

apache-spark - org.apache.spark.sql.SQLContext 无法加载文件

python - 将多列与另一列进行比较时,选择立即较小/较大的值

python - 如何根据 pandas 中的标签仅选择某些行?

python - 将 Pandas 中的某些列替换为 `filter(like = "")`

python - 如何不错过 itertools.takewhile() 之后的下一个元素

apache-spark - Spark SaveAsTextFile返回错误-Py4JJaveError

python - Databricks-如何将 accessToken 传递给spark._sc._gateway.jvm.java.sql.DriverManager?

apache-spark - 我应该使用哪个记录器在 Cloud Logging 中获取数据

python - 多索引分类并在 PySpark 中对其进行编码