Python: Concurrency: Multithreading

 24th March 2021 at 2:16pm

Python 的多进程实现主要在 threading 模块中。

模块设计

threading 模块的实现跟 Java 类似:

The design of this module is loosely based on Java’s threading model. However, where Java makes locks and condition variables basic behavior of every object, they are separate objects in Python. Python’s Thread class supports a subset of the behavior of Java’s Thread class; currently, there are no priorities, no thread groups, and threads cannot be destroyed, stopped, suspended, resumed, or interrupted. The static methods of Java’s Thread class, when implemented, are mapped to module-level functions.

GIL 影响

因为 GIL 的存在,CPython 中同时只能有一个线程在执行 Python 字节码,而一般 Python 只会在执行 I/O 释放 GIL。这使得 Python 的多线程不能有效利用多个 CPU 内核,适合做 I/O 密集的工作而不是 CPU 密集的。

但 GIL 并不保证某个线程函数执行到一半不会被中断,因此对于共享变量,仍然要使用锁等机制来保护。

Linux 平台实现

threading 在不同平台有不同的底层实现。在 Linux 下使用的是 pthread 模型。threading 模块中配合使用的各类 object,对应的底层实现如下:

线程本地变量

线程本地变量的概念是在 pthread 模型中提出的,可以在 TLPI: Ch31 理解。我认为对 Python 这种不需要手动内存管理、对象都在堆上生成的语言,这种机制没有太大价值

考虑这样一个程序:

from threading import Thread

count = 0

def incr(times):
    global count
    for i in range(times):
        count += 1

if __name__ == '__main__':
    t1 = Thread(target=incr, args=(1000000,))
    t2 = Thread(target=incr, args=(1000000,))

    t1.start()
    t2.start()
    t1.join()
    t2.join()

    print(count)     # 不会是 2000000,因为没有给临界区上锁

它的实现是有问题的,没有上锁。如果 incr() 不使用全局变量,那么两个线程各自加 100 万次,没有问题:

from threading import Thread

def incr(times):
    count = 0
    for i in range(times):
        count += 1
    print(count)      # 1000000

if __name__ == '__main__':
    t1 = Thread(target=incr, args=(1000000,))
    t2 = Thread(target=incr, args=(1000000,))

    t1.start()
    t2.start()
    t1.join()
    t2.join()

因为 t1 和 t2 各自跑的 incr() 函数中分别拥有自己的一份 count 变量。但是多线程是共享同个内存地址空间的,因此假如因为某种机制(比如 weakref)t2 可以 refer 到 t1 的 count 变量的值,并对它做修改,那就会出现混乱。

threading.local() 就是用来避免这种情况的。使用 threading.local() 得到的 对象,在不同的线程(包括主线程)中使用时,都只属于该线程本地使用:

from threading import Thread, local

data = local()

def incr(times):
    # 注意要使用 data.count = 0,而不能 data = 0;
    # 因为 Python 的世界里,name 只是个指向 value 的 label,一旦 data = 0,
    # data 就与 thread.local() 失去了关系
    data.count = 0
    for i in range(times):
        data.count += 1
    print(data.count)           # 1000000

if __name__ == '__main__':
    t1 = Thread(target=incr, args=(1000000,))
    t2 = Thread(target=incr, args=(1000000,))

    t1.start()
    t2.start()
    t1.join()
    t2.join()