Python asyncio : Queue. join() 仅在未引发异常时完成,为什么? (上下文:编写异步映射函数)

标签 python python-asyncio

我一直在尝试编写 map 的异步版本Python 中用于执行 IO 的函数。
为此,我使用了一个带有生产者/消费者的队列。
起初它似乎运行良好,但无一异常(exception)。
特别是,如果我使用 queue.join() ,它在没有异常时运行良好,但在异常情况下阻塞。
如果我使用 gather(*tasks) ,当有异常时它运行良好,但如果没有则阻塞。
所以它只是有时完成,我只是不明白为什么。
这是我实现的代码:

import asyncio
from asyncio import Queue
from typing import Iterable, Callable, TypeVar

Input = TypeVar("Input")
Output = TypeVar("Output")
STOP = object()


def parallel_map(func: Callable[[Input], Output], iterable: Iterable[Input]) -> Iterable[Output]:
    """
    Parallel version of `map`, backed by asyncio.
    Only suitable to do IO in parallel (not for CPU intensive tasks, otherwise it will block).
    """

    number_of_parallel_calls = 9

    async def worker(input_queue: Queue, output_queue: Queue) -> None:
        while True:
            data = await input_queue.get()
            try:
                output = func(data)
                # Simulate an exception:
                # raise RuntimeError("")
                output_queue.put_nowait(output)
            finally:
                input_queue.task_done()

    async def group_results(output_queue: Queue) -> Iterable[Output]:
        output = []
        while True:
            item = await output_queue.get()
            if item is not STOP:
                output.append(item)
            output_queue.task_done()
            if item is STOP:
                break
        return output

    async def procedure() -> Iterable[Output]:
        # First, produce a queue of inputs
        input_queue: Queue = asyncio.Queue()
        for i in iterable:
            input_queue.put_nowait(i)

        # Then, assign a pool of tasks to consume it (and also produce outputs in a new queue)
        output_queue: Queue = asyncio.Queue()
        tasks = []
        for _ in range(number_of_parallel_calls):
            task = asyncio.create_task(worker(input_queue, output_queue))
            tasks.append(task)

        # Wait for the input queue to be fully consumed (only works if no exception occurs in the tasks), blocks otherwise.
        await input_queue.join()
        # Gather tasks, only works when an exception is raised in a task, blocks otherwise
        # asyncio.gather(*tasks)

        for task in tasks:
            task.cancel()

        # Indicate that the output queue is complete, to stop the worker
        output_queue.put_nowait(STOP)

        # Consume the output_queue, and return its data as a list
        group_results_task = asyncio.create_task(group_results(output_queue))
        await output_queue.join()
        output = await group_results_task
        return output

    return asyncio.run(procedure())

if __name__ == "__main__()":
    def my_function(x):
        return x * x

    data = [1, 2, 3, 4]
    print(parallel_map(my_function, data))
我认为我误解了 Python asyncio 的一个基本但重要的内容,但不确定是什么。

最佳答案

问题是,您没有捕获异常。
来自 Python doc

The count of unfinished tasks goes up whenever an item is added to the queue. The count goes down whenever a consumer coroutine calls task_done() to indicate that the item was retrieved and all work on it is complete. When the count of unfinished tasks drops to zero, join() unblocks.


所以Queue实质上是在计算 put() 上的调用次数,并在每个 task_done() 上将计数器减 1称呼。如果worker在处理所有队列之前停止,你会被阻塞在Queue.join() .
在您的工作人员代码中:
    async def worker(input_queue: Queue, output_queue: Queue) -> None:
        while True:
            data = await input_queue.get()
            try:
                output = func(data)
                output_queue.put_nowait(output)
            finally:
                input_queue.task_done()
您的工作人员在遇到异常时停止。所以task_done()的数量调用次数不等于 input_queue.put() 上的调用次数.这就是它一直挂的原因。您不能在看到异常时取消所有任务 - 这是一个设计缺陷。

