我有一列是 Spark DataFrame 中的结构数组,例如
|-- sTest: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- value: string (nullable = true)
| | |-- embed: array (nullable = true)
| | | |-- element: integer (containsNull = true)
具有可变数量的嵌套数组(“嵌入”),且长度均相等。
对于每一行,我想取这些嵌入的平均值,并将结果作为一列分配给新的数据帧(现有的+新的列)。
我读到过一些人使用explode
,但这并不是我想要的。我最终想对每一行进行聚合,计算平均嵌入 (array(float)
)。
带有 ArrayType(StructType) 列的数据框的最小示例:
val structureData = Seq(
Row(Seq(Row("value1 ", Seq(1, 2, 3)), Row("value1 ", Seq(4, 5, 6)))),
Row(Seq(Row("value2", Seq(4,5,6))), Row("value1 ", Seq(1, 1, 1)))
)
val structureSchema = new StructType()
.add("sTest", ArrayType(new StructType()
.add("value", StringType)
.add("embed", ArrayType(IntegerType))))
期望的输出是
Row(2.5, 3.5, 4.5)
Row(2.5, 3, 3.5)
最佳答案
您的数据本质上看起来像一个矩阵,并且您尝试按列汇总矩阵,因此很自然地考虑使用 org.apache.spark.ml 中的
包。Summarizer
。 stat
输入数据:
case class sTest(value: String, embed: Seq[Int])
val df = Seq(
Tuple1(Seq(
sTest("value1", Seq(1, 2, 3)),
sTest("value2", Seq(4, 5, 6))
)),
Tuple1(Seq(
sTest("value3", Seq(4, 5, 6)),
sTest("value4", Seq(1, 1, 1))
))
) toDF("nested")
计算平均值:
import org.apache.spark.sql.functions._
import org.apache.spark.ml.linalg.{Vectors, Vector}
import org.apache.spark.ml.stat.Summarizer
val array2vecUdf = udf((array: Seq[Int]) => {
Vectors.dense(array.toArray.map(_.toDouble))
})
val vec2arrayUdf = udf((vec: Vector) => {
vec.toArray
})
val stage1 = df
// Create a rowid so we can explode, extract the embed field as a vector and collect.
.withColumn("rowid", monotonically_increasing_id)
.withColumn("exp", explode($"nested"))
.withColumn("embed", $"exp".getItem("embed"))
.withColumn("embed_vec", array2vecUdf($"embed"))
val avg = Summarizer.metrics("mean").summary($"embed_vec")
val stage2 = stage1
.groupBy("rowid")
.agg(avg.alias("avg_vec"))
// Convert back from vector to array.
.select(vec2arrayUdf($"avg_vec.mean").alias("avgs"))
stage2.show(false)
结果:
+---------------+
|avgs |
+---------------+
|[2.5, 3.5, 4.5]|
|[2.5, 3.0, 3.5]|
+---------------+
关于scala - 取 Struct 中 double 嵌套向量的平均值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67391492/