深入理解 go sync.Waitgroup

本文基于 Go 1.19。

go 里面的 WaitGroup 是非常常见的一种并发控制方式,它可以让我们的代码等待一组 goroutine 的结束。 比如在主协程中等待几个子协程去做一些耗时的操作,如发起几个 HTTP 请求,然后等待它们的结果。

WaitGroup 示例

下面的代码展示了一个 goroutine 等待另外 2 个 goroutine 结束的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
func TestWaitgroup(t *testing.T) {
var wg sync.WaitGroup
// 计数器 +2
wg.Add(2)

go func() {
sendHttpRequest("https://baidu.com")
// 计数器 -1
wg.Done()
}()

go func() {
sendHttpRequest("https://baidu.com")
// 计数器 -1
wg.Done()
}()

// 阻塞。计数器为 0 的时候,Wait 返回
wg.Wait()
}

// 发起 HTTP GET 请求
func sendHttpRequest(url string) (string, error) {
method := "GET"

client := &http.Client{}
req, err := http.NewRequest(method, url, nil)

if err != nil {
return "", err
}

res, err := client.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()

body, err := io.ReadAll(res.Body)
if err != nil {
return "", err
}

return string(body), err
}

在这个例子中,我们做了如下事情:

  • 定义了一个 WaitGroup 对象 wg,调用 wg.Add(2) 将其计数器 +2
  • 启动两个新的 goroutine,在这两个 goroutine 中,使用 sendHttpRequest 函数发起了一个 HTTP 请求。
  • 在 HTTP 请求返回之后,调用 wg.Done 将计数器 -1
  • 在函数的最后,我们调用了 wg.Wait,这个方法会阻塞,直到 WaitGroup 的计数器的值为 0 才会解除阻塞状态。

WaitGroup 基本原理

WaitGroup 内部通过一个计数器来统计有多少协程被等待。这个计数器的值在我们启动 goroutine 之前先写入(使用 Add 方法), 然后在 goroutine 结束的时候,将这个计数器减 1(使用 Done 方法)。除此之外,在启动这些 goroutine 的协程中, 会调用 Wait 来进行等待,在 Wait 调用的地方会阻塞,直到 WaitGroup 内部的计数器减到 0。 也就实现了等待一组 goroutine 的目的

背景知识

在操作系统中,有多种实现进程/线程间同步的方式,如:test_and_setcompare_and_swap、互斥锁等。 除此之外,还有一种是信号量,它的功能类似于互斥锁,但是它能提供更为高级的方法,以便进程能够同步活动。

信号量

一个信号量(semaphore)S是一个整型变量,它除了初始化外只能通过两个标准的原子操作:wait()signal() 来访问。 操作 wait() 最初称为 P(荷兰语 proberen,测试);操作 signal() 最初称为 V(荷兰语 verhogen,增加),可按如下来定义 wait()

PV 原语。

1
2
3
4
5
wait(S) {
while (S <= 0)
; // 忙等待
S--;
}

可按如下来定义 signal()

1
2
3
signal(S) {
S++;
}

wait()signal() 操作中,信号量整数值的修改应不可分割地执行。也就是说,当一个进程修改信号量值时,没有其他进程能够同时修改同一信号量的值。

简单来说,信号量实现的功能是:

  • 当信号量>0 时,表示资源可用,则 wait 会对信号量执行减 1 操作。
  • 当信号量<=0 时,表示资源暂时不可用,获取信号量时,当前的进程/线程会阻塞,直到信号量为正时被唤醒。

WaitGroup 中的信号量

WaitGroup 中,使用了信号量来实现 goroutine 的阻塞以及唤醒:

  • 在调用 Wait 的地方,goroutine 会陷入阻塞,直到信号量大于等于 0 的时候解除阻塞状态,得以继续执行。
  • 在调用 Done 的时候,如果 WaitGroup 内的等待协程的计数器减到 0 的时候,信号量会进行递增,这样那些阻塞的协程会进行执行下去。

WaitGroup 数据结构

1
2
3
4
5
6
7
type WaitGroup struct {
noCopy noCopy

// 高 32 位为计数器,低 32 位为等待者数量
state atomic.Uint64
sema uint32
}

noCopy

我们发现,WaitGroup 中有一个字段 noCopy,顾名思义,它的目的是防止复制。 这个字段在运行时是没有什么影响的,但是我们通过 go vet 可以发现我们对 WaitGroup 的复制。 为什么不能复制呢?因为一旦复制,WaitGroup 内的计数器就不再准确了,比如下面这个例子:

1
2
3
4
5
6
7
8
9
10
func test(wg sync.WaitGroup) {
wg.Done()
}

func TestWaitGroup(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
test(wg)
wg.Wait()
}

go 里面的函数参数传递是值传递。调用 test(wg) 的时候将 WaitGroup 复制了一份。

在这个例子中,程序会永远阻塞下去,因为 test 中调用 wg.Done() 的时候,只是将 WaitGroup 副本的计数器减去了 1, 而 TestWaitGroup 里面的 WaitGroup 的计数器并没有发生改变,因此 Wait 会永远阻塞。

我们如果需要将 WaitGroup 作为参数,请传递指针:

1
2
3
func test(wg *sync.WaitGroup) {
wg.Done()
}

