scala - 如何计算 Spark 中每位客户在 12 个月内按 1 个月滑动的订单总和

标签 scala apache-spark apache-spark-sql aggregation

我对 Scala 的 Spark 比较陌生。目前,我正在尝试在每月下滑的 12 个月内以 spark 形式汇总订单数据。


import spark.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._

var sample = Seq(("C1","01/01/2016", 20), ("C1","02/01/2016", 5), 
 ("C1","03/01/2016", 2),  ("C1","04/01/2016", 3), ("C1","05/01/2017", 5),
 ("C1","08/01/2017", 5), ("C1","01/02/2017", 10), ("C1","01/02/2017", 10),  
 ("C1","01/03/2017", 10)).toDF("id","order_date", "orders")

sample = sample.withColumn("order_date",
to_date(unix_timestamp($"order_date", "dd/MM/yyyy").cast("timestamp"))) 
 | id|order_date|orders|
 | C1|2016-01-01|    20|
 | C1|2016-01-02|     5|
 | C1|2016-01-03|     2|
 | C1|2016-01-04|     3|
 | C1|2017-01-05|     5|
 | C1|2017-01-08|     5|
 | C1|2017-02-01|    10|
 | C1|2017-02-01|    10|
 | C1|2017-03-01|    10|

id      period_start    period_end  rolling
C1      2015-01-01      2016-01-01  30
C1      2016-01-01      2017-01-01  40
C1      2016-02-01      2017-02-01  30
C1      2016-03-01      2017-03-01  40



(e.i. 2016-01-[1..31] >> 2016-01-01 )

import org.joda.time._

val collapse_month = (month:Integer, year:Integer ) => {
   var  dt = new DateTime().withYear(year)

val collapse_month_udf = udf(collapse_month)

sample = sample.withColumn("period_end",

sample.groupBy($"id",  $"period_end")
 | id|period_end|orders|
 | C1|2016-01-01|    30|
 | C1|2017-01-01|    10|
 | C1|2017-02-01|    20|
 | C1|2017-03-01|    10|

我尝试了提供的 window功能,但我无法通过一个选项使用 12 个月。

我真的不确定从这一点开始的最佳方法是什么,考虑到我必须处理多少数据,这不会花费 5 个小时。



tried the provided window function but I was not able to use 12 months sliding by one option.

您仍然可以使用 window间隔更长,但所有参数都必须以天或周表示:
window($"order_date", "365 days", "28 days")


val byMonth = sample
  .groupBy($"id", trunc($"order_date", "month").alias("order_month"))

| id|order_month|sum(orders)|
| C1| 2017-01-01|         10|
| C1| 2016-01-01|         30|
| C1| 2017-02-01|         20|
| C1| 2017-03-01|         10|

import java.time.temporal.ChronoUnit

val Row(start: java.sql.Date, end: java.sql.Date) = byMonth
  .select(min($"order_month"), max($"order_month"))

val months = (0L to ChronoUnit.MONTHS.between(
    start.toLocalDate, end.toLocalDate))
  .map(i => java.sql.Date.valueOf(start.toLocalDate.plusMonths(i)))

并结合唯一 ID:
val ref =$"id").distinct.crossJoin(months)

val expanded = ref.join(byMonth, Seq("id", "order_month"), "leftouter")

| id|order_month|orders|
| C1| 2016-01-01|    30|
| C1| 2016-02-01|  null|
| C1| 2016-03-01|  null|
| C1| 2016-04-01|  null|
| C1| 2016-05-01|  null|
| C1| 2016-06-01|  null|
| C1| 2016-07-01|  null|
| C1| 2016-08-01|  null|
| C1| 2016-09-01|  null|
| C1| 2016-10-01|  null|
| C1| 2016-11-01|  null|
| C1| 2016-12-01|  null|
| C1| 2017-01-01|    10|
| C1| 2017-02-01|    20|
| C1| 2017-03-01|    10|

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"id")
    .rowsBetween(-12, Window.currentRow)

