《asyncio 系列》6. 在 asyncio 中引入多进程

发布时间 2023-06-02 11:40:44作者: tomato-haha

 


楔子

到目前为止我们使用 asyncio 获得的性能提升,一直专注在并发运行 IO 密集型工作上面,当然运行 IO 密集型工作是 asyncio 的主要工作,并且按照目前编写代码的方式,需要注意不要在协程中运行任何的 CPU 密集型代码。但这似乎严重限制了 asyncio 的使用,因为这个库能做的事情不仅仅限于处理 IO 密集型工作。

asyncio 有一个用于和 Python 的 multiprocessing 库进行互操作的 API,让我们可以使用 async await 语法以及具有多个进程的 asyncio API。通过这个 API,即使使用 CPU 密集型代码,我们也可以获得 asyncio 库带来的优势。这使我们能够为 CPU 密集型工作(如数学计算或数据处理)实现性能提升,避开全局解释器锁定,并充分利用多核机器资源。

在本篇文章中,我们将首先了解 multiprocessing 模块,以熟悉执行多个进程的概念。然后将了解进程池执行器以及如何将它们挂接到 asyncio,此后我们将利用这些知识来解决 MapReduce 的 CPU 密集型问题。我们还将学习管理多个进程之间的共享状态并且了解锁定的概念,以避免并发错误。最后,我们将看看如何使用 multiprocessing 来提高 IO 和 CPU 密集型应用程序的性能。

multiprocessing 库

我们知道 CPython 有一个全局解释器锁,全局解释器锁可防止多个 Python 字节码并行运行。这意味着对于 IO 密集型任务以外的其他任务,除了一些小的异常,使用多线程不会像在 Java 和 C++ 等语言中那样提供任何性能优势。对于 Python 中的并行 CPU 密集型工作,似乎我们可能无法提升性能,但可通过 multiprocessing 库来实现这一点。

不是通过父进程生成线程来并行处理,而是生成子进程来处理工作。每个子进程都有自己的 Python 解释器,且遵循 GIL,所以会有多个解释器,每个解释器都有自己的 GIL。假设运行在具有多个 CPU 核的机器上,这意味着可以有效地并行处理任何 CPU 密集型的工作负载。即使进程比内核数要多,操作系统也会使用抢占式多任务,来允许多个任务同时运行。这种设置既是并发的,也是并行的。

import time
from multiprocessing import Process

def count(to: int):
    start = time.perf_counter()
    counter = 0
    while counter < to:
        counter += 1
    end = time.perf_counter()

    print(f"在 {end - start} 秒内将 counter 增加到 {to}")

if __name__ == '__main__':
    start = time.perf_counter()
    task1 = Process(target=count, args=(100000000,))
    task2 = Process(target=count, args=(100000000,))
    # 启动进程
    task1.start()
    task2.start()
    # 该方法会一直阻塞主进程,直到子进程执行完成,并且 join 方法内部也可以接收一个超时时间
    # 如果子进程在规定时间内没有完成,那么主进程不再等待
    task1.join()
    task2.join()
    end = time.perf_counter()
    print(f"在 {end - start} 秒内完成")
"""
在 4.685973625 秒内将 counter 增加到 100000000
在 4.687154792 秒内将 counter 增加到 100000000
在 4.962888833 秒内完成
"""

我们看到总耗时却是 4.9 秒,如果你把子进程换成子线程,那么耗时就不一样了,我们来测试一下。

import time
from threading import Thread

def count(to: int):
    start = time.perf_counter()
    counter = 0
    while counter < to:
        counter += 1
    end = time.perf_counter()

    print(f"在 {end - start} 秒内将 counter 增加到 {to}")

if __name__ == '__main__':
    start = time.perf_counter()
    # 多线程和多进程相关的 API 是一致的,只需要将 Process 换成 Thread 即可
    task1 = Thread(target=count, args=(100000000,))
    task2 = Thread(target=count, args=(100000000,))
    task1.start()
    task2.start()
    task1.join()
    task2.join()
    end = time.perf_counter()
    print(f"在 {end - start} 秒内完成")
"""
在 8.974233167000001 秒内将 counter 增加到 100000000
在 8.985212875 秒内将 counter 增加到 100000000
在 8.991764792 秒内完成
"""

因为线程存在切换,所以不会运行完一个任务之后再运行下一个,而是并发运行的。但不管怎么并发,同一时刻只会有一个任务在运行。所以对于 CPU 密集型任务来说,上面的耗时是两个任务加起来的的时间,因为同一时刻只会用到一个 CPU 核心。而采用多进程,那么两个任务就是并发运行的了。

虽然启动多进程给我们带来了不错的性能提升,但是有一点却很尴尬,因为我们必须为启动的每个进程调用 start 和 join。并且我们也不知道哪个过程会先完成,如果想完成 asyncio.as_completed 之类的工作,并在结果完成时处理它们,那么上面的解决方案就无法满足要求了。此外 join 方法不会返回目标函数返回的值,事实上,目前在不使用共享进程间内存的情况下是无法获取函数的返回值的。

因此这个 API 适用于简单的情况,但如果我们有想要获取函数的返回值,或想要在结果生成时立即处理结果,它显然不起作用。幸运的是,进程池提供了一种解决方法。

在使用多进程或后续的进程池时,必须要加上 if __name__ == '__main__',否则会报错:An attempt has been made to start a new process before the current process has finished its bootstrapping phase。这样做的原因是为了防止其他人导入代码时不小心启动多个进程。

使用进程池

上面我们手动创建了进程,并调用 start 和 join 方法来运行并等待它们。但这种方法存在几个问题,包括代码质量以及无法访问返回的结果,于是 multiprocessing 模块有一个 API 可以让我们解决这个问题,称为进程池。

from multiprocessing import Pool

def say_hello(name) -> str:
    return f"hello, {name}"

if __name__ == '__main__':
    with Pool() as pool:
        hi1 = pool.apply(say_hello, args=("satori",))
        hi2 = pool.apply(say_hello, args=("koishi",))
        print(hi1)
        print(hi2)
"""
hello, satori
hello, koishi
"""

我们使用 with Pool() as pool 创建了一个进程池,这是一个上下文管理器,因为一旦使用了进程池,那么就需要适当地关闭创建的 Python 进程。如果不这样做,就存在进程泄露的风险,这可能导致资源利用的问题。当实例化这个池时,它会自动创建与你使用的机器上的 CPU 内核数量相等的 Python 进程。

可通过运行 multiprocessing.cpu_count() 函数来确定当前机器拥有的 CPU 核心数,并且在调用 Pool() 时也可以通过指定 processes 参数设置需要使用的核心数。一般情况下,使用默认值即可。

接下来使用进程池的 apply 方法在一个单独的进程中运行 say_hello 函数,这个方法看起来类似于我们之前对 Process 类所做的,我们传递了一个目标函数和一个参数元组。但区别是不需要自己启动进程或调用 join,并且还得到了函数的返回值,这在前面的例子中是无法完成的。

