scala - 重写 LogicalPlan 以从聚合中下推 udf

标签 scala apache-spark catalyst-optimizer

我定义了一个将输入值加一的UDF,命名为“inc”,这是我的udf的代码

spark.udf.register("inc", (x: Long) => x + 1)

这是我的测试sql

val df = spark.sql("select sum(inc(vals)) from data")
df.explain(true)
df.show()

这是那个sql的优化计划

== Optimized Logical Plan ==
Aggregate [sum(inc(vals#4L)) AS sum(inc(vals))#7L]
+- LocalRelation [vals#4L]

我想重写计划,并从“sum”中提取“inc”,就像 python udf 一样。 所以,这就是我想要的优化方案。

Aggregate [sum(inc_val#6L) AS sum(inc(vals))#7L]
+- Project [inc(vals#4L) AS inc_val#6L]
   +- LocalRelation [vals#4L]

我发现源代码文件“ExtractPythonUDFs.scala”提供了在PythonUDF上工作的类似功能,但它插入了一个名为“ArrowEvalPython”的新节点,这是pythonudf的逻辑计划。

== Optimized Logical Plan ==
Aggregate [sum(pythonUDF0#7L) AS sum(inc(vals))#4L]
+- Project [pythonUDF0#7L]
   +- ArrowEvalPython [inc(vals#0L)], [pythonUDF0#7L], 200
      +- Repartition 10, true
         +- RelationV2[vals#0L] parquet file:/tmp/vals.parquet

我想插入的只是一个“项目节点”,我不想定义一个新的节点。


这是我项目的测试代码

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule

object RewritePlanTest {

  case class UdfRule(spark: SparkSession) extends Rule[LogicalPlan] {

    def collectUDFs(e: Expression): Seq[Expression] = e match {
      case udf: ScalaUDF => Seq(udf)
      case _ => e.children.flatMap(collectUDFs)
    }

    override def apply(plan: LogicalPlan): LogicalPlan = plan match {
      case agg@Aggregate(g, a, _) if (g.isEmpty && a.length == 1) =>
        val udfs = agg.expressions.flatMap(collectUDFs)
        println("================")
        udfs.foreach(println)
        val test = udfs(0).isInstanceOf[NamedExpression]
        println(s"cast ScalaUDF to NamedExpression = ${test}")
        println("================")
        agg
      case _ => plan
    }
  }


  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)

    val spark = SparkSession
      .builder()
      .master("local[*]")
      .appName("Rewrite plan test")
      .withExtensions(e => e.injectOptimizerRule(UdfRule))
      .getOrCreate()

    val input = Seq(100L, 200L, 300L)
    import spark.implicits._
    input.toDF("vals").createOrReplaceTempView("data")

    spark.udf.register("inc", (x: Long) => x + 1)

    val df = spark.sql("select sum(inc(vals)) from data")
    df.explain(true)
    df.show()
    spark.stop()
  }
}

我从 Aggregate 节点中提取了 ScalaUDF

因为 Project 节点所需的参数是 Seq[NamedExpression]

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)

但是无法将 ScalaUDF 转换为 NamedExpression

所以我不知道如何构建Project节点。

有人可以给我一些建议吗?

谢谢。

最佳答案

好吧,我终于找到了回答这个问题的方法。

虽然 ScalaUDF 不能转换为 NamedExpression,但是 Alias 可以。

因此,我从 ScalaUDF 创建了 Alias,然后构建了 Project

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpectsInputTypes, ExprId, Expression, NamedExpression, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, Project, Subquery}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{AbstractDataType, DataType}

import scala.collection.mutable

object RewritePlanTest {

  case class UdfRule(spark: SparkSession) extends Rule[LogicalPlan] {

    def collectUDFs(e: Expression): Seq[Expression] = e match {
      case udf: ScalaUDF => Seq(udf)
      case _ => e.children.flatMap(collectUDFs)
    }

    override def apply(plan: LogicalPlan): LogicalPlan = plan match {
      case agg@Aggregate(g, a, c) if g.isEmpty && a.length == 1 => {
        val udfs = agg.expressions.flatMap(collectUDFs)
        if (udfs.isEmpty) {
          agg
        } else {
          val alias_udf = for (i <- 0 until udfs.size) yield Alias(udfs(i), s"udf${i}")()
          val alias_set = mutable.HashMap[Expression, Attribute]()
          val proj = Project(alias_udf, c)
          alias_set ++= udfs.zip(proj.output)
          val new_agg = agg.withNewChildren(Seq(proj)).transformExpressionsUp {
            case udf: ScalaUDF if alias_set.contains(udf) => alias_set(udf)
          }
          println("====== new agg ======")
          println(new_agg)
          new_agg
        }
      }
      case _ => plan
    }
  }


  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)

    val spark = SparkSession
      .builder()
      .master("local[*]")
      .appName("Rewrite plan test")
      .withExtensions(e => e.injectOptimizerRule(UdfRule))
      .getOrCreate()

    val input = Seq(100L, 200L, 300L)
    import spark.implicits._
    input.toDF("vals").createOrReplaceTempView("data")

    spark.udf.register("inc", (x: Long) => x + 1)

    val df = spark.sql("select sum(inc(vals)) from data where vals > 100")
    //    val plan = df.queryExecution.analyzed
    //    println(plan)
    df.explain(true)
    df.show()

    spark.stop()

  }
}

此代码输出我想要的 LogicalPlan。

====== new agg ======
Aggregate [sum(udf0#9L) AS sum(inc(vals))#7L]
+- Project [inc(vals#4L) AS udf0#9L]
   +- LocalRelation [vals#4L]

关于scala - 重写 LogicalPlan 以从聚合中下推 udf,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59839910/

相关文章:

来自字符串的Scala FiniteDuration

scala - 递归数据框操作

scala - Spark 错误:ClassNotFoundException : scala. 可克隆

scala - Spark数据帧中两行之间的差异

apache-spark - Spark 流中的缓存会提高性能吗

apache-spark - UDF 无法从哪些优化中受益?

数据框 API 与 Spark.sql

Scala - 如何从配置文件将数据设置为占位符?

使用 SparkR 替换列中的特殊字符