expanded.withColumn("rolling", sum("orders").over(w))
      $"order_month" - expr("INTERVAL 12 MONTHS") as "period_start",
      $"order_month" as "period_end",

|  2015-01-01|2016-01-01|     30|
|  2016-01-01|2017-01-01|     40|
|  2016-02-01|2017-02-01|     30|
|  2016-03-01|2017-03-01|     40|

请注意,这是一项非常昂贵的操作,至少需要两次 shuffle:
== Physical Plan ==
*Project [cast(cast(order_month#104 as timestamp) - interval 1 years as date) AS period_start#1387, order_month#104 AS period_end#1388, rolling#1375L]
+- *Filter AtLeastNNulls(n, orders#55L)
   +- Window [sum(orders#55L) windowspecdefinition(id#7, order_month#104 ASC NULLS FIRST, ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS rolling#1375L], [id#7], [order_month#104 ASC NULLS FIRST]
      +- *Sort [id#7 ASC NULLS FIRST, order_month#104 ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(id#7, 200)
            +- *Project [id#7, order_month#104, orders#55L]
               +- *BroadcastHashJoin [id#7, order_month#104], [id#181, order_month#49], LeftOuter, BuildRight
                  :- BroadcastNestedLoopJoin BuildRight, Cross
                  :  :- *HashAggregate(keys=[id#7], functions=[])
                  :  :  +- Exchange hashpartitioning(id#7, 200)
                  :  :     +- *HashAggregate(keys=[id#7], functions=[])
                  :  :        +- *HashAggregate(keys=[id#7, trunc(order_date#14, month)#1394], functions=[])
                  :  :           +- Exchange hashpartitioning(id#7, trunc(order_date#14, month)#1394, 200)
                  :  :              +- *HashAggregate(keys=[id#7, trunc(order_date#14, month) AS trunc(order_date#14, month)#1394], functions=[])
                  :  :                 +- LocalTableScan [id#7, order_date#14]
                  :  +- BroadcastExchange IdentityBroadcastMode
                  :     +- LocalTableScan [order_month#104]
                  +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true], input[1, date, true]))
                     +- *HashAggregate(keys=[id#181, trunc(order_date#14, month)#1395], functions=[sum(cast(orders#183 as bigint))])
                        +- Exchange hashpartitioning(id#181, trunc(order_date#14, month)#1395, 200)
                           +- *HashAggregate(keys=[id#181, trunc(order_date#14, month) AS trunc(order_date#14, month)#1395], functions=[partial_sum(cast(orders#183 as bigint))])
                              +- LocalTableScan [id#181, order_date#14, orders#183]

也可以使用 rangeBetween 来表达这一点。帧,但您必须先对数据进行编码:
val encoded = byMonth
      // Choose "zero" date appropriate in your scenario
      months_between($"order_month", to_date(lit("1970-01-01"))))

val w = Window.partitionBy($"id")
  .rangeBetween(-12, Window.currentRow)

encoded.withColumn("rolling", sum($"orders").over(w))

| id|order_month|orders|order_month_offset|rolling|
| C1| 2016-01-01|    30|             552.0|     30|
| C1| 2017-01-01|    10|             564.0|     40|
| C1| 2017-02-01|    20|             565.0|     30|
| C1| 2017-03-01|    10|             566.0|     40|


关于scala - 如何计算 Spark 中每位客户在 12 个月内按 1 个月滑动的订单总和,我们在Stack Overflow上找到一个类似的问题:


vim - 如何在vim中快速编译和运行scala代码?

scala - 如何将天数(作为列的值)添加到日期?

scala - 喷雾路由模板不起作用

scala - 如何使用我的相等比较器对 Spark DataFrame 进行 GroupBy?

python - 查找值位于 PySpark Dataframe 中特定列之间的所有列的列表


apache-spark - 如何显示 Spark 应用程序中语句序列的逐步执行?

hadoop - 无法使用 Spark 从 HDFS 读取文件

apache-spark - 数据 block 中的存储选项

apache-spark - 如何使结构化流中的 dropDuplicates 状态过期以避免 OOM?