scala - 使用 Scala 类作为 UDF 与 pyspark

标签 scala apache-spark pyspark apache-spark-sql user-defined-functions

在使用 Apache Spark 时,我尝试将一些计算从 Python 卸载到 Scala。我想使用 Java 的类接口(interface)来使用持久变量,如下所示(这是一个基于我更复杂的用例的无意义的 MWE):

package mwe

import org.apache.spark.sql.api.java.UDF1

class SomeFun extends UDF1[Int, Int] {
  private var prop: Int = 0

  override def call(input: Int): Int = {
    if (prop == 0) {
      prop = input
    }
    prop + input
  }
}

现在我尝试在 pyspark 中使用此类:

import pyspark
from pyspark.sql import SQLContext
from pyspark import SparkContext

conf = pyspark.SparkConf()
conf.set("spark.jars", "mwe.jar")
sc = SparkContext.getOrCreate(conf)

sqlContext = SQLContext.getOrCreate(sc)
sqlContext.registerJavaFunction("fun", "mwe.SomeFun")

df0 = sc.parallelize((i,) for i in range(6)).toDF(["num"])
df1 = df0.selectExpr("fun(num) + 3 as new_num")
df1.show()

并得到以下异常:

pyspark.sql.utils.AnalysisException: u"cannot resolve '(UDF:fun(num) + 3)' due to data type mismatch: differing types in '(UDF:fun(num) + 3)' (struct<> and int).; line 1 pos 0;\n'Project [(UDF:fun(num#0L) + 3) AS new_num#2]\n+- AnalysisBarrier\n      +- LogicalRDD [num#0L], false\n"

实现这个的正确方法是什么?我是否必须使用 Java 本身来上课?我非常感谢提示!

最佳答案

异常的根源是使用了不兼容的类型:

  • 首先,o.a.s.sql.api.java.UDF* 对象需要外部 Java(不是 Scala 类型),因此需要整数的 UDF 应采用盒装Integer (java.lang.Integer) 不是 Int

    class SomeFun extends UDF1[Integer, Integer] {
      ...
      override def call(input: Integer): Integer = {
        ...
    
  • 除非您使用旧版 Python num 列,否则使用 LongType 而不是 IntegerType:

    df0.printSchema()
    root
     |-- num: long (nullable = true)
    

    所以实际的签名应该是

    class SomeFun extends UDF1[java.lang.Long, java.lang.Long] {
      ...
      override def call(input: java.lang.Long): java.lang.Long = {
        ...
    

    或者应该在应用 UDF 之前转换数据

    df0.selectExpr("fun(cast(num as integer)) + 3 as new_num")
    

最后,UDF 中不允许存在可变状态。它不会导致异常,但整体行为将是不确定的。

关于scala - 使用 Scala 类作为 UDF 与 pyspark,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49632117/

相关文章:

java - 如何获得 "sbt universal:packageBin"的分解版本?

scala - 将现有的 sbt 项目导入 IntelliJ

apache-spark - 带聚合的 Spark 流式传输

scala - 将列名添加到从csv文件读取的数据中而没有列名

java - 使用 Spark/java 的 ST_geomfromtext 函数

apache-spark - KafkaConsumer多线程访问pyspark不安全

python - 测试将值插入 mongodb(pyspark、pymongo)

scala - 函数在 Spark 中返回一个空列表

java - Scala 和 Kotlin 的手动 JAR 编译

apache-spark - sparksession.config() 和 spark.conf.set() 有什么区别