我一直在尝试编写 map
的异步版本Python 中用于执行 IO 的函数。
为此,我使用了一个带有生产者/消费者的队列。
起初它似乎运行良好,但无一异常(exception)。
特别是,如果我使用 queue.join()
,它在没有异常时运行良好,但在异常情况下阻塞。
如果我使用 gather(*tasks)
,当有异常时它运行良好,但如果没有则阻塞。
所以它只是有时完成,我只是不明白为什么。
这是我实现的代码:
import asyncio
from asyncio import Queue
from typing import Iterable, Callable, TypeVar
Input = TypeVar("Input")
Output = TypeVar("Output")
STOP = object()
def parallel_map(func: Callable[[Input], Output], iterable: Iterable[Input]) -> Iterable[Output]:
"""
Parallel version of `map`, backed by asyncio.
Only suitable to do IO in parallel (not for CPU intensive tasks, otherwise it will block).
"""
number_of_parallel_calls = 9
async def worker(input_queue: Queue, output_queue: Queue) -> None:
while True:
data = await input_queue.get()
try:
output = func(data)
# Simulate an exception:
# raise RuntimeError("")
output_queue.put_nowait(output)
finally:
input_queue.task_done()
async def group_results(output_queue: Queue) -> Iterable[Output]:
output = []
while True:
item = await output_queue.get()
if item is not STOP:
output.append(item)
output_queue.task_done()
if item is STOP:
break
return output
async def procedure() -> Iterable[Output]:
# First, produce a queue of inputs
input_queue: Queue = asyncio.Queue()
for i in iterable:
input_queue.put_nowait(i)
# Then, assign a pool of tasks to consume it (and also produce outputs in a new queue)
output_queue: Queue = asyncio.Queue()
tasks = []
for _ in range(number_of_parallel_calls):
task = asyncio.create_task(worker(input_queue, output_queue))
tasks.append(task)
# Wait for the input queue to be fully consumed (only works if no exception occurs in the tasks), blocks otherwise.
await input_queue.join()
# Gather tasks, only works when an exception is raised in a task, blocks otherwise
# asyncio.gather(*tasks)
for task in tasks:
task.cancel()
# Indicate that the output queue is complete, to stop the worker
output_queue.put_nowait(STOP)
# Consume the output_queue, and return its data as a list
group_results_task = asyncio.create_task(group_results(output_queue))
await output_queue.join()
output = await group_results_task
return output
return asyncio.run(procedure())
if __name__ == "__main__()":
def my_function(x):
return x * x
data = [1, 2, 3, 4]
print(parallel_map(my_function, data))
我认为我误解了 Python asyncio 的一个基本但重要的内容,但不确定是什么。
最佳答案
问题是,您没有捕获异常。
来自 Python doc
The count of unfinished tasks goes up whenever an item is added to the queue. The count goes down whenever a consumer coroutine calls task_done() to indicate that the item was retrieved and all work on it is complete. When the count of unfinished tasks drops to zero, join() unblocks.
所以
Queue
实质上是在计算 put()
上的调用次数,并在每个 task_done()
上将计数器减 1称呼。如果worker在处理所有队列之前停止,你会被阻塞在Queue.join()
.在您的工作人员代码中:
async def worker(input_queue: Queue, output_queue: Queue) -> None:
while True:
data = await input_queue.get()
try:
output = func(data)
output_queue.put_nowait(output)
finally:
input_queue.task_done()
您的工作人员在遇到异常时停止。所以task_done()
的数量调用次数不等于 input_queue.put()
上的调用次数.这就是它一直挂的原因。您不能在看到异常时取消所有任务 - 这是一个设计缺陷。为了易于实现、减少失败和性能,有多个设计因素需要改变。
对于设计因素,我进行了以下更改:
function
只是,所以最好支持coroutine
也。 input_queue
保证在 worker
之前填充.正在检查 Queue.empty()
足以确定循环的结束。 input_queue
开始那么不需要哨兵,你知道多久给定iterable
是,由 queue.qszie()
. await Queue.put()
而不是 put_nowait()
, 你不能确定 Queue
在您放置的那个精确时间不可用。 Exception
中实现故障安全,将错误放入结果并处理所有队列,然后根据用户的选择简单地重新提出它。 for
此任务不需要 - 和 list.append
会妨碍您的脚本的性能。 Queue
来自 asyncio
,如果来自 queue
,它不会警告用户足够多或任何其他内置库 Queue
对象。 queue.join
运行 group_results
- await
它完全。 功能代码:
import asyncio
def parallel_map(func, iterable, concurrent_limit=2, raise_error=False):
async def worker(input_queue: asyncio.Queue, output_queue: asyncio.Queue):
while not input_queue.empty():
idx, item = await input_queue.get()
try:
# Support both coroutine and function. Coroutine function I mean!
if asyncio.iscoroutinefunction(func):
output = await func(item)
else:
output = func(item)
await output_queue.put((idx, output))
except Exception as err:
await output_queue.put((idx, err))
finally:
input_queue.task_done()
async def group_results(input_size, output_queue: asyncio.Queue):
output = {} # using dict to remove the need to sort list
for _ in range(input_size):
idx, val = await output_queue.get() # gets tuple(idx, result)
output[idx] = val
output_queue.task_done()
return [output[i] for i in range(input_size)]
async def procedure():
# populating input queue
input_queue: asyncio.Queue = asyncio.Queue()
for idx, item in enumerate(iterable):
input_queue.put_nowait((idx, item))
# Remember size before using Queue
input_size = input_queue.qsize()
# Generate task pool, and start collecting data.
output_queue: asyncio.Queue = asyncio.Queue()
result_task = asyncio.create_task(group_results(input_size, output_queue))
tasks = [
asyncio.create_task(worker(input_queue, output_queue))
for _ in range(concurrent_limit)
]
# Wait for tasks complete
await asyncio.gather(*tasks)
# Wait for result fetching
results = await result_task
# Re-raise errors at once if raise_error
if raise_error and (errors := [err for err in results if isinstance(err, Exception)]):
# noinspection PyUnboundLocalVariable
raise Exception(errors) # It never runs before assignment, safe to ignore.
return results
return asyncio.run(procedure())
测试代码:if __name__ == "__main__":
import random
import time
data = [1, 2, 3]
err_data = [1, 'yo', 3]
def test_normal_function(data_, raise_=False):
def my_function(x):
t = random.uniform(1, 2)
print(f"Sleep {t:.3} start")
time.sleep(t)
print(f"Awake after {t:.3}")
return x * x
print(f"Normal function: {parallel_map(my_function, data_, raise_error=raise_)}\n")
def test_coroutine(data_, raise_=False):
async def my_coro(x):
t = random.uniform(1, 2)
print(f"Coroutine sleep {t:.3} start")
await asyncio.sleep(t)
print(f"Coroutine awake after {t:.3}")
return x * x
print(f"Coroutine {parallel_map(my_coro, data_, raise_error=raise_)}\n")
# Test starts
print(f"Test for data {data}:")
test_normal_function(data)
test_coroutine(data)
print(f"Test for data {err_data} without raise:")
test_normal_function(err_data)
test_coroutine(err_data)
print(f"Test for data {err_data} with raise:")
test_normal_function(err_data, True)
test_coroutine(err_data, True) # this line will not run, but works same.
以上将为 function
测试以下条件和 coroutine
:即使有异常,这也不会取消任务,而是处理所有队列。
输出:
Test for data [1, 2, 3]:
Sleep 1.71 start
Awake after 1.71
Sleep 1.74 start
Awake after 1.74
Sleep 1.83 start
Awake after 1.83
Normal function: [1, 4, 9]
Coroutine sleep 1.32 start
Coroutine sleep 1.01 start
Coroutine awake after 1.01
Coroutine sleep 1.98 start
Coroutine awake after 1.32
Coroutine awake after 1.98
Coroutine [1, 4, 9]
Test for data [1, 'yo', 3] without raise:
Sleep 1.57 start
Awake after 1.57
Sleep 1.98 start
Awake after 1.98
Sleep 1.39 start
Awake after 1.39
Normal function: [1, TypeError("can't multiply sequence by non-int of type 'str'"), 9]
Coroutine sleep 1.22 start
Coroutine sleep 2.0 start
Coroutine awake after 1.22
Coroutine sleep 1.96 start
Coroutine awake after 2.0
Coroutine awake after 1.96
Coroutine [1, TypeError("can't multiply sequence by non-int of type 'str'"), 9]
Test for data [1, 'yo', 3] with raise:
Sleep 1.99 start
Awake after 1.99
Sleep 1.74 start
Awake after 1.74
Sleep 1.52 start
Awake after 1.52
Traceback (most recent call last):
...
line 52, in procedure
raise Exception(errors)
Exception: [TypeError("can't multiply sequence by non-int of type 'str'")]
请注意,我已经设置了 concurrent_limit
2 演示协程等待可用的worker。这就是为什么 3 个协程任务没有立即运行的原因。从输出中您还可以看到一些任务先于其他任务完成,但结果是有序的。
附言
如果您要导入
Queue
由于类型提示违反了 PEP-8 行限制,您可以添加类型提示,如下所示:async def worker(input_queue, output_queue) -> None:
input_queue: asyncio.Queue
output_queue: asyncio.Queue
或者async def worker(
input_queue: asyncio.Queue,
output_queue: asyncio.Queue
) -> None:
虽然它不像原始方式那么干净,但这将有助于其他人阅读您的代码。
关于Python asyncio : Queue. join() 仅在未引发异常时完成,为什么? (上下文:编写异步映射函数),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64138325/