Go sync WaitGroup 源码剖析

11-29 6,343 views

        今天来说说sync库WaitGroup的具体实现。WaitGroup用于等待goroutine集合执行完成。main goroutine调用Add方法来设置要等待的goroutine的数量。然后每个goroutine运行,并在完成后调用Done,goroutine等待计数器减1。同时,Wait可用来阻塞main goroutine,直到所有的goroutine完成(Wait计数器为0)。其实,也可以使用CSP模式实现阻塞,但从性能来说,sync.WaitGroup性能更好。

        其实WaitGroup的实现还是很简单的,接下来咱们对WaitGroup源码进行剖析。

type WaitGroup struct {
        // 使用字节数组作为64位整数,用高低位表达两个计数器。
	// 剩下的4字节,用于和sema32位补位对齐(否则就按4字节对齐了)

        noCopy noCopy    // 禁止拷贝
        state1 [12]byte    // 状态位; 高32位记录Add/Done计数器; 低32位记录Wait计数器;
        sema   uint32      // 信号量
}

         64位原子操作需要64位内存对齐,但32位编译器不能确保它。所以要按照64位进行对齐,WaitGroup将一个64位分成高32位和低32位保存状态(8byte),然后sema信号量32位(4byte)。按照目前分配内存方式共96位,但64位内存对齐需要128位。则需多分配出4byte用于内存对齐使用。

// 返回state1状态信息
// 返回64或32位目标地址
func (wg *WaitGroup) state() *uint64 {
        if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {

        	// 32位
                return (*uint64)(unsafe.Pointer(&wg.state1))
        } else {

        	// 64位
                return (*uint64)(unsafe.Pointer(&wg.state1[4]))
        }
}

// Add 操作会增加(或减少,比如Done)高位的计数器.
func (wg *WaitGroup) Add(delta int) {

	// 累加高位计数器
        statep := wg.state()
        state := atomic.AddUint64(statep, uint64(delta)<<32)


        // v 高位计数器, w低位计数器(等待计数)
        v := int32(state >> 32)
        w := uint32(state)
        
        // 对应 Done: 递减后, 如果计数器依然大于0,或者没有等待者,则直接返回
        if v > 0 || w == 0 {
                return
        }
        
        // 计数等于0,且有等待者(v ==0 &&  w < 0)
        // 重置计数器后(两个),依次唤醒所有等待者
        *statep = 0
        for ; w != 0; w-- {
                runtime_Semrelease(&wg.sema)
        }
}

func (wg *WaitGroup) Done() {
        wg.Add(-1)
}

        WaitGroup主要是对二进制位移的操作,其它没有什么难点。上图就是具体的位移实现,在这里就不多阐述了。

// 等待操作,使用了Free-Lock模式,重试CAS模式,直到累加等待计数器成功后阻塞休眠
func (wg *WaitGroup) Wait() {
        statep := wg.state()

        // 基于CAS实现Free-Lock
        for {
        		// v 高位计数器, w等待计数器
                state := atomic.LoadUint64(statep)
                v := int32(state >> 32)
                w := uint32(state)

                // 计数为0,无需等待
                if v == 0 {
                        // Counter is 0, no need to wait.
                        return
                }

                // 增加等待计数(低位)
                if atomic.CompareAndSwapUint64(statep, state, state+1) {
                		// 休眠,等待唤醒信号
                        runtime_Semacquire(&wg.sema)
                        return
                }
        }
}

         整个的执行流程大概是这样:

        main goroutine 调用Add函数,v自增1,如果Wait函数获取到信号,并且判断v不等于0,则使用CAS原子操作对Wait计数器加1,然后进入休眠;如果Done获取到信号,则v-1。计算后,如果v大于0或者w等于0,则直接退出。由其它获取信号继续执行。如果v等于0而且Wait大于0,则循环Wait计数器,依次唤醒Wait休眠的信号。直到Wait函数内的等待计数器为空等于0或v计数器为空,退出WaitGroup,执行main goroutine。

    • 嗯嗯,写反了func (wg *WaitGroup) state() *uint64 { if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 { // 64位目标地址 return (*uint64)(unsafe.Pointer(&wg.state1)) } else { // 32位目标地址 return (*uint64)(unsafe.Pointer(&wg.state1 )) }}感觉纠正~~~~