apache-spark - Pyspark UDF 在两列之间返回类似于 groupby().sum() 的结果

标签 apache-spark pyspark apache-spark-sql

我有以下示例数据框

fruit_list = ['apple', 'apple', 'orange', 'apple']
qty_list = [16, 2, 3, 1]
spark_df = spark.createDataFrame([(101, 'Mark', fruit_list, qty_list)], ['ID', 'name', 'fruit', 'qty'])

我想创建另一个列,其中包含类似于我使用 pandas groupby('fruit').sum()

实现的结果
        qty
fruits     
apple    19
orange    3

上述结果可以任何形式(字符串、字典、元组列表...)存储在新列中。

我尝试了一种类似于以下方法但不起作用的方法

sum_cols = udf(lambda x: pd.DataFrame({'fruits': x[0], 'qty': x[1]}).groupby('fruits').sum())
spark_df.withColumn('Result', sum_cols(F.struct('fruit', 'qty'))).show()

结果数据框的一个示例可能是

+---+----+--------------------+-------------+-------------------------+
| ID|name|               fruit|          qty|                   Result|
+---+----+--------------------+-------------+-------------------------+
|101|Mark|[apple, apple, or...|[16, 2, 3, 1]|[(apple,19), (orange,3)] |
+---+----+--------------------+-------------+-------------------------+

您对我如何实现这一点有什么建议吗?

谢谢

编辑:在 Spark 2.4.3 上运行

最佳答案

正如@pault 提到的,从 Spark 2.4+ 开始,您可以使用 Spark SQL 内置函数来处理您的任务,这是 array_distinct + < 的一种方法强>变换 + 聚合:

from pyspark.sql.functions import expr

# set up data
spark_df = spark.createDataFrame([
        (101, 'Mark', ['apple', 'apple', 'orange', 'apple'], [16, 2, 3, 1])
      , (102, 'Twin', ['apple', 'banana', 'avocado', 'banana', 'avocado'], [5, 2, 11, 3, 1])
      , (103, 'Smith', ['avocado'], [10])
    ], ['ID', 'name', 'fruit', 'qty']
)

>>> spark_df.show(5,0)
+---+-----+-----------------------------------------+----------------+
|ID |name |fruit                                    |qty             |
+---+-----+-----------------------------------------+----------------+
|101|Mark |[apple, apple, orange, apple]            |[16, 2, 3, 1]   |
|102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]|
|103|Smith|[avocado]                                |[10]            |
+---+-----+-----------------------------------------+----------------+

>>> spark_df.printSchema()
root
 |-- ID: long (nullable = true)
 |-- name: string (nullable = true)
 |-- fruit: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- qty: array (nullable = true)
 |    |-- element: long (containsNull = true)

设置 SQL 语句:

stmt = '''
    transform(array_distinct(fruit), x -> (x, aggregate(
          transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0))
        , 0
        , (y,z) -> int(y + z)
    ))) AS sum_fruit
'''

>>> spark_df.withColumn('sum_fruit', expr(stmt)).show(10,0)
+---+-----+-----------------------------------------+----------------+----------------------------------------+
|ID |name |fruit                                    |qty             |sum_fruit                               |
+---+-----+-----------------------------------------+----------------+----------------------------------------+
|101|Mark |[apple, apple, orange, apple]            |[16, 2, 3, 1]   |[[apple, 19], [orange, 3]]              |
|102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]|[[apple, 5], [banana, 5], [avocado, 12]]|
|103|Smith|[avocado]                                |[10]            |[[avocado, 10]]                         |
+---+-----+-----------------------------------------+----------------+----------------------------------------+

解释:

  1. 使用 array_distinct(fruit) 查找数组 fruit
  2. 中的所有不同条目
  3. 将这个新数组(带有元素 x)从 x 转换为 (x, aggregate(..x..))<
  4. 上述函数aggregate(..x..)采用简单的形式,将array_T中的所有元素相加

    aggregate(array_T, 0, (y,z) -> y + z)
    

    array_T 来自以下转换:

    transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0))
    

    遍历数组fruit,如果fruit[i] = x,则返回对应的qty[i],否则返回0。例如对于ID=101,当x = 'orange',返回一个数组[0, 0, 3, 0]

关于apache-spark - Pyspark UDF 在两列之间返回类似于 groupby().sum() 的结果,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57291746/

相关文章:

scala - 从spark中的json模式动态生成df.select语句

apache-spark - 来自 Spark 安装的 Pyspark VS Pyspark python 包

apache-spark - 在 map 列的 Spark 数据框中如何使用所有键的常量更新值

apache-spark - 执行pyspark.sql.DataFrame.take(4)需要1个多小时

python - 如何在 Hadoop 环境中重新训练 Inception 图像分类器

apache-spark - 在 pyspark 中分解 Maptype 列

python - pyspark 3.0 KMeansModel 误差平方和

python - 如何在 PySpark 中将列从字符串转换为数组

java - 将函数传递给 JavaPairRDD<K,V> 中的 KEY

python - PySpark:字典类型 RDD 的迭代