go语言sync.WaitGroup

发布时间 2023-04-19 16:34:44作者: 每天提醒自己要学习

go语言sync.WaitGroup

WaitGroup的主要作用是,让一个或多个goroutine去等待另一组goroutine结束

数据结构

waitGroup的数据结构有过改动,具体是哪个版本改的没有去找

1.13版本的结构

type WaitGroup struct {
	noCopy noCopy

	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers do not ensure it. So we allocate 12 bytes and then use
	// the aligned 8 bytes in them as state, and the other 4 as storage
	// for the sema.
	state1 [3]uint32
}

noCopy是用来防止复制的,可以用go vet工具进行检查,如果检查到WaitGroup被复制了,就会报错

state1是一个12字节的数据,主要包含了8字节的statep,statep前32位表示位counter,后32位表示为watiter,和4字节的sema

counter用于记录等待的goroutine数量,waiter用来记录被阻塞的goroutine数量,sema用于控制goroutine的唤醒和阻塞

在做 64 位的原子操作的时候必须要保证 64 位(8 字节)对齐,如果没有对齐的就会有问题,但是 32 位的编译器并不能保证 64 位对齐所以这里用一个 12 字节的 state1 字段来存储这两个状态,然后根据是否 8 字节对齐选择不同的保存顺序。

8字节对齐的顺序 counter waiter sema
8字节未对齐的顺序 sema counter waiter

  • 如果是64位机器直接用第一种顺序保存
  • 如果是32位机器
    • 如果刚好在分配内存时8字节对齐了,就取第一种顺序进行保存
    • 如果是4字节对齐的,那就选用第二种顺序保存,这样statep也是8字节对齐的

在取值的时候只需要判断state1字段的地址是否是8位对齐就可以

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}

1.20版本的结构

type WaitGroup struct {
	noCopy noCopy

	state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
	sema  uint32
}

1.20将原先的state1分成了两个字段,state和sema,通过atomic.Uint64来保证内存对齐

type Uint64 struct {
	_ noCopy
	_ align64
	v uint64
}

atomic.Uint64中嵌入了一个align64结构体

// align64 may be added to structs that must be 64-bit aligned.
// This struct is recognized by a special case in the compiler
// and will not work if copied to any other package.
type align64 struct{}

编译器在编译时检查到此字段会特殊处理进行内存对齐操作,来保证是8字节对齐

Add

func (wg *WaitGroup) Add(delta int) {
	if race.Enabled {
		if delta < 0 {
			// Synchronize decrements with Wait.
			race.ReleaseMerge(unsafe.Pointer(wg))
		}
		race.Disable()
		defer race.Enable()
	}
	state := wg.state.Add(uint64(delta) << 32)
	v := int32(state >> 32)
	w := uint32(state)
	if race.Enabled && delta > 0 && v == int32(delta) {
		// The first increment must be synchronized with Wait.
		// Need to model this as a read, because there can be
		// several concurrent wg.counter transitions from 0.
		race.Read(unsafe.Pointer(&wg.sema))
	}
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	if v > 0 || w == 0 {
		return
	}
	// This goroutine has set counter to 0 when waiters > 0.
	// Now there can't be concurrent mutations of state:
	// - Adds must not happen concurrently with Wait,
	// - Wait does not increment waiters if it sees counter == 0.
	// Still do a cheap sanity check to detect WaitGroup misuse.
	if wg.state.Load() != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// Reset waiters count to 0.
	wg.state.Store(0)
	for ; w != 0; w-- {
		runtime_Semrelease(&wg.sema, false, 0)
	}
}

Add流程:
除去竞态检测的相关代码

  1. 其次给state原子增加delta,取高32位counter和低32位waiter
  2. 如果等待的goroutine数量小于0,直接报错
  3. 如果waiter不为0,delta大于0,并且counter等于delta,表示还没有add就开始wait了,直接报错
  4. 如果counter大于0,或者waiter等于0,那么说明还没到唤醒waiter的时候,直接返回,如果v==0了,说明可以唤醒等待的goroutine了,后面的流程都是用于唤醒waiter
  5. 如果取state内存里的值不等于当前state,说明再调用wait方法后,又调用了add方法,直接报错
  6. 将state值设为0,循环唤醒阻塞的waiter

Done

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

只是对add方法的简单封装

Wait

// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
	if race.Enabled {
		race.Disable()
	}
	for {
		state := wg.state.Load()
		v := int32(state >> 32)
		w := uint32(state)
		if v == 0 {
			// Counter is 0, no need to wait.
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		// Increment waiters count.
		if wg.state.CompareAndSwap(state, state+1) {
			if race.Enabled && w == 0 {
				// Wait must be synchronized with the first Add.
				// Need to model this is as a write to race with the read in Add.
				// As a consequence, can do the write only for the first waiter,
				// otherwise concurrent Waits will race with each other.
				race.Write(unsafe.Pointer(&wg.sema))
			}
			runtime_Semacquire(&wg.sema)
			if wg.state.Load() != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
	}
}

Wait流程:
使用for语句循环执行

  1. 读取state,取高32位counter和低32位waiter
  2. 如果counter为0说明不需要wait,直接返回
  3. cas方法给state加上一个waiter,并调用runtime_Semacquire方法阻塞当前goroutine
  4. 当被唤醒后执行后续的方法,如果发现state的值不为0,直接报错,否则直接返回

总结

  1. WaitGroup可以用来控制开启的协程数量,也可以用来控制一组协程等待另一组协程的完成
  2. 在WaitGroup的设计上考虑到了内存对齐的问题,在无锁的原子访问时,要考虑此问题