scala - Spark Collect_list 并限制结果列表

标签 scala apache-spark dataframe limit

我有以下格式的数据框:

name          merged
key1    (internalKey1, value1)
key1    (internalKey2, value2)
...
key2    (internalKey3, value3)
...

我想要做的是按名称对数据框进行分组,收集列表并限制列表的大小。

这就是我按名称分组并收集列表的方式:

val res = df.groupBy("name")
            .agg(collect_list(col("merged")).as("final"))

结果数据框类似于:

 key1   [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list 
 key2   [(internalKey3, value3),...]

我想要做的是限制每个键生成的列表的大小。我尝试了多种方法来做到这一点,但没有成功。我已经看到一些建议第三方解决方案的帖子,但我想避免这种情况。有办法吗?

最佳答案

因此,虽然 UDF 可以满足您的需要,但如果您正在寻找一种性能更高且对内存敏感的方法,则可以编写 UDAF。不幸的是,UDAF API 实际上不如 Spark 附带的聚合函数那么可扩展。但是,您可以使用其内部 API 来构建内部函数来完成您需要的操作。

这是一个 collect_list_limit 的实现,它主要是 Spark 内部 CollectList AggregateFunction 的复制。我只想扩展它,但它是一个案例类。实际上,所需要的只是重写更新和合并方法以尊重传入的限制:

case class CollectListLimit(
    child: Expression,
    limitExp: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {

  val limit = limitExp.eval( null ).asInstanceOf[Int]

  def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

  override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
    if( buffer.size < limit ) super.update(buffer, input)
    else buffer
  }

  override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
    if( buffer.size >= limit ) buffer
    else if( other.size >= limit ) other
    else ( buffer ++= other ).take( limit )
  }

  override def prettyName: String = "collect_list_limit"
}

要实际注册它,我们可以通过 Spark 的内部 FunctionRegistry 来完成,它接收名称和构建器,该构建器实际上是一个使用以下函数创建 CollectListLimit 的函数:提供的表达式:

val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )

编辑:

事实证明,仅当您尚未创建 SparkContext 时才将其添加到内置中才有效,因为它会在启动时创建不可变的克隆。如果您有现有的上下文,那么这应该可以通过反射添加它:

val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )

关于scala - Spark Collect_list 并限制结果列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52467555/

相关文章:

apache-spark - ShuffledRDD、MapPartitionsRDD 和 ParallelCollectionRDD 之间有什么区别?

python - 值错误 : Length mismatch: Expected axis has 0 elements while creating hierarchical columns in pandas dataframe

Scala : why can't I filter my Int List properly with placeholder ? 例如 : myList. 过滤器(_:Int => _ % 5 == 0)

Scala 向量标量乘法

apache-spark - 如何修复 oozie spark yarn 提交中的 '' java.lang.NoSuchMethodError"?

python - 由于 Spark 的惰性求值导致结果不一致

python - 值错误: arrays must all be same length

r - 基于数据帧 R 中的组的条件聚合

scala - 在 Scala 中应该只接受 A 或 B 或 C 实例的方法

scala - Scala 中使用关联运算符的并行聚合