上面代码可以成功执行,但有一个问题,apply 方法会一直阻塞,直到函数执行完成。这意味着,如果每次调用 say_hello 需要 10 秒,那么整个程序的运行时间将是大约 20秒,因为我们是串行运行的,无法并行运行。因此可以将 apply 换成 apply_async 来解决这个问题,一旦调用的是 apply_async 方法,那么返回的就不再是目标函数的返回值了,而是一个 AsyncResult 对象,进程会在后台运行。

如果想要返回值,那么可以调用 AsyncResult 的 get 方法,该方法会阻塞并获取目标函数的返回值。

from multiprocessing import Pool

def say_hello(name) -> str:
    return f"hello, {name}"

if __name__ == '__main__':
    with Pool() as pool:
        hi1_async = pool.apply_async(say_hello, args=("satori",))
        hi2_async = pool.apply_async(say_hello, args=("koishi",))
        # 可以接收一个超时时间,如果在规定时间内没有完成
        # 那么抛出 multiprocessing.context.TimeoutError,默认会一直阻塞
        print(hi1_async.get())
        print(hi2_async.get())
"""
hello, satori
hello, koishi
"""

调用 apply_async 时,对 say_hello 的两个调用会立即在不同的进程中开始执行。然后调用 get 方法时,父进程会阻塞,直到每个进程都返回一个值。但这里还隐藏了一个问题,如果 hi1_async 需要 10 秒,hi2_async 需要 1 秒,会发生什么呢?因为我们首先调用 hi1_async 的 get 方法,所以在第二个 print 在打印之前需要先阻塞 10 秒,即使 hi2_async 只需要 1 秒就完成了。

如果想在事情完成后立即作出回应,就会遇到问题。这种情况下,我们真正想要的是类似于 asyncio.as_completed 返回的对象。接下来,看看如何将进程池执行器与 asyncio 一起使用,以便我们解决这个问题。但在此之前,我们需要先了解一个模块。

concurrent.futures 的相关用法

本来这模块不应该放在这里介绍的,但如果它不说,我们后面的内容就不方便展开,所以我们就来先聊聊这个模块。

concurrent.futures 模块提供了使用线程池或进程池运行任务的接口,线程池和进程池的 API 是一致的,所以应用只需要做最小的修改就可以在线程和进程之间进行切换。这个模块提供了两种类型的类与这些池交互:执行器(executor)用来管理工作线程或进程池,future 用来管理计算的结果。要使用一个工作线程或进程池,应用要创建适当的执行器类的一个实例,然后向它提交任务来运行。

该模块和 asyncio 里面的一些概念非常相似,或者说 asyncio 在设计的时候借鉴了 concurrent.futures 的很多理念。

Future 对象

当我们将一个函数提交到线程池里面运行时,会立即返回一个对象,这个对象就叫做 Future 对象,里面包含了函数的执行状态等等,当然我们也可以手动创建一个 Future 对象。

Future 对象和 asyncio 里面的 Future 在概念上是类似的。

from concurrent.futures import Future

# 创建一个Future对象
future = Future()

def callback(future):
    print("当set_result的时候,执行回调,我也可以拿到返回值:", future.result())


# 通过调用add_done_callback方法,可以将该future绑定一个回调函数
# 这里只需要传入函数名即可,future会自动传递给callback的第一个参数
# 如果这里需要多个参数的话,怎么办呢?很简单,使用偏函数即可
future.add_done_callback(callback)

# 当什么时候会触发回调函数的执行呢?
# 当future执行set_result的时候
future.set_result("return value")
"""
当set_result的时候,执行回调,我也可以拿到返回值: return value
"""

我们说过,将函数提交到线程池里面运行的时候,会立即返回,从而得到一个 Future 对象。这个 Future 对象里面就包含了函数的执行状态,比如此时是处于暂停、运行中还是完成等等,并且在函数执行完毕之后,还可以拿到返回值。

from concurrent.futures import ThreadPoolExecutor
import time

def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"

# 创建一个线程池,里面还可以指定 max_workers 参数,表示最多创建多少个线程
# 如果不指定,那么每一个任务都会为其创建一个线程
executor = ThreadPoolExecutor()

# 通过 submit 就直接将任务提交到线程池里面了,一旦提交,就会立刻运行
# 提交之后,相当于开启了一个新的线程,主线程会继续往下走
# 参数按照函数名,对应参数提交即可,切记不可写成 task("古明地觉", 16, 3),这样就变成调用了
future = executor.submit(task, "古明地觉", 16, 3)

# 由于 n=3,所以会休眠 3 秒,此时任务处于 running 状态
print(future)  # <Future at 0x226b860 state=running>

# 让主程序也休眠 3s
time.sleep(3)

# 此时再打印
print(future)  # <Future at 0x226b860 state=finished returned str>

"""
可以看到,一开始任务处于running,正在运行状态
3s 过后,任务处于 finished,完成状态,并告诉我们返回了一个 str
"""

然后获取任务的返回值:

from concurrent.futures import ThreadPoolExecutor
import time


def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


executor = ThreadPoolExecutor()

future = executor.submit(task, "古明地觉", 16, 3)

start_time = time.perf_counter()
print(future.result())  # name is 古明地觉, age is 16, sleep 3s
print(f"耗时:{time.perf_counter() - start_time}")  # 耗时:2.999359371

可以看到,打印 future.result() 这一步花了将近 3s。其实也不难理解,future.result() 是干嘛的,就是为了获取函数的返回值,可函数都还没有执行完毕,它又从哪里获取呢?所以只能先等待函数执行完毕,将返回值通过 set_result 自动地设置到 future 里面之后,外界 future.result() 才能够获取到值。所以 future.result() 这一步实际上是会阻塞的,会等待任务执行完毕。

当然也可以绑定一个回调:

from concurrent.futures import ThreadPoolExecutor
import time

def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"

def callback(future):
    print(future.result())

executor = ThreadPoolExecutor()
future = executor.submit(task, "古明地觉", 16, 3)
time.sleep(5)
future.add_done_callback(callback)
"""
name is 古明地觉, age is 16, sleep 3s
"""

等到函数执行完毕之后,依旧会获取到返回值,但这里我加上 time.sleep(5),只是为了证明即使等函数完成之后再去添加回调,依旧是可以的。函数完成之前添加回调,那么会在函数执行完毕后触发回调;函数完成之后添加回调,由于函数已经执行完成,代表此时的 future 已经有值了,或者说已经 set_result 了,那么会立即触发回调,因此 time.sleep(5) 完全可以去掉。

提交多个函数

提交函数的话,可以提交任意多个,我们来看一下:

from concurrent.futures import ThreadPoolExecutor
import time

def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


executor = ThreadPoolExecutor()

futures = [executor.submit(task, "古明地觉", 16, 3),
           executor.submit(task, "古明地觉", 16, 4),
           executor.submit(task, "古明地觉", 16, 1),
           ]

# 此时都处于 running
print(futures)  
"""
[<Future at 0x226b860 state=running>, 
 <Future at 0x9f4b160 state=running>, 
 <Future at 0x9f510f0 state=running>]
"""
time.sleep(3.5)