为了易于实现、减少失败和性能,有多个设计因素需要改变。
对于设计因素,我进行了以下更改:
  • 如果它正在运行就毫无意义function只是,所以最好支持coroutine也。
  • input_queue保证在 worker 之前填充.正在检查 Queue.empty()足以确定循环的结束。
  • 如果您从填写 input_queue 开始那么不需要哨兵,你知道多久给定iterable是,由 queue.qszie() .
  • 使用 await Queue.put()而不是 put_nowait() , 你不能确定 Queue在您放置的那个精确时间不可用。
  • 使用额外的索引参数保持输入和输出的顺序,您也无法确定并发任务的输出顺序。
  • 而不是在 Exception 中实现故障安全,将错误放入结果并处理所有队列,然后根据用户的选择简单地重新提出它。
  • 在任务生成中使用 genexp,for此任务不需要 - 和 list.append会妨碍您的脚本的性能。
  • 不进口Queue来自 asyncio ,如果来自 queue,它不会警告用户足够多或任何其他内置库 Queue对象。
  • 无需等待 queue.join运行 group_results - await它完全。

  • 功能代码:
    import asyncio
    
    
    def parallel_map(func, iterable, concurrent_limit=2, raise_error=False):
        async def worker(input_queue: asyncio.Queue, output_queue: asyncio.Queue):
    
            while not input_queue.empty():
                idx, item = await input_queue.get()
                try:
                    # Support both coroutine and function. Coroutine function I mean!
                    if asyncio.iscoroutinefunction(func):
                        output = await func(item)
                    else:
                        output = func(item)
                    await output_queue.put((idx, output))
    
                except Exception as err:
                    await output_queue.put((idx, err))
    
                finally:
                    input_queue.task_done()
    
        async def group_results(input_size, output_queue: asyncio.Queue):
            output = {}  # using dict to remove the need to sort list
    
            for _ in range(input_size):
                idx, val = await output_queue.get()  # gets tuple(idx, result)
                output[idx] = val
                output_queue.task_done()
    
            return [output[i] for i in range(input_size)]
    
        async def procedure():
            # populating input queue
            input_queue: asyncio.Queue = asyncio.Queue()
            for idx, item in enumerate(iterable):
                input_queue.put_nowait((idx, item))
            
            # Remember size before using Queue
            input_size = input_queue.qsize()
    
            # Generate task pool, and start collecting data.
            output_queue: asyncio.Queue = asyncio.Queue()
            result_task = asyncio.create_task(group_results(input_size, output_queue))
            tasks = [
                asyncio.create_task(worker(input_queue, output_queue))
                for _ in range(concurrent_limit)
            ]
            
            # Wait for tasks complete
            await asyncio.gather(*tasks)
            
            # Wait for result fetching
            results = await result_task
            
            # Re-raise errors at once if raise_error
            if raise_error and (errors := [err for err in results if isinstance(err, Exception)]):
                # noinspection PyUnboundLocalVariable
                raise Exception(errors)  # It never runs before assignment, safe to ignore.
    
            return results
    
        return asyncio.run(procedure())
    
    测试代码:
    if __name__ == "__main__":
        import random
        import time
    
        data = [1, 2, 3]
        err_data = [1, 'yo', 3]
    
        def test_normal_function(data_, raise_=False):
            def my_function(x):
                t = random.uniform(1, 2)
                print(f"Sleep {t:.3} start")
    
                time.sleep(t)
                print(f"Awake after {t:.3}")
    
                return x * x
    
            print(f"Normal function: {parallel_map(my_function, data_, raise_error=raise_)}\n")
    
        def test_coroutine(data_, raise_=False):
            async def my_coro(x):
                t = random.uniform(1, 2)
                print(f"Coroutine sleep {t:.3} start")
    
                await asyncio.sleep(t)
                print(f"Coroutine awake after {t:.3}")
    
                return x * x
    
            print(f"Coroutine {parallel_map(my_coro, data_, raise_error=raise_)}\n")
    
        # Test starts
        print(f"Test for data {data}:")
        test_normal_function(data)
        test_coroutine(data)
    
        print(f"Test for data {err_data} without raise:")
        test_normal_function(err_data)
        test_coroutine(err_data)
    
        print(f"Test for data {err_data} with raise:")
        test_normal_function(err_data, True)
        test_coroutine(err_data, True)  # this line will not run, but works same.
    
    以上将为 function 测试以下条件和 coroutine :
  • 正常数据
  • 有缺陷的数据,不要引发错误
  • 有缺陷的数据,引发它遇到的所有错误

  • 即使有异常,这也不会取消任务,而是处理所有队列。
    输出:
    Test for data [1, 2, 3]:
    Sleep 1.71 start
    Awake after 1.71
    Sleep 1.74 start
    Awake after 1.74
    Sleep 1.83 start
    Awake after 1.83
    Normal function: [1, 4, 9]
    
    Coroutine sleep 1.32 start
    Coroutine sleep 1.01 start
    Coroutine awake after 1.01
    Coroutine sleep 1.98 start
    Coroutine awake after 1.32
    Coroutine awake after 1.98
    Coroutine [1, 4, 9]
    
    Test for data [1, 'yo', 3] without raise:
    Sleep 1.57 start
    Awake after 1.57
    Sleep 1.98 start
    Awake after 1.98
    Sleep 1.39 start
    Awake after 1.39
    Normal function: [1, TypeError("can't multiply sequence by non-int of type 'str'"), 9]
    
    Coroutine sleep 1.22 start
    Coroutine sleep 2.0 start
    Coroutine awake after 1.22
    Coroutine sleep 1.96 start
    Coroutine awake after 2.0
    Coroutine awake after 1.96
    Coroutine [1, TypeError("can't multiply sequence by non-int of type 'str'"), 9]
    
    Test for data [1, 'yo', 3] with raise:
    Sleep 1.99 start
    Awake after 1.99
    Sleep 1.74 start
    Awake after 1.74
    Sleep 1.52 start
    Awake after 1.52
    Traceback (most recent call last):
    ...
    line 52, in procedure
        raise Exception(errors)
    Exception: [TypeError("can't multiply sequence by non-int of type 'str'")]
    
    请注意,我已经设置了 concurrent_limit 2 演示协程等待可用的worker。这就是为什么 3 个协程任务没有立即运行的原因。
    从输出中您还可以看到一些任务先于其他任务完成,但结果是有序的。

    附言
    如果您要导入 Queue由于类型提示违反了 PEP-8 行限制,您可以添加类型提示,如下所示:
    async def worker(input_queue, output_queue) -> None:
        input_queue: asyncio.Queue
        output_queue: asyncio.Queue
    
    或者
    async def worker(
            input_queue: asyncio.Queue,
            output_queue: asyncio.Queue
        ) -> None:
    
    虽然它不像原始方式那么干净,但这将有助于其他人阅读您的代码。

    关于Python asyncio : Queue. join() 仅在未引发异常时完成,为什么? (上下文:编写异步映射函数),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64138325/

    相关文章:

    python - 使用 exec 进行变量赋值是Pythonic吗?

    Python 套接字服务器接收重复数据

    python - 脚本在执行过程中的某个时刻抛出一些错误

    python - pyinstaller ImportError 错误 - 如何解决?

    python - python 的 googlesearch 库

    python - 在 ApplicationSession 的注册端点内访问 RPC 调用者的 IP 和 HTTP 连接 header

    python - 如何在 Django Channels 中创建服务器端计时器/倒计时?

    python - 在 Python 3.6.10 上运行异步 Flask 2.0.0 时出错

    python - 使用 contextvar 跟踪 Python 中的异步循环

    python - 为什么我的流套接字一次只能从我的浏览器排队一个连接?