我有以下示例数据框
fruit_list = ['apple', 'apple', 'orange', 'apple']
qty_list = [16, 2, 3, 1]
spark_df = spark.createDataFrame([(101, 'Mark', fruit_list, qty_list)], ['ID', 'name', 'fruit', 'qty'])
我想创建另一个列,其中包含类似于我使用 pandas groupby('fruit').sum()
qty
fruits
apple 19
orange 3
上述结果可以任何形式(字符串、字典、元组列表...)存储在新列中。
我尝试了一种类似于以下方法但不起作用的方法
sum_cols = udf(lambda x: pd.DataFrame({'fruits': x[0], 'qty': x[1]}).groupby('fruits').sum())
spark_df.withColumn('Result', sum_cols(F.struct('fruit', 'qty'))).show()
结果数据框的一个示例可能是
+---+----+--------------------+-------------+-------------------------+
| ID|name| fruit| qty| Result|
+---+----+--------------------+-------------+-------------------------+
|101|Mark|[apple, apple, or...|[16, 2, 3, 1]|[(apple,19), (orange,3)] |
+---+----+--------------------+-------------+-------------------------+
您对我如何实现这一点有什么建议吗?
谢谢
编辑:在 Spark 2.4.3 上运行
最佳答案
正如@pault 提到的,从 Spark 2.4+ 开始,您可以使用 Spark SQL 内置函数来处理您的任务,这是 array_distinct + < 的一种方法强>变换 + 聚合:
from pyspark.sql.functions import expr
# set up data
spark_df = spark.createDataFrame([
(101, 'Mark', ['apple', 'apple', 'orange', 'apple'], [16, 2, 3, 1])
, (102, 'Twin', ['apple', 'banana', 'avocado', 'banana', 'avocado'], [5, 2, 11, 3, 1])
, (103, 'Smith', ['avocado'], [10])
], ['ID', 'name', 'fruit', 'qty']
)
>>> spark_df.show(5,0)
+---+-----+-----------------------------------------+----------------+
|ID |name |fruit |qty |
+---+-----+-----------------------------------------+----------------+
|101|Mark |[apple, apple, orange, apple] |[16, 2, 3, 1] |
|102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]|
|103|Smith|[avocado] |[10] |
+---+-----+-----------------------------------------+----------------+
>>> spark_df.printSchema()
root
|-- ID: long (nullable = true)
|-- name: string (nullable = true)
|-- fruit: array (nullable = true)
| |-- element: string (containsNull = true)
|-- qty: array (nullable = true)
| |-- element: long (containsNull = true)
设置 SQL 语句:
stmt = '''
transform(array_distinct(fruit), x -> (x, aggregate(
transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0))
, 0
, (y,z) -> int(y + z)
))) AS sum_fruit
'''
>>> spark_df.withColumn('sum_fruit', expr(stmt)).show(10,0)
+---+-----+-----------------------------------------+----------------+----------------------------------------+
|ID |name |fruit |qty |sum_fruit |
+---+-----+-----------------------------------------+----------------+----------------------------------------+
|101|Mark |[apple, apple, orange, apple] |[16, 2, 3, 1] |[[apple, 19], [orange, 3]] |
|102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]|[[apple, 5], [banana, 5], [avocado, 12]]|
|103|Smith|[avocado] |[10] |[[avocado, 10]] |
+---+-----+-----------------------------------------+----------------+----------------------------------------+
解释:
- 使用
array_distinct(fruit)
查找数组fruit
中的所有不同条目
- 将这个新数组(带有元素
x
)从x
转换为(x, aggregate(..x..))
< 上述函数aggregate(..x..)采用简单的形式,将array_T中的所有元素相加
aggregate(array_T, 0, (y,z) -> y + z)
array_T 来自以下转换:
transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0))
遍历数组
fruit
,如果fruit[i] = x,则返回对应的qty[i],否则返回0。例如对于ID=101,当x = 'orange',返回一个数组[0, 0, 3, 0]
关于apache-spark - Pyspark UDF 在两列之间返回类似于 groupby().sum() 的结果,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57291746/