传递指针之后,我们在 test 中调用 wg.Done() 修改的就是 TestWaitGroup 里面同一个 WaitGroup。 从而,Wait 方法可以正常返回。

state

WaitGroup 里面的 state 是一个 64 位的 atomic.Uint64 类型,它的高 32 位用来保存 counter(也就是上面说的计数器),低 32 位用来保存 waiter(也就是阻塞在 Wait 上的 goroutine 数量。)

waitgroup_1

sema

WaitGroup 通过 sema 来记录信号量:

  • runtime_Semrelease 表示将信号量递增(对应信号量中的 signal 操作)
  • runtime_Semacquire 表示将信号量递减(对应信号量中的 wait 操作)

简单来说,在调用 runtime_Semacquire 的时候 goroutine 会阻塞,而调用 runtime_Semrelease 会唤醒阻塞在同一个信号量上的 goroutine。

WaitGroup 的三个基本操作

  • Add: 这会将 WaitGroup 里面的 counter 加上一个整数(也就是传递给 Add 的函数参数)。
  • Done: 这会将 WaitGroup 里面的 counter 减去 1。
  • Wait: 这会将 WaitGroup 里面的 waiter 加上 1,并且调用 Wait 的地方会阻塞。(有可能会有多个 goroutine 等待一个 WaitGroup

WaitGroup 的实现

Add 的实现

Add 做了下面两件事:

  1. delta 加到 state 的高 32 位上
  2. 如果 counter0 了,并且 waiter 大于 0,表示所有被等待的 goroutine 都完成了,而还有在等待的 goroutine,这会唤醒那些阻塞在 Wait 上的 goroutine。

源码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
func (wg *WaitGroup) Add(delta int) {
// wg.state 的计数器加上 delta
//(加到 state 的高 32 上)
state := wg.state.Add(uint64(delta) << 32) // 高 32 位加上 delta
v := int32(state >> 32) // 高 32 位(counter)
w := uint32(state) // 低 32 位(waiter)
// 计数器不能为负数(加上 delta 之后不能为负数,最小只能到 0)
if v < 0 {
panic("sync: negative WaitGroup counter")
}
// 正常使用情况下,是先调用 Add 再调用 Wait 的,这种情况下,w 是 0,v > 0
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// v > 0,计数器大于 0
// w == 0,没有在 Wait 的协程
// 说明还没有到唤醒 waiter 的时候
if v > 0 || w == 0 {
return
}

// Add 负数的时候,v 会减去对应的数值,减到最后 v 是 0。
// 计数器是 0,并且有等待的协程,现在要唤醒这些协程。

// 存在等待的协程时,goroutine 已将计数器设置为0。
// 现在不可能同时出现状态突变:
// - Add 不能与 Wait 同时发生,
// - 如果看到计数器==0,则 Wait 不会增加等待的协程。
// 仍然要做一个廉价的健康检查,以检测 WaitGroup 的误用。
if wg.state.Load() != state { // 不能在 Add 的同时调用 Wait
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}

// 将等待的协程数量设置为 0。
wg.state.Store(0)
for ; w != 0; w-- {
// signal,调用 Wait 的地方会解除阻塞
runtime_Semrelease(&wg.sema, false, 0) // goyield
}
}

Done 的实现

WaitGroup 里的 Done 其实只是对 Add 的调用,但是它的效果是,将计数器的值减去 1。 背后的含义是:一个被等待的协程执行完毕了

Wait 的实现

Wait 主要功能是阻塞当前的协程:

  1. Wait 会先判断计数器是否为 0,为 0 说明没有任何需要等待的协程,那么就可以直接返回了。
  2. 如果计数器还不是 0,说明有协程还没执行完,那么调用 Wait 的地方就需要被阻塞起来,等待所有的协程完成。

源码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
func (wg *WaitGroup) Wait() {
for {
// 获取当前计数器
state := wg.state.Load()
// 计数器
v := int32(state >> 32)
// waiter 数量
w := uint32(state)
// v 为 0,不需要等待,直接返回
if v == 0 {
// 计数器是 0,不需要等待
return
}

// 增加 waiter 数量。
// 调用一次 Wait,waiter 数量会加 1。
if wg.state.CompareAndSwap(state, state+1) {
// 这会阻塞,直到 sema (信号量)大于 0
runtime_Semacquire(&wg.sema) // goparkunlock
// state 不等 0
// wait 还没有返回又继续使用了 WaitGroup
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
// 解除阻塞状态了,可以返回了
return
}
// 状态没有修改成功(state 没有成功 +1),开始下一次尝试。
}
}

总结

  • WaitGroup 使用了信号量来实现了并发资源控制,sema 字段表示信号量。
  • 使用 runtime_Semacquire 会使得 goroutine 阻塞直到计数器减少至 0,而使用 runtime_Semrelease 会使得信号量递增,这等于是通知之前阻塞在信号量上的协程,告诉它们可以继续执行了。
  • WaitGroup 作为参数传递的时候,需要传递指针作为参数,否则在被调用函数内对 Add 或者 Done 的调用,在 caller 里面调用的 Wait 会观测不到。
  • WaitGroup 使用一个 64 位的数来保存计数器(高 32 位)和 waiter(低 32 位,正在等待的协程的数量)。
  • WaitGroup 使用 Add 增加计数器,使用 Done 来将计数器减 1,使用 Wait 来等待 goroutine。Wait 会阻塞直到计数器减少到 0