scala - 如何定义自定义聚合函数来对向量列求和?

标签 scala apache-spark apache-spark-sql aggregate-functions apache-spark-ml

我有一个包含两列的 DataFrame,ID 类型为 IntVec 类型为 Vector (org.apache.spark.mllib.linalg.Vector)。

DataFrame 如下所示:

ID,Vec
1,[0,0,5]
1,[4,0,1]
1,[1,2,1]
2,[7,5,0]
2,[3,3,4]
3,[0,8,1]
3,[0,0,1]
3,[7,7,7]
....

我想做一个groupBy($"ID"),然后通过对向量求和来对每个组内的行应用聚合。

上述示例的所需输出为:

ID,SumOfVectors
1,[5,2,7]
2,[10,8,4]
3,[7,15,9]
...

可用的聚合函数将不起作用,例如df.groupBy($"ID").agg(sum($"Vec") 将导致 ClassCastException。

如何实现自定义聚合函数,使我能够对向量或数组求和或任何其他自定义操作?

最佳答案

Spark >= 3.0

您可以将Summarizersum结合使用

import org.apache.spark.ml.stat.Summarizer

df
  .groupBy($"id")
  .agg(Summarizer.sum($"vec").alias("vec"))

Spark <= 3.0

就我个人而言,我不会为 UDAF 烦恼。有更多的冗长而且不完全快( Spark UDAF with ArrayType as bufferSchema performance issues )相反,我会简单地使用 reduceByKey/foldByKey:

import org.apache.spark.sql.Row
import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.ml.linalg.{Vector, Vectors}

def dv(values: Double*): Vector = Vectors.dense(values.toArray)

val df = spark.createDataFrame(Seq(
    (1, dv(0,0,5)), (1, dv(4,0,1)), (1, dv(1,2,1)),
    (2, dv(7,5,0)), (2, dv(3,3,4)), 
    (3, dv(0,8,1)), (3, dv(0,0,1)), (3, dv(7,7,7)))
  ).toDF("id", "vec")

val aggregated = df
  .rdd
  .map{ case Row(k: Int, v: Vector) => (k, BDV(v.toDense.values)) }
  .foldByKey(BDV.zeros[Double](3))(_ += _)
  .mapValues(v => Vectors.dense(v.toArray))
  .toDF("id", "vec")

aggregated.show

// +---+--------------+
// | id|           vec|
// +---+--------------+
// |  1| [5.0,2.0,7.0]|
// |  2|[10.0,8.0,4.0]|
// |  3|[7.0,15.0,9.0]|
// +---+--------------+

只是为了比较“简单”的 UDAF。所需导入:

import org.apache.spark.sql.expressions.{MutableAggregationBuffer,
  UserDefinedAggregateFunction}
import org.apache.spark.ml.linalg.{Vector, Vectors, SQLDataTypes}
import org.apache.spark.sql.types.{StructType, ArrayType, DoubleType}
import org.apache.spark.sql.Row
import scala.collection.mutable.WrappedArray

类定义:

class VectorSum (n: Int) extends UserDefinedAggregateFunction {
    def inputSchema = new StructType().add("v", SQLDataTypes.VectorType)
    def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
    def dataType = SQLDataTypes.VectorType
    def deterministic = true 

    def initialize(buffer: MutableAggregationBuffer) = {
      buffer.update(0, Array.fill(n)(0.0))
    }

    def update(buffer: MutableAggregationBuffer, input: Row) = {
      if (!input.isNullAt(0)) {
        val buff = buffer.getAs[WrappedArray[Double]](0) 
        val v = input.getAs[Vector](0).toSparse
        for (i <- v.indices) {
          buff(i) += v(i)
        }
        buffer.update(0, buff)
      }
    }

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      val buff1 = buffer1.getAs[WrappedArray[Double]](0) 
      val buff2 = buffer2.getAs[WrappedArray[Double]](0) 
      for ((x, i) <- buff2.zipWithIndex) {
        buff1(i) += x
      }
      buffer1.update(0, buff1)
    }

    def evaluate(buffer: Row) =  Vectors.dense(
      buffer.getAs[Seq[Double]](0).toArray)
} 

以及示例用法:

df.groupBy($"id").agg(new VectorSum(3)($"vec") alias "vec").show

// +---+--------------+
// | id|           vec|
// +---+--------------+
// |  1| [5.0,2.0,7.0]|
// |  2|[10.0,8.0,4.0]|
// |  3|[7.0,15.0,9.0]|
// +---+--------------+

另请参阅:How to find mean of grouped Vector columns in Spark SQL? .

关于scala - 如何定义自定义聚合函数来对向量列求和?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44561255/

相关文章:

scala 文本中单词之间的绝对最小距离

scala - Spark - Scala - saveAsHadoopFile 抛出错误

apache-spark - 错误 SparkContext 无法在 Apache Spark 2.1.1 中添加文件

apache-spark - Spark Dataframes 中的分区和集群

java - Java 类访问时的 Scala 可见性

mysql - 如何使用 Spark 将数据插入 RDB (MySQL)?

scala - 在 shapeless 2.0 中动态创建可扩展记录

hadoop - Spark 上的 Hive 2.1.1 - 我应该使用哪个版本的 Spark

pandas - 有没有办法强制 spark worker 使用分布式 numpy 版本而不是安装在他们身上的版本?

Scala - 当文件路径不存在时读取 DataFrame