WaitGroup的用法和原理、常见错误

发布时间 2024-01-07 01:22:53作者: 李若盛开

WaitGroup的介绍

WaitGroup就是package sync用来做任务编排的一个并发原语,这个要解决的就是并发-等待的问题:现有一个goroutine A在检查点(chaeckpoint)等待一组goroutine全部完成,如果在执行任务的这些goroutine还没有全部完成,那么goroutine A就会阻塞在检查点,直到所有的goroutine都完成后才能继续执行。

WaitGroup的用法

创建一个WaitGroup对象后,可以使用Add方法向计数器添加值,然后使用Done方法从计数器减去值。最后,Wait方法将阻塞当前goroutine,直到计数器归零。

WaitGroup的实现原理

WaitGroup的实现原理比较简单。它有一个计数器,初始值为零。当调用Add方法时,它会将计数器加上传入的值。每次调用Done方法时,计数器减去一个值。最后,在调用Wait方法时,程序将阻塞,直到计数器归零。

WaitGroup的源代码:

type WaitGroup struct {
    counter int64
    mutex   sync.Mutex
    waiters sync.Cond
}

func (wg *WaitGroup) Add(delta int) {
    wg.mutex.Lock()
    defer wg.mutex.Unlock()
    wg.counter += int64(delta)
}

func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

func (wg *WaitGroup) Wait() {
    wg.mutex.Lock()
    defer wg.mutex.Unlock()
    for wg.counter > 0 {
        wg.waiters.Wait()
    }
}

在这个实现中,WaitGroup使用了一个互斥锁和一个条件变量,以确保在多个goroutine之间同步计数器的值。Add方法和Done方法都使用互斥锁来保护计数器的访问。Wait方法使用条件变量来等待计数器归零。

WaitGroup的实践

Demo1:使用WaitGroup等待多个goroutine执行完毕后再继续执行主函数

package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Worker %d starting\n", id)
    // 模拟工作时间
    for i := 0; i < 100000000; i++ {

    }
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup
    for i := 1; i <= 5; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }
    wg.Wait()
    fmt.Println("All workers done")
}
View Code

说明:该程序会启动5个goroutine,每个goroutine都会打印出自己的ID并模拟一段工作时间后结束。主函数会等待所有goroutine都执行完毕后再打印"All workers done"。

Demo2:使用WaitGroup等待多个http请求完成后再继续执行主函数

package main

import (
    "fmt"
    "io/ioutil"
    "net/http"
    "sync"
)

func fetch(url string, wg *sync.WaitGroup) {
    defer wg.Done()
    resp, err := http.Get(url)
    if err != nil {
        fmt.Println(err)
        return
    }
    defer resp.Body.Close()
    body, err := ioutil.ReadAll(resp.Body)
    if err != nil {
        fmt.Println(err)
        return
    }
    fmt.Printf("Fetched %s, Body size: %d\n", url, len(body))
}

func main() {
    var wg sync.WaitGroup
    urls := []string{
        "https://www.baidu.com",
        "https://www.google.com",
        "https://www.bing.com",
    }
    for _, url := range urls {
        wg.Add(1)
        go fetch(url, &wg)
    }
    wg.Wait()
    fmt.Println("All fetches done")
}
View Code

说明:该程序会启动3个goroutine,每个goroutine都会向一个URL发送http请求并打印出返回的body大小。主函数会等待所有goroutine都执行完毕后再打印"All fetches done"。

Demo3:使用WaitGroup实现协程池

package main

import (
    "fmt"
    "sync"
)

const (
    workerCount = 5
    taskCount   = 20
)

func worker(id int, tasks <-chan int, wg *sync.WaitGroup) {
    defer wg.Done()
    for task := range tasks {
        fmt.Printf("Worker %d processing task %d\n", id, task)
        // 模拟工作时间
        for i := 0; i < 100000000; i++ {

        }
    }
}

func main() {
    var wg sync.WaitGroup
    tasks := make(chan int, taskCount)
    for i := 1; i <= workerCount; i++ {
        wg.Add(1)
        go worker(i, tasks, &wg)
    }
    for i := 1; i <= taskCount; i++ {
        tasks <- i
    }
    close(tasks)
    wg.Wait()
    fmt.Println("All tasks done")
}
View Code

说明:该程序会创建一个由5个goroutine构成的协程池,并向任务通道中发送20个任务。每个goroutine会从任务通道中获取一个任务并处理,直到任务通道关闭。主函数会等待所有协程都执行完毕后再打印"All tasks done"。