scala - 向 Spark DataFrame 添加嵌套列

标签 scala apache-spark apache-spark-sql

如何在任何嵌套级别的结构中添加或替换字段?

此输入:

val rdd = sc.parallelize(Seq(
  """{"a": {"xX": 1,"XX": 2},"b": {"z": 0}}""",
  """{"a": {"xX": 3},"b": {"z": 0}}""",
  """{"a": {"XX": 3},"b": {"z": 0}}""",
  """{"a": {"xx": 4},"b": {"z": 0}}"""))
var df = sqlContext.read.json(rdd)

产生以下架构:

root
 |-- a: struct (nullable = true)
 |    |-- XX: long (nullable = true)
 |    |-- xX: long (nullable = true)
 |    |-- xx: long (nullable = true)
 |-- b: struct (nullable = true)
 |    |-- z: long (nullable = true)

那么我可以这样做:

import org.apache.spark.sql.functions._
val overlappingNames = Seq(col("a.xx"), col("a.xX"), col("a.XX"))
df = df
  .withColumn("a_xx",
    coalesce(overlappingNames:_*))
  .dropNestedColumn("a.xX")
  .dropNestedColumn("a.XX")
  .dropNestedColumn("a.xx")

(dropNestedColumn 借自此答案: https://stackoverflow.com/a/39943812/1068385 。我基本上是在寻找它的逆运算。)

架构变成:

root
 |-- a: struct (nullable = false)
 |-- b: struct (nullable = true)
 |    |-- z: long (nullable = true)
 |-- a_xx: long (nullable = true)

显然,它不会替换(或添加)a.xx,而是在根级别添加新字段a_xx

我希望能够这样做:

val overlappingNames = Seq(col("a.xx"), col("a.xX"), col("a.XX"))
df = df
  .withNestedColumn("a.xx",
    coalesce(overlappingNames:_*))
  .dropNestedColumn("a.xX")
  .dropNestedColumn("a.XX")

这样就会产生这样的架构:

root
 |-- a: struct (nullable = false)
 |    |-- xx: long (nullable = true)
 |-- b: struct (nullable = true)
 |    |-- z: long (nullable = true)

我怎样才能实现这一目标?

此处的实际目标是输入 JSON 中的列名称不区分大小写。最后一步很简单:收集所有重叠的列名称并对每个列名称应用合并。

最佳答案

它可能不够优雅或高效,但这是我的想法:

object DataFrameUtils {
  private def nullableCol(parentCol: Column, c: Column): Column = {
    when(parentCol.isNotNull, c)
  }

  private def nullableCol(c: Column): Column = {
    nullableCol(c, c)
  }

  private def createNestedStructs(splitted: Seq[String], newCol: Column): Column = {
    splitted
      .foldRight(newCol) {
        case (colName, nestedStruct) => nullableCol(struct(nestedStruct as colName))
      }
  }

  private def recursiveAddNestedColumn(splitted: Seq[String], col: Column, colType: DataType, nullable: Boolean, newCol: Column): Column = {
    colType match {
      case colType: StructType if splitted.nonEmpty => {
        var modifiedFields: Seq[(String, Column)] = colType.fields
          .map(f => {
            var curCol = col.getField(f.name)
            if (f.name == splitted.head) {
              curCol = recursiveAddNestedColumn(splitted.tail, curCol, f.dataType, f.nullable, newCol)
            }
            (f.name, curCol as f.name)
          })

        if (!modifiedFields.exists(_._1 == splitted.head)) {
          modifiedFields :+= (splitted.head, nullableCol(col, createNestedStructs(splitted.tail, newCol)) as splitted.head)
        }

        var modifiedStruct: Column = struct(modifiedFields.map(_._2): _*)
        if (nullable) {
          modifiedStruct = nullableCol(col, modifiedStruct)
        }
        modifiedStruct
      }
      case _  => createNestedStructs(splitted, newCol)
    }
  }

  private def addNestedColumn(df: DataFrame, newColName: String, newCol: Column): DataFrame = {
    if (newColName.contains('.')) {
      var splitted = newColName.split('.')

      val modifiedOrAdded: (String, Column) = df.schema.fields
        .find(_.name == splitted.head)
        .map(f => (f.name, recursiveAddNestedColumn(splitted.tail, col(f.name), f.dataType, f.nullable, newCol)))
        .getOrElse {
          (splitted.head, createNestedStructs(splitted.tail, newCol) as splitted.head)
        }

      df.withColumn(modifiedOrAdded._1, modifiedOrAdded._2)

    } else {
      // Top level addition, use spark method as-is
      df.withColumn(newColName, newCol)
    }
  }

  implicit class ExtendedDataFrame(df: DataFrame) extends Serializable {
    /**
      * Add nested field to DataFrame
      *
      * @param newColName Dot-separated nested field name
      * @param newCol New column value
      */
    def withNestedColumn(newColName: String, newCol: Column): DataFrame = {
      DataFrameUtils.addNestedColumn(df, newColName, newCol)
    }
  }
}

请随意改进它。

val data = spark.sparkContext.parallelize(List("""{ "a1": 1, "a3": { "b1": 3, "b2": { "c1": 5, "c2": 6 } } }"""))
val df: DataFrame = spark.read.json(data)

val df2 = df.withNestedColumn("a3.b2.c3.d1", $"a3.b2")

应该产生:

assertResult("struct<a1:bigint,a3:struct<b1:bigint,b2:struct<c1:bigint,c2:bigint,c3:struct<d1:struct<c1:bigint,c2:bigint>>>>>")(df2.shema.simpleString)

关于scala - 向 Spark DataFrame 添加嵌套列,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41731000/

相关文章:

apache-spark - 如何在 databricks 中的字符串内使用 %run 运行代码

json - 在 JSON 的情况下,当模式推断留给 Spark 时,为什么 Spark 会输出 nullable = true?

scala - Spark : how to run spark file from spark shell

Scala 嵌套映射到 Spark RDD

scala - 将 StructType 定义为函数 Spark-Scala 2.11 的输入数据类型

scala - 灵活的查询过滤器作用于任意列

scala - 如何解决Kafka Consumer轮询超时错误

从内部使迭代器映射超时的 Scala 惯用方法?

python - Spark 创建数据帧,其中包含整数和 float 混合的列

arrays - 使用数组值合并两个 Spark 数据帧