python - BranchPythonOperator 意外跳过后的 Airflow 任务

标签 python python-3.x python-2.x airflow

我的 dag 定义如下。虽然 flag1flag2 都是 y,但它们不知何故被跳过了。

from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator
import pandas as pd
from itertools import compress



default_args = {
    'owner': 'alex'
    , 'retries': 2
    , 'retry_delay': timedelta(minutes=15)
    , 'depends_on_past': False
    , 'start_date': datetime(2018, 11, 22)
}

dag = DAG(
    'test_dag'
    , catchup = False
    , default_args = default_args
    , schedule_interval = '@daily'
)

task1 = DummyOperator(
        task_id='task1',
        dag=dag,
    )

task2 = DummyOperator(
        task_id='task2',
        dag=dag,
    )

task3 = DummyOperator(
        task_id='task3',
        dag=dag,
    )


# 1 means yes, 0 means no
flag1 = 'y'
flag2 = 'y'
flag3 = 'y'

tasks_name = ['task1', 'task2', 'task3']
flags = [flag1, flag2, flag3]


def generate_branches(tasks_name, flags):
    res = []
    idx = 1
    root_name = 'switch'
    for sub_task, sub_flag in zip(tasks_name, flags):
        tmp_branch_operator = BranchPythonOperator(
            task_id=root_name+str(idx), # switch1, switch2, ...
            python_callable= lambda: sub_task if sub_flag == 'y' else 'None',
            dag=dag,
        )
        res.append(tmp_branch_operator)
        idx += 1
    return res


def set_dependencies(switches, transfer_operators):
    for sub_switch, sub_transfer_operator in zip(switches, transfer_operators):
        sub_switch.set_downstream(sub_transfer_operator)


transfer_operators = [task1, task2, task3]
gen_branches_op = generate_branches(tasks_name, flags)
set_dependencies(gen_branches_op, transfer_operators)

enter image description here

最佳答案

该问题是由 lambda 的延迟绑定(bind)行为引起的。因为 lambda 在调用时进行评估,所以每次您的 lambda 总是返回列表中的最后一个元素,即 task3

如果查看 switch1 和 switch2 的日志,您会发现它们分别有以下分支 task3 而不是 task1task2

为避免这种情况,您可以通过更改 generate_branches() 中的 python_callable 来强制在定义 lambda 时对其求值:

def generate_branches(tasks_name, flags):
    res = []
    idx = 1
    root_name = 'switch'
    for sub_task, sub_flag in zip(tasks_name, flags):
        tmp_branch_operator = BranchPythonOperator(
            task_id=root_name+str(idx), # switch1, switch2, ...
            python_callable=lambda sub_task=sub_task: sub_task if sub_flag == "y", else "None"
            dag=dag,
        )
        res.append(tmp_branch_operator)
        idx += 1
    return res

关于python - BranchPythonOperator 意外跳过后的 Airflow 任务,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55172736/

相关文章:

python - 检测图像是否像素化的最佳方法是什么?

python - Python 2.7 中的 Google Cloud 客户端库

Excel 输出中的 Python Pandas 自定义时间格式

python - 仅将一列从字符串转换为 int

python - 重复元组的正确约定是什么?

session - 具有 session 处理功能的 Python 2 SSL xmlrpc 服务器

python - 带有单引号字符的 csv.writer

python - 获取matplotlib中的数字列表

python - 如何使用 python 作为服务器端语言?

python-3.x - 在 windows64 中用于 python 3.6 的 nltk