11-29 11,398 views
今天来说说sync库WaitGroup的具体实现。WaitGroup用于等待goroutine集合执行完成。main goroutine调用Add方法来设置要等待的goroutine的数量。然后每个goroutine运行,并在完成后调用Done,goroutine等待计数器减1。同时,Wait可用来阻塞main goroutine,直到所有的goroutine完成(Wait计数器为0)。其实,也可以使用CSP模式实现阻塞,但从性能来说,sync.WaitGroup性能更好。
其实WaitGroup的实现还是很简单的,接下来咱们对WaitGroup源码进行剖析。
type WaitGroup struct { // 使用字节数组作为64位整数,用高低位表达两个计数器。 // 剩下的4字节,用于和sema32位补位对齐(否则就按4字节对齐了) noCopy noCopy // 禁止拷贝 state1 [12]byte // 状态位; 高32位记录Add/Done计数器; 低32位记录Wait计数器; sema uint32 // 信号量 }
64位原子操作需要64位内存对齐,但32位编译器不能确保它。所以要按照64位进行对齐,WaitGroup将一个64位分成高32位和低32位保存状态(8byte),然后sema信号量32位(4byte)。按照目前分配内存方式共96位,但64位内存对齐需要128位。则需多分配出4byte用于内存对齐使用。
// 返回state1状态信息 // 返回64或32位目标地址 func (wg *WaitGroup) state() *uint64 { if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 { // 32位 return (*uint64)(unsafe.Pointer(&wg.state1)) } else { // 64位 return (*uint64)(unsafe.Pointer(&wg.state1[4])) } } // Add 操作会增加(或减少,比如Done)高位的计数器. func (wg *WaitGroup) Add(delta int) { // 累加高位计数器 statep := wg.state() state := atomic.AddUint64(statep, uint64(delta)<<32) // v 高位计数器, w低位计数器(等待计数) v := int32(state >> 32) w := uint32(state) // 对应 Done: 递减后, 如果计数器依然大于0,或者没有等待者,则直接返回 if v > 0 || w == 0 { return } // 计数等于0,且有等待者(v ==0 && w < 0) // 重置计数器后(两个),依次唤醒所有等待者 *statep = 0 for ; w != 0; w-- { runtime_Semrelease(&wg.sema) } } func (wg *WaitGroup) Done() { wg.Add(-1) }
WaitGroup主要是对二进制位移的操作,其它没有什么难点。上图就是具体的位移实现,在这里就不多阐述了。
// 等待操作,使用了Free-Lock模式,重试CAS模式,直到累加等待计数器成功后阻塞休眠 func (wg *WaitGroup) Wait() { statep := wg.state() // 基于CAS实现Free-Lock for { // v 高位计数器, w等待计数器 state := atomic.LoadUint64(statep) v := int32(state >> 32) w := uint32(state) // 计数为0,无需等待 if v == 0 { // Counter is 0, no need to wait. return } // 增加等待计数(低位) if atomic.CompareAndSwapUint64(statep, state, state+1) { // 休眠,等待唤醒信号 runtime_Semacquire(&wg.sema) return } } }
整个的执行流程大概是这样:
main goroutine 调用Add函数,v自增1,如果Wait函数获取到信号,并且判断v不等于0,则使用CAS原子操作对Wait计数器加1,然后进入休眠;如果Done获取到信号,则v-1。计算后,如果v大于0或者w等于0,则直接退出。由其它获取信号继续执行。如果v等于0而且Wait大于0,则循环Wait计数器,依次唤醒Wait休眠的信号。直到Wait函数内的等待计数器为空等于0或v计数器为空,退出WaitGroup,执行main goroutine。
state函数里关于32位和64位的注释是不是反了哇?
嗯嗯,写反了func (wg *WaitGroup) state() *uint64 { if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 { // 64位目标地址 return (*uint64)(unsafe.Pointer(&wg.state1)) } else { // 32位目标地址 return (*uint64)(unsafe.Pointer(&wg.state1
)) }}感觉纠正~~~~