我正在尝试定义一种从 DF 中的 WrappedArray 中过滤元素的方法。该过滤器基于外部元素列表。
在寻找解决方案时,我发现了这个question 。它非常相似,但似乎不适合我。我正在使用 Spark 2.4.0。这是我的代码:
val df = sc.parallelize(Array((1, Seq("s", "v", "r")),(2, Seq("r", "a", "v")),(3, Seq("s", "r", "t")))).toDF("foo","bar")
def filterItems(flist: Seq[String]) = udf {
(recs: Seq[String]) => recs match {
case null => Seq.empty[String]
case recs => recs.intersect(flist)
}}
df.withColumn("filtercol", filterItems(Seq("s", "v"))(col("bar"))).show(5)
我的预期结果是:
+---+---------+---------+
|foo| bar|filtercol|
+---+---------+---------+
| 1 |[s, v, r]| [s, v]|
| 2 |[r, a, v]| [v]|
| 3| [s, r, t]| [s]|
+---+---------+---------+
但我收到此错误:
java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
最佳答案
实际上,您可以使用 Spark 2.4 中的内置函数,而不需要太多的努力:
import org.apache.spark.sql.functions.{array_intersect, array, lit}
val df = sc.parallelize(Array((1, Seq("s", "v", "r")),(2, Seq("r", "a", "v")),(3, Seq("s", "r", "t")))).toDF("foo","bar")
val ar = Seq("s", "v").map(lit(_))
df.withColumn("filtercol", array_intersect($"bar", array(ar:_*))).show
输出:
+---+---------+---------+
|foo| bar|filtercol|
+---+---------+---------+
| 1|[s, v, r]| [s, v]|
| 2|[r, a, v]| [v]|
| 3|[s, r, t]| [s]|
+---+---------+---------+
唯一棘手的部分是 Seq("s", "v").map(lit(_))
它将每个字符串映射到 lit(i)
。 intersection
函数接受两个数组。第一个是 bar
列的值。第二个是使用 array(ar:_*)
即时创建的,其中将包含 lit(i)
的值。
关于scala - 基于具有交集的外部数组过滤数据框数组项,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56129321/