Go简单自定义协程池

发布时间 2023-12-27 10:02:38作者: 朝阳1
package main

import (
	"fmt"
	"sync"
)

type Task struct {
	f func() error
}

var wg sync.WaitGroup

type Pool struct {
	//任务通道
	JobQueue chan Task
	//worker通道
	WorkerQueue chan chan Task
	//worker数量
	MaxWorkers int
}

func NewPool(maxWorkers int) *Pool {
	return &Pool{
		JobQueue:    make(chan Task, 10),
		WorkerQueue: make(chan chan Task, maxWorkers),
		MaxWorkers:  maxWorkers,
	}
}

func (p *Pool) Run() {
	for i := 0; i < p.MaxWorkers; i++ {
		worker := NewWorker(i+1, p.WorkerQueue)
		worker.Start()
	}

	go p.dispatch()
}

func (p *Pool) dispatch() {
	for {
		select {
		case job := <-p.JobQueue:
			fmt.Println("new job")
			worker := <-p.WorkerQueue
			fmt.Println("append job")
			worker <- job
			fmt.Println("after run job")
		}
	}
}

func (p *Pool) AddTask(task Task) {
	p.JobQueue <- task
}

type Worker struct {
	id          int
	WorkerQueue chan chan Task
	JobChannel  chan Task
	quitChan    chan struct{}
}

func NewWorker(id int, workerQueue chan chan Task) Worker {
	fmt.Println("newWorker")
	return Worker{
		id:          id,
		WorkerQueue: workerQueue,
		JobChannel:  make(chan Task),
		quitChan:    make(chan struct{}),
	}
}

func (w *Worker) Start() {
	fmt.Println("worker start")
	go func() {
		for {
			//将自己的jobChannel放入worker队列中
			w.WorkerQueue <- w.JobChannel
			select {
			case task := <-w.JobChannel:
				fmt.Printf("worker%d start job", w.id)
				task.f()
				fmt.Printf("worker%d finished job", w.id)
			case <-w.quitChan:
				fmt.Printf("worker%d quit", w.id)
				return
			}
		}
	}()
}

func (w *Worker) Stop() {
	go func() {
		w.quitChan <- struct{}{}
	}()
}

func Hello() error {
	fmt.Println("Hello World")
	wg.Done()
	return nil
}

func main() {
	p := NewPool(5)
	p.Run()

	for i := 0; i < 10; i++ {
		task := Task{
			f: Hello,
		}
		wg.Add(1)
		p.AddTask(task)
	}
	wg.Wait()
}