# 主程序 sleep 3.5s 后,futures[0] 和 futures[2] 处于 finished,futures[1] 处于running
print(futures)
"""
[<Future at 0x271642c2e50 state=finished returned str>, 
 <Future at 0x2717b5f6e50 state=running>, 
 <Future at 0x2717b62f1f0 state=finished returned str>]
"""

获取任务的返回值:

from concurrent.futures import ThreadPoolExecutor
import time

def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


executor = ThreadPoolExecutor()

futures = [executor.submit(task, "古明地觉", 16, 5),
           executor.submit(task, "古明地觉", 16, 2),
           executor.submit(task, "古明地觉", 16, 4),
           executor.submit(task, "古明地觉", 16, 3),
           executor.submit(task, "古明地觉", 16, 6),
           ]

# 此时的 futures 里面相当于有了 5 个 future
# 记做 future1,future2,future3,future4,future5
for future in futures:
    print(future.result())
"""
name is 古明地觉, age is 16, sleep 5s
name is 古明地觉, age is 16, sleep 2s
name is 古明地觉, age is 16, sleep 4s
name is 古明地觉, age is 16, sleep 3s
name

当我们使用 for 循环的时候,实际上会依次遍历这 5 个 future,所以返回值的顺序就是我们添加的 future 的顺序。但 future1 对应的任务休眠了 5s,那么必须等到 5s 后,future1 里面才会有值。由于这五个任务是并发执行的,future2、future3、future4 只休眠了 2s、4s、3s,所以肯定会先执行完毕,然后执行 set_result,将返回值设置到对应的 future 里。

但 Python 的 for 循环,不可能在第一次迭代还没有结束,就去执行第二次迭代。因为 futures 里面的几个 future 的顺序已经一开始就被定好了,只有当第一个 future.result() 执行完成之后,才会执行第二个 future.result()、第三个。。。即便后面的任务已经执行完毕,但由于 for 循环的顺序,也只能等着,直到前面的 future.result() 执行完毕。

所以会先打印 "name is 古明地觉, age is 16, sleep 5s",当这句打印完时,由于后面的任务早已执行完毕,只是由于第一个 future.result() 太慢,又把路给堵住了,才导致后面的无法输出。因此第一个future.result()执行完毕之后,后面的 3 个 future.result() 会瞬间执行,从而立刻打印。

而最后一个任务由于是 6s,因此再过 1s 后,打印 "name is 古明地觉, age is 16, sleep 6s"。

查看函数是否执行完毕

我们之前说 future 里面包含了函数的执行状态,所以我们可以通过 future.done() 查看任务是否完成。

from concurrent.futures import ThreadPoolExecutor
import time


def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


executor = ThreadPoolExecutor()

# 之前我们说过,可以打印future来查看任务的状态,其实还有一种方法来确定任务是否完成
future = executor.submit(task, "椎名真白", 16, 3)

while True:
    if future.done():
        print(f"任务执行完毕:{future.done()}")
        break

    else:
        print(f"任务尚未执行完毕:{future.done()}")

    time.sleep(1)
"""
任务尚未执行完毕:False
任务尚未执行完毕:False
任务尚未执行完毕:False
任务尚未执行完毕:False
任务执行完毕:True
"""

# 当任务尚未执行完毕的时候,future.done() 是 False,执行完毕之后打印为 True

除此之外,还有一个 future.running() ,表示任务是否正在运行。如果正在运行返回 True,运行结束或者失败,返回 False。

使用 map 来提交多个函数

使用 map 来提交会更简单一些,如果任务的量比较多,并且不关心某个具体任务设置回调的话,可以使用 map。那么如何使用 map 提交任务呢?

# 如果我想将以下这种用submit提交的方式,改用map要怎么做呢?
"""
futures = [executor.submit(task, "椎名真白", 16, 5),
           executor.submit(task, "古明地觉", 16, 2),
           executor.submit(task, "古明地恋", 15, 4),
           executor.submit(task, "坂上智代", 19, 3),
           executor.submit(task, "春日野穹", 16, 6)]
"""
# 可以直接改成
"""
results = executor.map(task,
                       ["椎名真白", "古明地觉", "古明地恋", "坂上智代", "春日野穹"],
                       [16, 16, 15, 19, 16],
                       [5, 2, 4, 3, 6])
"""

map 这样写确实是简化了不少,但是我们也可以看到使用这种方式就无法为某个具体的任务添加回调函数了。并且 map 内部也是使用了 submit。

我们测试一下:

from concurrent.futures import ThreadPoolExecutor
import time

def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


executor = ThreadPoolExecutor()

results = executor.map(task,
                       ["椎名真白", "古明地觉", "古明地恋", "坂上智代", "春日野穹"],
                       [16, 16, 15, 19, 16],
                       [5, 2, 4, 3, 6])

# 此时返回的是一个生成器,里面存放的就是函数的返回值
print(results)  
"""
<generator object Executor.map.<locals>.result_iterator at 0x0000000009F4C840>
"""

# 此时的 future,相当于submit当中的future.result()
for result in results:
    print(result)
"""
name is 椎名真白, age is 16, sleep 5s
name is 古明地觉, age is 16, sleep 2s
name is 古明地恋, age is 15, sleep 4s
name is 坂上智代, age is 19, sleep 3s
name is 春日野穹, age is 16, sleep 6s
"""

如果想等所有任务都执行完毕之后再一并处理的话,可以直接调用一个 list。

from concurrent.futures import ThreadPoolExecutor
import time
import pprint

def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


executor = ThreadPoolExecutor()

results = executor.map(task,
                       ["椎名真白", "古明地觉", "古明地恋", "坂上智代", "春日野穹"],
                       [16, 16, 15, 19, 16],
                       [5, 2, 4, 3, 6])

# 如果这里我改一下,改成 list(results)
# 分析:
# 由于 results 是一个生成器,当我转化为 list 之后,会将里面所有的值全部生产出来
# 这就意味着,要将所有任务的返回值都获取到才行。
# 尽管我们不需要调用 result(),但 result 这一步是无法避免的,只是 map 内部自动帮我们调用了
# 因此调用 result() 方法是不可避免的,调用的时候依旧会阻塞
# 而耗时最长的任务是 6s,因此这一步会阻塞 6s,6s 过后,会打印所有任务的返回值
start_time = time.perf_counter()
pprint.pprint(list(results))
print(f"总耗时:{time.perf_counter() - start_time}")
"""
['name is 椎名真白, age is 16, sleep 5s',
 'name is 古明地觉, age is 16, sleep 2s',
 'name is 古明地恋, age is 15, sleep 4s',
 'name is 坂上智代, age is 19, sleep 3s',
 'name is 春日野穹, age is 16, sleep 6s']
