scala - 如何编写 Spark UDF,它将 Array[StructType]、StructType 作为输入并返回 Array[StructType]

标签 scala apache-spark user-defined-functions

我有一个具有以下架构的 DataFrame :

root
 |-- user_id: string (nullable = true)
 |-- user_loans_arr: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- loan_date: string (nullable = true)
 |    |    |-- loan_amount: string (nullable = true)
 |-- new_loan: struct (nullable = true)
 |    |-- loan_date : string (nullable = true)
 |    |-- loan_amount : string (nullable = true)

我想使用 UDF,它将 user_loans_arrnew_loan 作为输入,并将 new_loan 结构添加到现有的 user_loans_arr 中。然后,从user_loans_arr中删除所有loan_date早于12个月的元素。

提前致谢。

最佳答案

如果 spark >= 2.4 那么你不需要 UDF,请查看下面的示例-

加载输入数据

 val df = spark.sql(
      """
        |select user_id, user_loans_arr, new_loan
        |from values
        | ('u1', array(named_struct('loan_date', '2019-01-01', 'loan_amount', 100)), named_struct('loan_date',
        | '2020-01-01', 'loan_amount', 100)),
        | ('u2', array(named_struct('loan_date', '2020-01-01', 'loan_amount', 200)), named_struct('loan_date',
        | '2020-01-01', 'loan_amount', 100))
        | T(user_id, user_loans_arr, new_loan)
      """.stripMargin)
    df.show(false)
    df.printSchema()

    /**
      * +-------+-------------------+-----------------+
      * |user_id|user_loans_arr     |new_loan         |
      * +-------+-------------------+-----------------+
      * |u1     |[[2019-01-01, 100]]|[2020-01-01, 100]|
      * |u2     |[[2020-01-01, 200]]|[2020-01-01, 100]|
      * +-------+-------------------+-----------------+
      *
      * root
      * |-- user_id: string (nullable = false)
      * |-- user_loans_arr: array (nullable = false)
      * |    |-- element: struct (containsNull = false)
      * |    |    |-- loan_date: string (nullable = false)
      * |    |    |-- loan_amount: integer (nullable = false)
      * |-- new_loan: struct (nullable = false)
      * |    |-- loan_date: string (nullable = false)
      * |    |-- loan_amount: integer (nullable = false)
      */

按照以下要求进行处理

user_loans_arr and new_loan as inputs and add the new_loan struct to the existing user_loans_arr. Then, from user_loans_arr delete all the elements whose loan_date is older than 12 months.

<强> spark >= 2.4

    df.withColumn("user_loans_arr",
      expr(
        """
          |FILTER(array_union(user_loans_arr, array(new_loan)),
          | x -> months_between(current_date(), to_date(x.loan_date)) < 12)
        """.stripMargin))
      .show(false)

    /**
      * +-------+--------------------------------------+-----------------+
      * |user_id|user_loans_arr                        |new_loan         |
      * +-------+--------------------------------------+-----------------+
      * |u1     |[[2020-01-01, 100]]                   |[2020-01-01, 100]|
      * |u2     |[[2020-01-01, 200], [2020-01-01, 100]]|[2020-01-01, 100]|
      * +-------+--------------------------------------+-----------------+
      */

<强> spark < 2.4

 // spark < 2.4
    val outputSchema = df.schema("user_loans_arr").dataType

    import java.time._
    val add_and_filter = udf((userLoansArr: mutable.WrappedArray[Row], loan: Row) => {
      (userLoansArr :+ loan).filter(row => {
        val loanDate = LocalDate.parse(row.getAs[String]("loan_date"))
        val period = Period.between(loanDate, LocalDate.now())
        period.getYears * 12 + period.getMonths < 12
      })
    }, outputSchema)

    df.withColumn("user_loans_arr", add_and_filter($"user_loans_arr", $"new_loan"))
      .show(false)

    /**
      * +-------+--------------------------------------+-----------------+
      * |user_id|user_loans_arr                        |new_loan         |
      * +-------+--------------------------------------+-----------------+
      * |u1     |[[2020-01-01, 100]]                   |[2020-01-01, 100]|
      * |u2     |[[2020-01-01, 200], [2020-01-01, 100]]|[2020-01-01, 100]|
      * +-------+--------------------------------------+-----------------+
      */

关于scala - 如何编写 Spark UDF,它将 Array[StructType]、StructType 作为输入并返回 Array[StructType],我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62618949/

相关文章:

scala - 如何仅在 Spark Streaming 的分区内使用 `reduce`,也许使用 combineByKey?

scala - 使用 scala 2.12.10 时 Spark 的 pom.xml 依赖项

apache-spark - 如何在 Spark DataFrame/DataSet 中将行拆分为不同的列?

scala - 阶段 6.0 spark scala 中的错误执行器 : Exception in task 0. 0?

Scala解密OpenPGP(GPG)加密文件

scala - 用于编码/解码 arity 0 的密封特征实例的 Circe 实例?

java - 使用 IntelliJ 设置 Scala

apache-spark - MLlib 的输入格式问题

python - python中的点函数?

scala - 在 Spark 数据帧中查找