python - 如何检查 Pyspark Dataframe 中列表是否存在交集

标签 python pandas apache-spark pyspark apache-spark-sql

我有一个 pyspark 数据帧,如下:

import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.functions import udf

schema = T.StructType([  # schema
    T.StructField("id", T.StringType(), True),
    T.StructField("code", T.ArrayType(T.StringType()), True)])
df = spark.createDataFrame([{"id": "1", "code": ["a1", "a2","a3","a4"]},
                            {"id": "2", "code": ["b1","b2"]},
                            {"id": "3", "code": ["c1","c2","c3"]},
                            {"id": "4", "code": ["d1", "b3"]}],
                           schema=schema)

给出输出

df.show()


| id|            code|
|---|----------------|
|  1|[a1, a2, a3, a4]|
|  2|        [b1, b2]|
|  3|    [c1, c2, c3]|
|  4|        [d1, b3]|

我希望能够通过向函数提供列和列表来过滤行,如果有任何交集,则返回 true(使用与 here 不相交,因为会有很多非命中)

def lst_intersect(data_lst,query_lst):
    return not set(data_lst).isdisjoint(query_lst) 
lst_intersect_udf = F.udf(lambda x,y: lst_intersect(x,y), T.BooleanType())

当我尝试应用这个时

query_lst = ['a1','b3']
df = df.withColumn("code_found", lst_intersect_udf(F.col('code'),F.lit(query_lst)))

出现以下错误

Unsupported literal type class java.util.ArrayList [a1, b3]

我可以通过更改函数等来解决它 - 但想知道我在 F.lit(query_lst) 上做错了什么基本的事情吗?

最佳答案

lit 仅接受单个值,而不接受 Python 列表。例如,您需要使用列表理解传递包含列表中文字值的数组列。

df2 = df.withColumn(
    "code_found", 
    lst_intersect_udf(
        F.col('code'),
        F.array(*[F.lit(i) for i in query_lst])
    )
)

df2.show()
+---+----------------+----------+
| id|            code|code_found|
+---+----------------+----------+
|  1|[a1, a2, a3, a4]|      true|
|  2|        [b1, b2]|     false|
|  3|    [c1, c2, c3]|     false|
|  4|        [d1, b3]|      true|
+---+----------------+----------+

也就是说,如果您的 Spark >= 2.4,您还可以使用 Spark SQL 函数 arrays_overlap提供更好的性能:

df2 = df.withColumn(
    "code_found", 
    F.arrays_overlap(
        F.col('code'),
        F.array(*[F.lit(i) for i in query_lst])
    )
)

关于python - 如何检查 Pyspark Dataframe 中列表是否存在交集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66852467/

相关文章:

python - 将 Django 连接到 Google App Engine 中的 Google CloudSQL Postgres 数据库时出错

python - 如何创建字典来查找掉落的零?

python - Pandas - 比较正值/负值

python - Pandas:查找特定列不为 NA 但所有其他列为 NA 的行

python - 如何从 Jupyter 在 HDInsight Spark 集群上提交 python wordcount

python - 在 Odoo/OpenERP 中获取当前没有 ids 或 uid 的公司

python - 警告 : pip is configured with locations that require TLS/SSL, 但是 Python 中的 ssl 模块不可用

python - OpenCV:如何只绘制视频文件的第一帧,然后继续显示整个视频

python - pyspark 中的每月聚合

python - 更改 Spark Web UI 的根路径?