python - pyspark 中的 Pandas UDF

标签 python pandas apache-spark pyspark

我正在尝试在 Spark 数据帧上填充一系列观察结果。基本上我有一个日期列表,我应该为每个组创建缺少的日期。
pandas 中有 reindex 函数,pyspark 中没有该函数。
我尝试实现 pandas UDF:

@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def reindex_by_date(df):
    df = df.set_index('dates')
    dates = pd.date_range(df.index.min(), df.index.max())
    return df.reindex(dates, fill_value=0).ffill()

这看起来应该可以满足我的需要,但是它失败并显示此消息 AttributeError:只能将 .dt 访问器与类似日期时间的值一起使用 。我在这里做错了什么?
完整代码如下:

data = spark.createDataFrame(
        [(1, "2020-01-01", 0), 
        (1, "2020-01-03", 42), 
        (2, "2020-01-01", -1), 
        (2, "2020-01-03", -2)],
        ('id', 'dates', 'value'))

data = data.withColumn('dates', col('dates').cast("date"))

schema = StructType([
     StructField('id', IntegerType()),
     StructField('dates', DateType()),
     StructField('value', DoubleType())])

@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def reindex_by_date(df):
     df = df.set_index('dates')
     dates = pd.date_range(df.index.min(), df.index.max())
     return df.reindex(dates, fill_value=0).ffill()

data = data.groupby('id').apply(reindex_by_date)

理想情况下我想要这样的东西:

+---+----------+-----+                                                          
| id|     dates|value|
+---+----------+-----+
|  1|2020-01-01|    0|
|  1|2020-01-02|    0|
|  1|2020-01-03|   42|
|  2|2020-01-01|   -1|
|  2|2020-01-02|    0|
|  2|2020-01-03|   -2|
+---+----------+-----+

最佳答案

情况 1:每个 ID 都有单独的日期范围。

我会尽量减少udf的内容。在这种情况下,我只会计算 udf 中每个 ID 的日期范围。对于其他部分,我将使用 Spark native 函数。

from pyspark.sql import types as T
from pyspark.sql import functions as F

# Get min and max date per ID
date_ranges = data.groupby('id').agg(F.min('dates').alias('date_min'), F.max('dates').alias('date_max'))

# Calculate the date range for each ID
@F.udf(returnType=T.ArrayType(T.DateType()))
def get_date_range(date_min, date_max):
  return [t.date() for t in list(pd.date_range(date_min, date_max))]

# To get one row per potential date, we need to explode the UDF output
date_ranges = date_ranges.withColumn(
  'dates',
  F.explode(get_date_range(F.col('date_min'), F.col('date_max')))
)

date_ranges = date_ranges.drop('date_min', 'date_max')

# Add the value for existing entries and add 0 for others
result = date_ranges.join(
  data,
  ['id', 'dates'],
  'left'
)

result = result.fillna({'value': 0})

情况 2:所有 ID 具有相同的日期范围

我认为这里没有必要使用UDF。您想要的内容可以通过不同的方式存档:首先,您获得所有可能的 ID 和所有必要的日期。其次,你交叉加入它们,这将为你提供所有可能的组合。第三,将原始数据左连接到组合上。第四,将出现的空值替换为0。

# Get all unique ids
ids_df = data.select('id').distinct()

# Get the date series
date_min, date_max = data.agg(F.min('dates'), F.max('dates')).collect()[0]
dates = [[t.date()] for t in list(pd.date_range(date_min, date_max))]
dates_df = spark.createDataFrame(data=dates, schema="dates:date")

# Calculate all combinations
all_comdinations = ids_df.crossJoin(dates_df)

# Add the value column
result = all_comdinations.join(
  data,
  ['id', 'dates'],
  'left'
)

# Replace all null values with 0
result = result.fillna({'value': 0})

请注意此解决方案的以下限制:

  1. 交叉连接的成本可能相当高。可以在 this related question 中找到解决该问题的一种潜在解决方案。 .
  2. collect 语句和 Pandas 的使用导致 Spark 转换不完全并行。

[编辑] 分为两种情况,因为我首先认为所有 ID 都具有相同的日期范围。

关于python - pyspark 中的 Pandas UDF,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64395846/

相关文章:

python - 如何调用 json 字典中列表中的值

pandas.read_csv 在循环内给出 FileNotFound 错误

python - 在 python 中将多个 excel '.xlsx' 文件转换为 '.csv' 文件时,我得到了额外的列?

scala - 无法从Sqoop创建的Spark中的序列文件创建数据框

apache-spark - 来自 Kafka 源的 Spark Streaming 返回检查点或倒带

java - 集成 Spark 和 Spring Boot

python - 如何在Python中生成随机稀疏埃尔米特矩阵?

python - 如何计算 Pandas 每个时期的增加量

python - 解压后重新打包 PyInstaller exe

python - 如何在pandas dataframe Python中找到GPS坐标之间的角度