总耗时:6.00001767
"""

当然调用 list,和我们直接使用 for 循环本质上是一样的。并且我们看到,这个和 asyncio 的 gather 是比较类似的。

按照顺序等待

现在有这么一个需求,就是哪个任务先完成,哪个就先返回,这要怎么做呢?

from concurrent.futures import ThreadPoolExecutor, as_completed
import time

def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


executor = ThreadPoolExecutor()

futures = [executor.submit(task, "椎名真白", 16, 5),
           executor.submit(task, "古明地觉", 16, 2),
           executor.submit(task, "古明地恋", 15, 4),
           executor.submit(task, "坂上智代", 19, 3),
           executor.submit(task, "春日野穹", 16, 6)]


for future in as_completed(futures):
    print(future.result())
"""
name is 古明地觉, age is 16, sleep 2s
name is 坂上智代, age is 19, sleep 3s
name is 古明地恋, age is 15, sleep 4s
name is 椎名真白, age is 16, sleep 5s
name is 春日野穹, age is 16, sleep 6s
"""

# 只需要将futures传递给as_completed即可

如何取消一个任务

我们可以将任务添加到线程池当中,但是如果我们想取消怎么办呢?

from concurrent.futures import ThreadPoolExecutor
import time


def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


executor = ThreadPoolExecutor()

future1 = executor.submit(task, "椎名真白", 16, 5)
future2 = executor.submit(task, "古明地觉", 16, 2)
future3 = executor.submit(task, "古明地恋", 15, 4)

# 取消任务,可以使用future.cancel
print(future3.cancel())  # False

但是我们发现调用 cancel 方法的时候,返回的是 False,这是为什么?因为任务已经被提交到线程池里面了,任务已经运行了,只有在任务还没有运行时,取消才会成功。可这不矛盾了吗?任务一旦提交就会运行,只有不运行才会取消成功,这怎么办?还记得线程池的一个叫做 max_workers 的参数吗?控制线程池内线程数量的,我们可以将最大的任务数设置为 2,那么当第三个任务进去的时候,就不会执行了,而是处于等待状态。

from concurrent.futures import ThreadPoolExecutor
import time

def task(name, age, n):
    print(f"sleep {n}")
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


# 此时最多只能同时执行两个任务
executor = ThreadPoolExecutor(max_workers=2)

future1 = executor.submit(task, "椎名真白", 16, 5)
future2 = executor.submit(task, "古明地觉", 16, 2)
future3 = executor.submit(task, "古明地恋", 15, 4)

print(future3.cancel())  # True
"""
sleep 5
sleep 2
"""
# 可以看到打印为 True,说明取消成功了
# 而 sleep 4 也没有被打印

而事实上我们在启动线程池的时候,肯定是需要设置容量的,不然处理几千个任务要几千个线程吗。

任务中的异常

如果任务当中产生了一个异常,同样会被保存到 future 当中,可以通过 future.exception 获取。

from concurrent.futures import ThreadPoolExecutor

def task1():
    1 / 0

def task2():
    pass


executor = ThreadPoolExecutor(max_workers=2)

future1 = executor.submit(task1)
future2 = executor.submit(task2)

print(future1.exception())  # division by zero
print(future2.exception())  # None


# 或者
try:
    future1.result()
except Exception as e:
    print(e)  # division by zero

等待所有任务完成

一种方法是遍历所有的 future,调用它们 result 方法。

from concurrent.futures import ThreadPoolExecutor
import time


def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


executor = ThreadPoolExecutor()

future1 = executor.submit(task, "椎名真白", 16, 5)
future2 = executor.submit(task, "古明地觉", 16, 2)
future3 = executor.submit(task, "古明地恋", 15, 4)

# 这里是不会阻塞的
print(123)

for future in [future1, future2, future3]:
    print(future.result())

"""
123
name is 椎名真白, age is 16, sleep 5s
name is 古明地觉, age is 16, sleep 2s
name is 古明地恋, age is 15, sleep 4s
"""

或者使用 wait 方法:

from concurrent.futures import ThreadPoolExecutor, wait
import time


def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


executor = ThreadPoolExecutor()

future1 = executor.submit(task, "椎名真白", 16, 5)
future2 = executor.submit(task, "古明地觉", 16, 2)
future3 = executor.submit(task, "古明地恋", 15, 4)

# 这里是不会阻塞的
print(123)
"""
123
"""

# 直到所有的 future 完成,这里的 return_when 有三个可选
# FIRST_COMPLETED,当任意一个任务完成或者取消
# FIRST_EXCEPTION,当任意一个任务出现异常,如果都没出现异常等同于 ALL_COMPLETED
# ALL_COMPLETED,所有任务都完成,默认是这个值
# 会卡在这一步,直到所有的任务都完成
fs = wait([future1, future2, future3], return_when="ALL_COMPLETED")

# 此时返回的 fs 是 DoneAndNotDoneFutures 类型的 namedtuple
# 里面有两个值,一个是 done,一个是 not_done
print(fs.done)
"""
{<Future at 0x1df1400 state=finished returned str>, 
 <Future at 0x2f08e48 state=finished returned str>, 
 <Future at 0x9f7bf60 state=finished returned str>}
