python - 根据spark中的移动和将批号添加到DataFrame

标签 python dataframe apache-spark pyspark

我有一个需要分批处理的数据集(由于 API 限制)。

一个batch的text_lenth列之和不能超过1000。并且批处理中的最大行数不能大于 5

为此,我想将批号添加到单个批处理中,以便稍后根据 batch_numbers 处理数据。

我如何在 pyspark(在 Databricks 中)中实现它。我对所有这一切都很陌生,我什至不知道要在网上寻找什么。

非常感谢您的帮助。

下表说明了我正在努力实现的目标:

原始表格

<表类="s-表"> <头> id text_length <正文> 1 500 2 400 3 200 4 300 5 100 6 100 7 100 8 100 9 100 10 300

结果表

<表类="s-表"> <头> id text_length 批号 <正文> 1 500 1 2 400 1 3 200 2 4 300 2 5 100 2 6 100 2 7 100 2 8 100 3 9 100 3 10 300 3

最佳答案

如果您不是在寻找最优解,而是在寻找一种在 Spark 中解决问题而又不太复杂的方法,我们可以将问题分为两个步骤:

  1. 将数据分成 block ,每 block 5 行,忽略文本长度
  2. 如果一个 block 中的文本长度总和太大,则将该 block 拆分为更多 block

这个解决方案不是最优的,因为它产生了太多的批处理。

第 1 步可以使用 zipWithIndex 实现.创建批处理 ID 时,我们会留出足够的“空间”以便稍后划分批处理。在此步骤结束时,一个 block 中的所有行都被分组到一个列表中作为步骤 2 的输入:

df = ...

r = df.rdd.zipWithIndex().toDF() \
    .select("_1.id", "_1.text_length", "_2") \
    .withColumn("batch", F.expr("cast(_2 / 5 as long)*5")) \
    .withColumn("data", F.struct("id", "text_length", "batch")) \
    .groupBy("batch") \
    .agg(F.collect_list("data").alias("data"))

第 2 部分主要包含 udf检查在一批中是否超过了最大文本长度。如果是这样,则以下元素的批处理 ID 增加 1。由于我们在第 1 部分中跳过了足够多的批处理 ID,因此我们没有遇到任何冲突。

def splitBatchIfNecessary(data):
    text_length = 0
    batch = -1
    for d in data:
        text_length = text_length + d.text_length
        if text_length > 1000:
          if batch == -1:
            text_length = 0
            batch = d.batch + 1
            yield (d.id, d.text_length, d.batch)
          else:
            text_length = d.text_length
            batch = batch + 1
            yield (d.id, d.text_length, batch)          
        else:
          if batch == -1:
            batch = d.batch
          yield (d.id, d.text_length, batch)

schema=r.schema["data"].dataType
split_udf = F.udf(splitBatchIfNecessary, schema)

r = r.withColumn("data",split_udf(F.col("data")) ) \
      .selectExpr("explode(data)") \
      .select("col.*") 

输出:

+---+-----------+-----+                                                         
| id|text_length|batch|
+---+-----------+-----+
|  1|        500|    0|
|  2|        400|    0|
|  3|        200|    1|
|  4|        300|    1|
|  5|        100|    1|
|  6|        100|    5|
|  7|        100|    5|
|  8|        100|    5|
|  9|        100|    5|
| 10|        300|    5|
+---+-----------+-----+

可能的优化是将 zipWithIndex 替换为 zipWithUniqueIds (但会稍微多一些“不完整”的批处理)或使用 vectorized udf .

关于python - 根据spark中的移动和将批号添加到DataFrame,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66507942/

相关文章:

python - 在python中匹配html标签

python - 按列值分组并将其设置为 Pandas 中的索引

python-3.x - 替换数据框列中每个值的天数

python - 如果多次使用 RDD 是否需要缓存?

linux - 将 Windows 上的 Spark 文件存储到 HDFS

apache-spark - 如何使用 Spark DataFrames 进行分层抽样?

python - 从 API 获取数据并解析

python - 使用 python 3 使用 PyQt4 QWebView 查看 map

Python (tkinter) 错误 : "CRC check failed"

pandas 系列或整齐的数据框 : index level values to dataframe columns