我有以下格式的数据框:
name merged
key1 (internalKey1, value1)
key1 (internalKey2, value2)
...
key2 (internalKey3, value3)
...
我想要做的是按名称
对数据框进行分组,收集列表并限制列表的大小。
这就是我按名称
分组并收集列表的方式:
val res = df.groupBy("name")
.agg(collect_list(col("merged")).as("final"))
结果数据框类似于:
key1 [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list
key2 [(internalKey3, value3),...]
我想要做的是限制每个键生成的列表的大小。我尝试了多种方法来做到这一点,但没有成功。我已经看到一些建议第三方解决方案的帖子,但我想避免这种情况。有办法吗?
最佳答案
因此,虽然 UDF 可以满足您的需要,但如果您正在寻找一种性能更高且对内存敏感的方法,则可以编写 UDAF。不幸的是,UDAF API 实际上不如 Spark 附带的聚合函数那么可扩展。但是,您可以使用其内部 API 来构建内部函数来完成您需要的操作。
这是一个 collect_list_limit
的实现,它主要是 Spark 内部 CollectList
AggregateFunction 的复制。我只想扩展它,但它是一个案例类。实际上,所需要的只是重写更新和合并方法以尊重传入的限制:
case class CollectListLimit(
child: Expression,
limitExp: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
val limit = limitExp.eval( null ).asInstanceOf[Int]
def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
if( buffer.size < limit ) super.update(buffer, input)
else buffer
}
override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
if( buffer.size >= limit ) buffer
else if( other.size >= limit ) other
else ( buffer ++= other ).take( limit )
}
override def prettyName: String = "collect_list_limit"
}
要实际注册它,我们可以通过 Spark 的内部 FunctionRegistry
来完成,它接收名称和构建器,该构建器实际上是一个使用以下函数创建 CollectListLimit
的函数:提供的表达式:
val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )
编辑:
事实证明,仅当您尚未创建 SparkContext 时才将其添加到内置中才有效,因为它会在启动时创建不可变的克隆。如果您有现有的上下文,那么这应该可以通过反射添加它:
val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )
关于scala - Spark Collect_list 并限制结果列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52467555/