"""

print(fs.not_done)
"""
set()
"""

for f in fs.done:
    print(f.result())
    """
    name is 椎名真白, age is 16, sleep 5s
    name is 古明地觉, age is 16, sleep 2s
    name is 古明地恋, age is 15, sleep 4s
    """

使用上下文管理:

from concurrent.futures import ThreadPoolExecutor
import time


def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


with ThreadPoolExecutor() as executor:
    future1 = executor.submit(task, "椎名真白", 16, 5)
    future2 = executor.submit(task, "古明地觉", 16, 2)
    future3 = executor.submit(task, "古明地恋", 15, 4)

print(future1.result())
print(future2.result())
print(future3.result())
# 直到with语句全部执行完毕,才会往下走
print(123)
"""
name is 椎名真白, age is 16, sleep 5s
name is 古明地觉, age is 16, sleep 2s
name is 古明地恋, age is 15, sleep 4s
123
"""

或者调用 executor 的 shutdown:

from concurrent.futures import ThreadPoolExecutor
import time


def task(name, age, n):
    time.sleep(n)
    return f"name is {name}, age is {age}, sleep {n}s"


executor = ThreadPoolExecutor()
future1 = executor.submit(task, "椎名真白", 16, 5)
future2 = executor.submit(task, "古明地觉", 16, 2)
future3 = executor.submit(task, "古明地恋", 15, 4)
executor.shutdown()
print(future1.result())
print(future2.result())
print(future3.result())
print(123)
"""
name is 椎名真白, age is 16, sleep 5s
name is 古明地觉, age is 16, sleep 2s
name is 古明地恋, age is 15, sleep 4s
123
"""

以上就是 concurrent.futures 的基本用法, 这里为了方便介绍,使用时线程池执行器。如果想换成进程池,那么只需要将 ThreadPoolExecutor 换成 ProcessPoolExecutor。

from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ProcessPoolExecutor

并且在看这个模块的时候,会发现它里面的很多概念和 asyncio 基本是一致的,所以当发现概念上和 asyncio 有重叠时,不用担心,在 asyncio 里面学到的知识放在 concurrent.futures 里面也是适用的。

进程池执行器与 asyncio

我们已经了解了如何使用进程池同时运行 CPU 密集型操作,这些池适用于简单的用例,但 Python 在 concurrent.futures 模块中提供了进程池的一个抽象。该模块包含进程和线程的执行器,它们可以单独使用,也可与 asyncio 互操作。

因为 Python 的进程池 API 与进程强耦合,而 multiprocessing 是实现抢占式多任务的两种方法之一,另一种方法是多线程。如果我们需要轻松改变处理并发的方式,在进程和线程之间无缝切换怎么办?如果想要这样的设计,我们需要构建一个抽象,它包含将工作分配到资源池的核心内容,而不关心这些资源是进程、线程还是其他构造。

所以 concurrent.futures 模块通过 Executor 抽象类提供了这个抽象,该类定义了两种异步执行工作的方法:第一个是 submit,它将接收一个可调用对象并返回一个 future,类似于 pool.apply_async 方法。第二个则是 map,该方法将采用可调用函数和函数参数列表,然后异步执行列表中的每个参数。map 方法会返回调用结果的迭代器,类似于 gather,因为结果一旦完成就可被使用。

Executor 有两个具体的实现:ProcessPoolExecutor 和 ThreadPoolExecutor,分别用于进程池和线程池。

而我们上面已经介绍了这个模块,虽然介绍的是线程池执行器,但进程池执行器和它的用法是一样的。那么下面就来学习一下,如何将其挂接到 asyncio 中,以便使用其 API 函数的强大功能,例如 gather。

from concurrent.futures import ProcessPoolExecutor
import time

def count(to: int) -> int:
    start = time.perf_counter()
    counter = 0
    while counter < to:
        counter += 1
    end = time.perf_counter()
    print(f"在 {end - start} 秒内将 counter 增加到 {to}")
    return counter

if __name__ == '__main__':
    with ProcessPoolExecutor() as executor:
        numbers = [1, 3, 5, 22, 100000000]
        for result in executor.map(count, numbers):
            print(result)
"""
在 6.670000000097431e-07 秒内将 counter 增加到 1
在 5.419999999922709e-07 秒内将 counter 增加到 3
在 4.999999999588667e-07 秒内将 counter 增加到 5
在 1.0420000000066487e-06 秒内将 counter 增加到 22
1
3
5
22
在 4.5546771669999995 秒内将 counter 增加到 100000000
100000000
"""

当运行代码时,会看到对较小数字的调用将很快完成,并几乎立即输出。然而当参数为 100000000 的调用会花费很长时间,并在几个较小的数字之后输出。

虽然看起来这与 asyncio.as_completed 的工作方式相同,但迭代顺序是根据在数字列表中传递的顺序确定的。这意味着如果 100000000 是传入的第一个数字,我们将等待该调用完成,然后才能输出之前完成的其他结果,因为我们无法像 asyncio.as_completed 那样可以响应迅速。

带有异步事件循环的进程池执行器

现在我们已经了解了进程池执行器如何工作的基础知识,让我们看看如何将它挂接到 asyncio 事件循环中。

首先创建一个与 asyncio 一起使用的进程池执行器与刚刚介绍的没有什么不同,也就是说,可在上下文管理器中进行创建。一旦有了进程池,就可以在异步事件循环上使用一个特殊的方法,称为 run_in_executor。该方法接收一个可调用对象,并将在池(可以是线程池或进程池)中运行该可调用对象。然后它返回一个 awaitable 对象,我们可在 await 语句中使用它或传递给一个 API 函数,例如 gather。

from concurrent.futures import ProcessPoolExecutor
import asyncio
from asyncio.events import AbstractEventLoop

def count(to: int) -> int:
    counter = 0
    while counter < to:
        counter += 1
    return counter

async def main():
    with ProcessPoolExecutor() as pool:
        loop: AbstractEventLoop = asyncio.get_running_loop()
        numbers = [1, 3, 5, 22, 100000000]
        tasks = [loop.run_in_executor(pool, count, n) for n in numbers]
        results = await asyncio.gather(*tasks)
        print(results)

if __name__ == '__main__':
    asyncio.run(main())
"""
[1, 3, 5, 22, 100000000]
"""

首先创建一个进程池执行器,就像我们之前所做的那样。一旦有了进程池执行器,就可以进行 asyncio 事件循环,run_in_executor 是 AbstractEventLoop 上的一个方法。我们将这些耗时的调用扔到线程池,跟踪它返回的可等待对象,然后使用 asyncio.gather 获取这个列表,等待所有操作完成。

如有必要,也可使用 asyncio.as_completed 在子进程完成时立即获得结果,这将解决进程池的 map 方法中的问题。

async def main():
    with ProcessPoolExecutor() as pool:
        loop: AbstractEventLoop = asyncio.get_running_loop()
        numbers = [100000000, 1, 3, 5, 22]
        tasks = [loop.run_in_executor(pool, count, n) for n in numbers]
        for result in asyncio.as_completed(tasks):
            print(await result)

if __name__ == '__main__':
    asyncio.run(main())
"""
1
3
5
22
100000000
"""

需要注意的是:run_in_executor 返回的是一个可等待对象,如果希望它完成之后才能往下执行,那么可以直接放在 await 表达式中。

await loop.run_in_executor(线程池或进程池执行器, 函数, 参数1, 参数2, ...)

而说白了 run_in_executor 返回的就是 asyncio 的 future,另外 concurrent.futures 里面的也有 future,这两者在概念上非常相似,它们的设计理念是一样的。另外 asyncio 也提供了一个函数 wrapped_future,可以将 concurrent.futures 里面 future 转成 asyncio 里面 future。

不知道你有没有感到奇怪,我们扔到进程池执行器里面运行,为啥返回的是 asyncio 里面的 future 呢?其实就是就是通过该函数转换了。

from concurrent.futures import ProcessPoolExecutor, Future
import asyncio
from asyncio.events import AbstractEventLoop

def count(to: int) -> int:
    counter = 0
    while counter < to:
        counter += 1
    return counter

async def main():
    with ProcessPoolExecutor() as pool:
        # 提交任务返回 future(Future 对象)
        future = pool.submit(count, 100000000)
        # 虽然两个 future 在设计上很相似,但 concurrent.futures 里面的 future 不能直接 await
        print(type(future) is Future)
        
        # 转成 asyncio.Future
        future2 = asyncio.wrap_future(future)
        print(type(future2) is asyncio.Future)
        
        print(await future2)
        print(future2.result())
        print(future.result())

if __name__ == '__main__':
    asyncio.run(main())
"""
True
True
100000000
100000000
100000000
"""

比较简单,现在我们已经看到了使用 asyncio 进程池所需要的一切。接下来,让我们看看如何使用 multiprocessing 和 asyncio 提高实际性能。

使用 asyncio 解决 MapReduce 问题

为了理解可以用 MapReduce 解决的问题类型,我们先引入一个假设的问题,然后通过对它的理解,来解决一个类似的问题,这里将使用一个免费的大型数据集。

我们假设网站通过客户在线提交信息收到大量的文本数据,由于站点访问人数较多,这个客户反馈数据集的大小可能是 TB 级的,并且每天都在增长。为了更好地了解用户面临的常见问题,我们的任务是在这个数据集内找到最常用的词。一个简单的解决方案是使用单个进程循环每个评论,并跟踪每个单词出现的次数。这样做可以实现目标,但由于数据很大,因此串行执行该操作可能需要很长时间。有没有更快的方法可以解决此类问题呢?

这正是 MapReduce 可以解决的问题,MapReduce 编程模型首先将大型数据集划分为较小的块来解决问题。然后可以针对较小的数据子集而不是整个集合来解决问题(这被称为映射,mapping),因为我们将数据映射到部分结果。一旦解决了每个子集的问题,就可将结果组合成最终答案,此步骤称为归约(reducing),因为我们将多个答案归约为一个。

计算大型文本数据集中单词的频率是一个典型的 MapReduce 问题,如果我们有足够大的数据集,将其分成更小的块可以带来性能优势,因为每个映射操作都可以并行执行。

像 Hadoop 和 Spark 这样的系统是为真正的大型数据集在计算机集群中执行 MapReduce 操作而存在的,然而许多较小的工作负载可以通过 multiprocessing 在台计算机上完成。在本问中,我们将看到如何实现具有 multiprocessing 功能的 MapReduce 工作流,从而查找自 1500 年以来某些单词在文献中出现的频率。

简单的 MapReduce 示例

为了充分理解 MapReduce 是如何工作的,让我们来看一个具体例子。假设文件的每一行都有文本数据,对于这个例子,假设有四行文本需要处理:

I know what I know
I know that I know
I don't know that much
They don't know much

我们想要计算每个不同的单词在这个数据集中出现的次数,这个示例非常小,可以用一个简单的 for 循环来解决它,但此处使用 MapReduce 模型来处理它。

首先需要将这个数据集分割成更小的块,为简单起见,我们将一行文本定义为一个块。接下来需要定义映射操作,因为我们想要计算单词频率,所以使用空格对文本行进行分隔,这将得到由单词组成的数组。然后可以对其进行循环,跟踪字典文本行中每个不同的单词。

最后需要定义一个 reduce 操作,这将从 map 操作中获取一个或多个结果,并将它们组合成一个答案。

from typing import Dict
from functools import reduce

def map_frequency(text: str) -> Dict[str, int]:
    """
    计算每一行文本中,单词的频率
    """
    words = text.split(" ")
    frequencies = {}
    for word in words:
        if word in frequencies:
            frequencies[word] += 1
        else:
            frequencies[word] = 1

    return frequencies

def merge_dict(first: Dict[str, int],
               second: Dict[str, int]) -> Dict[str, int]:
    """
    对两行文本统计出的词频进行合并
    """
    keys = first.keys() | second.keys()
    return {key: first.get(key, 0) + second.get(key, 0) for key in keys}

lines = ["I know what I know", "I know that I know",
         "I don't know that much", "They don't know much"]

mapped_results = [map_frequency(line) for line in lines]
print(reduce(merge_dict, mapped_results))
"""
{'that': 2, 'know': 6, 'what': 1, 'much': 2, 'They': 1, "don't": 2, 'I': 5}
"""

现在我们已经了解了 MapReduce 的基础知识,并学习了一个示例,下面将看到如何将其应用到实际的数据集,在整个数据集中,使用 multiprocessing 可以带来性能提升。

Google Books Ngram 数据集

我们需要通过足够大的数据集来了解 MapReduce 与 multiprocessing 的优势,如果数据集太小,将看不到 MapReduce 带来的任何收益,并且可能会因为管理流程的开销而导致性能下降。

Google Books Ngram 是一个足够大的数据集,为了理解这个数据集是什么,我们首先定义一下什么是 n-gram。

n-gam 是来自自然语言处理的概念,它可以将文本中的连续 N 个单词或字符作为一个单元来进行处理。比如短语 "the fast dog" 就有 6 个 n-gram,怎么计算的呢?

  • 3 个 1-gram,分别是:the、fast 和 dog
  • 2 个 2-gram,分别是:the fast 和 fast dog
  • 1 个 3-gram,即:the fast dog

Google Books Ngram 数据集是对一组超过 8000000 本书的 n-gram 扫描,可追溯到 1500 年,占所有已出版书籍的 6% 以上。它计算不同 n-gram 在文本中出现的次数,并按出现的年份分组。该数据集以制表符,对 1-gram 到 5-gram 的所有内容进行了分隔,每一行都有一个 n-gram,它出现的年份、出现的次数以及出现在多少本书中。让我们看一下数据集中的前几个条目中关于单词 aardvark 的情况:

aardvark 1822 2 1
aardvark 1824 3 1
aardvark 1827 10 7

这意味着在 1822 年,aardvark 单词在 1 本书中出现了 2 次,然后 1827 年,aardvark 单词在 7 本不同的书中出现了 10 次。

该数据集大小约为 1.8 GB,里面有大量的单词,现在我们要解决一个问题:自 1500 年以来,aardvark 单词在文学作品中出现了多少次?

我们要使用的相关文件可以从这个地址下载。

import time

freqs = {}
with open("googlebooks-eng-all-1gram-20120701-a", encoding="utf-8") as f:
    lines = f.readlines()  # 将所有的行读到内存
    start = time.perf_counter()
    for line in lines:
        data = line.split("\t")  # 将每一行按照 \t 分隔
        word = data[0]  # 获取单词
        count = int(data[2])  # 获取次数
        if word in freqs:
            freqs[word] += count
        else:
            freqs[word] = count
    end = time.perf_counter()
    print(f"耗时: {end - start}")
"""
耗时: 22.4049245
"""

为了测试 CPU 密集型操作花费了多长时间,我们将只计算词频统计花费的时间,而不包括加载文件所需的时间。

第一步:切分数据集

让我们定义一个分区生成器,它可以获取大数据集,并分割为任意大小的块。

def partition(data: List, chunk_size: int) -> List:
    """
    将一个大型列表,以 chunk_size 为单位,分割成若干个小列表(chunk)
    """
    for i in range(0, len(data), chunk_size):
        yield data[i: i + chunk_size]

可使用这个分区生成器来创建大小为 chunk_size 的数据切片,从而传递给 map 函数,实现并行运行。

第二步:定义映射函数

def map_frequencies(chunk: List[str]) -> Dict[str, int]:
    """
    计算一个 chunk 中,单词的频率
    """
    frequencies = {}
    # chunk 的每一行都是如下格式:单词\t年份\t出现次数\t出现在多少本书中
    for line in chunk:
        word, _, count, _ = line.split("\t")
        if word in frequencies:
            frequencies[word] += int(count)
        else:
            frequencies[word] = int(count)

    return frequencies

接下来就可以创建一个进程池,并利用进程池中的资源为每个分区运行 map_frequencies。现在我们几乎拥有所需的一切,但还有一个问题:分区应该设置为多大?

对此没有一个简单的答案,一个经验法则是 Goldilocks 方法,即分区不宜过大或过小。分区大小不应该很小的原因是,当创建分区时,它们会被序列化并发送到 worker 进程,然后 worker 进程将它们解开。序列化和反序列化这些数据的过程可能会占用大量时间,如果我们经常这样做,就会抵消并行所带来的性能提升。因为数据量的固定的,那么当分区大小过小时,分区数量就会过大,这样就会产生大量的序列化和反序列化操作。

这个 HDFS 存储文件是一个道理,HDFS 存储文件的时候,如果块过小,那么 NameNode 的寻址时间甚至可能会超过文件的处理时间。

当然我们也不希望分区太大,否则可能无法充分利用机器的算力。例如有 16 个 CPU 内核,但分区太大导致只创建了两个分区,那么就浪费了可以并行运行工作负载的 14 个内核。

对于我们当前的数据集来说,总共 86618505 行,而我的机器有 24 个核心,所以我就把 chunk_size 定位 4000000,然后我们来编写代码测试一下。

import time
from typing import Dict, List
from functools import reduce
from concurrent.futures import ProcessPoolExecutor
import asyncio

def partition(data: List, chunk_size: int) -> List:
    """
    将一个大型列表,以 chunk_size 为单位,分割成若干个小列表(chunk)
    """
    for i in range(0, len(data), chunk_size):
        yield data[i: i + chunk_size]

def map_frequencies(chunk: List[str]) -> Dict[str, int]:
    """
    计算一个 chunk 中,单词的频率
    """
    frequencies = {}
    # chunk 的每一行都是如下格式:单词\t年份\t出现次数\t出现在多少本书中
    for line in chunk:
        word, _, count, _ = line.split("\t")
        if word in frequencies:
            frequencies[word] += int(count)
        else:
            frequencies[word] = int(count)

    return frequencies

def merge_dict(first: Dict[str, int],
               second: Dict[str, int]) -> Dict[str, int]:

    keys = first.keys() | second.keys()
    return {key: first.get(key, 0) + second.get(key, 0) for key in keys}

async def main(chunk_size):
    with open(r"googlebooks-eng-all-1gram-20120701-a", encoding="utf-8") as f:
        contents = f.readlines()
        loop = asyncio.get_running_loop()
        tasks = []
        start = time.perf_counter()
        with ProcessPoolExecutor() as pool:
            for chunk in partition(contents, chunk_size):
                tasks.append(
                    loop.run_in_executor(pool, map_frequencies, chunk)
                )
            middle_results = await asyncio.gather(*tasks)
            final_results = reduce(merge_dict, middle_results)
            print(f"Aardvark 总共出现了 {final_results['Aardvark']} 次")
            end = time.perf_counter()
            print(f"MapReduce 总耗时: {end - start}")

if __name__ == '__main__':
    asyncio.run(main(4000000))

在主协程中,我们创建一个进程池,并对数据进行分区。对于每个分区,我们在单独的进程中启动 map_frequencies 函数,然后使用 asyncio.gather 等待所有中间字典完成。一旦所有 map 操作完成,将运行 reduce 操作来生成最终结果。

共享数据和锁

在前面的章节中,我们讨论了这样一种情况:在多进程中,每个进程都有自己的内存,与其他进程分开。当共享要跟踪的状态时,我们将遇到挑战。如果它们的内存空间都是不同的,我们如何在进程之间共享数据呢?

multiprocessing 支持称为共享内存对象的概念,共享内存对象是分配给一组独立进程可以访问的一块内存。如下图所示,每个进程可以根据需要读取和写入该内存空间。

共享状态很复杂,如果实施不当,可能导致难以重现的错误。通常,如果可能最好避免共享状态,也就是说,只有在必要时才引入共享状态,如共享计数器。

共享数据和竞争条件

multiprocessing 支持两种共享数据方法:值和数组。值是奇异值,例如整数或浮点数。数组是奇异值的数组(array.array),我们可以在内存中共享的数据类型取决于 Python array 模块中定义的类型,可以通过 https:/docs.python.org/3/library/array.html#module-array 查看具体信息。

要创建一个值或数组,我们首先需要使用 array 模块中的类型代码,它只是一个字符。

  • 'b':有符号 8 位整数;
  • 'B':无符号 8 位整数;
  • 'h':有符号 16 位整数;
  • 'H':无符号 16 位整数;
  • 'i':有符号 32 位整数;
  • 'I':无符号 32 位整数;
  • 'l':有符号 32/64 位整数,具体是 32 位还是 64 位取决于系统;
  • 'L':无符号 32/64 位整数,具体是 32 位还是 64 位取决于系统;
  • 'q':有符号 64 位整数;
  • 'Q':无符号 64 位整数;
  • 'f':单精度浮点数;
  • 'd':双精度浮点数;
  • 'u':Unicode 字符

像 struct 等模块,凡是需要使用字符来表示类型的,标准都是上面那个。

让我们创建两个共享的数据:一个整数值和一个整数数组,然后再创建两个进程来并行增加这些共享数据。

from multiprocessing import Process, Value, Array

def increment_value(shared_int: Value):
    shared_int.value += 1

def increment_array(shared_array: Array):
    for index, integer in enumerate(shared_array):
        shared_array[index] = integer + 1

if __name__ == '__main__':
    integer = Value("i", 0)
    integer_array = Array("i", [0, 0])
    procs = [Process(target=increment_value, args=(integer,)),
             Process(target=increment_array, args=(integer_array,))]
    for p in procs:
        p.start()
    for p in procs:
        p.join()
    print(integer.value)  # 1
    print(integer_array[:])  # [1, 1]

我们创建了两个进程,一个用于递增共享整数值,另一个用于递增共享数组中的每个元素,一且两个子进程完成,就可以输出数据。

由于两条数据从未被不同的进程接触过,因此这段代码运行良好。但如果有多个进程修改相同的共享数据,这段代码还会正常运行吗?让我们通过创建两个进程,并行增加一个共享整数值来测试这一点。我们将在循环中重复运行这段代码,看看我们是否可以得到一致的结果。由于有两个进程,每个进程都将共享计数器加,因此一旦进程完成,我们希望共享值始终为 2。

from multiprocessing import Process, Value

def increment_value(shared_int: Value):
    shared_int.value += 1

if __name__ == '__main__':
    for _ in range(100):
        integer = Value("i", 0)
        procs = [Process(target=increment_value, args=(integer,)),
                 Process(target=increment_value, args=(integer,))]
        for p in procs:
            p.start()
        for p in procs:
            p.join()
        print(integer.value)
        assert integer.value == 2
"""
2
2
2
2
2
2
2
Traceback (most recent call last):
  File "....py", line 17, in <module>
    assert integer.value == 2
AssertionError
1
"""

结果并不总是为 2,有时结果是 1,为什么是这样?这种情况就被称为竞态条件。当一组操作的结果取决于哪个操作先完成时,就会出现竞态条件。可将这些操作想象成相互竞争的一组操作,如果操作以正确的顺序完成,那么一切正常。如果他们以错误的顺序执行完成,就会导致不可预测的结果。

那么在我们的示例中,竞争发生在哪里?很简答,问题在于增加一个值涉及读取和写入操作。要增加一个值,我们首先需要读取该值,将其加 1,然后将结果写回内存。每个进程在共享数据中看到的值完全取决于它读取共享值的时间,如果进程按以下顺序运行,则一切正常,如下图所示:

进程 1 在进程 2 读取该值并在发生竞争之前递增该值,由于进程 2 排在第二位,这意味着将看到正确的值 1,并将对它进行加 1,从而生成正确的最终值。

这种情况下,进程1和 2 都读取初始值零,然后将该值增加到 1,同时将其写回从而产生不正确的结果。

因为递增被写为两个操作(就算简单的赋值操作,也不是原子性的),就导致了这个问题,这使得它不是原子性的,或不是线程安全的。

这些类型的错误很棘手,因为它们通常难以重现。它们不像普通的错误,而取决于操作系统运行程序的顺序,当使用 multiprocessing 时,这是我们无法控制的。那么该如何修复这个错误呢?

使用锁进行同步

可通过同步访问我们想要修改的任何共享数据来避免竞态条件,同步访问是什么意思呢?重新审视前面的示例,这意味着将控制对任何共享数据的访问,以便所做的任何操作都以有意义的顺序完成执行。如果处于两个操作之间可能出现"平局"的情况,我们会明确阻止第二个操作运行,直到第一个操作完成,从而保证操作以一致的方完成执行。

可将其想象为终点线的裁判,看到平局即将发生,并告诉跑步者:等一下,一次通过一个选手!并选择一名跑者让其等待,而另一名选手越过终点线。

同步访问共享数据的一种机制是锁,也称为互斥锁。这些结构允许单个进程"锁定"一段代码,防止其他进程运行该代码,代码的锁定部分通常称为临界区。这意味着如果一个进程正在执行锁定部分的代码,而第二个进程试图访问该代码,则第二个进程需要等待,直到第一个进程完成锁定部分。

锁支持两种主要操作:获取和释放,当一个进程获得锁时,可以保证它是运行该代码段的唯一进程。一旦需要同步访问的代码部分完成执行,我们就释放锁。这允许其他进程获取锁,并运行临界区中的任何代码。如果一个进程试图运行被另一个进程锁定的代码,将发生阻塞,直到另一个进程释放该锁。

在上图中,进程 1 首先成功获取锁,并读取及递增共享数据。第二个进程也尝试获取锁,但在第一个进程释放锁之前,被阻止继续运行。一旦第一个进程释放锁,第二个进程就可以成功获取锁,并对共享数据进行增加。这可以防止竞争条件,因为锁可以防止多个进程同时读取和写入共享数据。

那么我们如何实现与共享数据的同步呢?multiprocessing API 的实现者考虑到这一点,并且很好地包含一个方法来获取值和数组的锁定。要获取锁,我们调用 acquire(),要释放锁,我们调用 release()。

from multiprocessing import Process, Value

def increment_value(shared_int: Value):
    shared_int.get_lock().acquire()  # 加锁
    shared_int.value += 1
    shared_int.get_lock().release()  # 解锁

if __name__ == '__main__':
    for _ in range(100):
        integer = Value("i", 0)
        procs = [Process(target=increment_value, args=(integer,)),
                 Process(target=increment_value, args=(integer,))]
        for p in procs:
            p.start()
        for p in procs:
            p.join()
        print(integer.value)
        assert integer.value == 2

运行上面代码时,得到的每个值都会是 2,因为我们已经修复了竞态条件。但请注意,锁也是上下文管理器,为了清理代码,可使用 with 块编写 increment_value,这将自动获取和释放锁:

def increment_value(shared_int: Value):
    with shared_int.get_lock():
        shared_int.value += 1

注意:我们采用了并发代码,虽然强制它是顺序执行,但却否定了并行运行的价值。这是一个重要的观察结果,通常是对并发性中的同步和共享数据的警告。为避免竞态条件,必须使并行代码在临界区中是连续的,而这会损害 multiprocessing 代码的性能。为此我们应该只锁定绝对必须锁定的部分,以便应用程序的其他部分可以并发执行。所以当遇到竞态条件错误时,很容易使用锁来保护代码,这可以修复问题,但可能降低应用程序的性能(为此我们应该只锁定关键部分,不要把代码全锁上了)。

与进程池共享数据

我们刚刚看到了如何在几个进程中共享数据,那么如何将这些知识应用到进程池中呢?进程池的运行方式与手动创建进程略有不同,这为共享数据带来了挑战。为什么会这样?

将任务提交到进程池时,它可能不会立即运行,因为池中的进程可能正忙于其他任务。那进程池如何处理这个问题呢?在后台,进程池执行器通过一个任务队列来管理它,向进程池提交任务时,它的参数会被序列化,并放入任务队列。然后,每个工作进程在准备好工作时,从队列中请求一个任务。当工作进程将任务从队列中拉出时,它会对参数反序列化,并开始执行任务。

根据定义,共享数据是在工作进程之间共享的,因此对它进行序列化和反序列化,从而在进程之间来回发送几乎没有意义。事实上,无论是 Value 还是 Array 对象都不能被序列化,所以如果像以前一样,尝试将共享数据作为参数传递给函数,会得到一个类似于 can't pickle Value objects 的错误。

为了处理这个问题,需要将共享计数器放在一个全局变量中,并以某种方式让工作进程知道它的存在。

from multiprocessing import Value
from concurrent.futures import ProcessPoolExecutor
import asyncio

shared_counter: Value

def init(counter: Value):
    global shared_counter
    shared_counter = counter

def increment():
    with shared_counter.get_lock():
        shared_counter.value += 1

async def main():
    counter = Value("i", 0)
    # 告诉进程池,创建进程的时候使用参数 counter 执行 init 函数
    with ProcessPoolExecutor(initializer=init, initargs=(counter,)) as pool:
        await asyncio.get_running_loop().run_in_executor(pool, increment)
        print(counter.value)

if __name__ == '__main__':
    asyncio.run(main())
"""
1
"""

首先定义一个全局变量 shared_counter,它将包含对我们创建的共享值对象的引用。在 init 函数中,接收一个 Value 并将 shared_counter 初始化为该值。然后在主进程中,创建计数器并将其初始化为 0,然后在创建进程池时将 init 函数和计数器传递给初始化程序以及 initargs 参数。这将为进程池创建的每个进程调用 init 函数,并把 shared_counter 正确初始化为在主协程中创建的那个对象。

你可能会问,我们为什么要这么麻烦?我们直接将全局变量初始化为 shared_counter: Value = Value('i', 0) 不就好啦。我们不能这样做的原因是,当创建每个进程时,创建它的脚本将在每个进程中再次运行。这意味着每个启动的进程都将执行 shared_counter: Value = Value('i', 0),如果有 100 个进程,将得到 100 个 shared_counter 值,每个值都设置为 0,导致不正确的结果。

小结

在本篇文章中,我们学习了以下内容:

  • 在进程池中并行运行多个 Python 函数;
  • 创建进程池执行器,以及并行运行 Python 函数。进程池执行器允许使用 gather 等 asyncio API 方法并发地运行多个进程,并等待结果;
  • 使用进程池和 asyncio 来解决 MapReduce 的问题。这个工作流不仅适用于 MapReduce,也可以用于任何 CPU 密集型工作,可将这些工作分割成多个较小的块;
  • 在多个进程之间共享状态,从而能够跟踪与我们启动的子进程相关的数据,例如状态计数器;
  • 使用锁来避免竞态条件,当多个进程试图同时访问数据时,就会出现竞态条件,这可能导致难以重现的 bug;