From 2929cede543983f0b31f22dd9356de1900fdbf86 Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Sat, 19 May 2018 18:24:36 +0800 Subject: [PATCH] use sync.Pool --- ants.go | 8 ++++- ants_test.go | 10 ++++-- pool.go | 92 +++++++++++++++++++++++++++++++++++++++------------- worker.go | 6 ++-- 4 files changed, 89 insertions(+), 27 deletions(-) diff --git a/ants.go b/ants.go index 2968a96..4abef68 100644 --- a/ants.go +++ b/ants.go @@ -1,6 +1,8 @@ package ants -const DEFAULT_POOL_SIZE = 1000 +import "math" + +const DEFAULT_POOL_SIZE = math.MaxInt32 var defaultPool = NewPool(DEFAULT_POOL_SIZE) @@ -19,3 +21,7 @@ func Cap() int { func Free() int { return defaultPool.Free() } + +func Wait() { + defaultPool.Wait() +} diff --git a/ants_test.go b/ants_test.go index 24afdb1..8f43b8a 100644 --- a/ants_test.go +++ b/ants_test.go @@ -7,10 +7,14 @@ import ( "runtime" ) -var n = 100000 +var n = 10 func demoFunc() { - for i := 0; i < 1000000; i++ {} + var n int + for i := 0; i < 10000; i++ { + n += i + } + fmt.Printf("finish task with result:%d\n", n) } func TestDefaultPool(t *testing.T) { @@ -22,6 +26,8 @@ func TestDefaultPool(t *testing.T) { t.Logf("running workers number:%d", ants.Running()) t.Logf("free workers number:%d", ants.Free()) + ants.Wait() + mem := runtime.MemStats{} runtime.ReadMemStats(&mem) fmt.Println("memory usage:", mem.TotalAlloc/1024) diff --git a/pool.go b/pool.go index 3b76159..28b165e 100644 --- a/pool.go +++ b/pool.go @@ -4,28 +4,40 @@ import ( "runtime" "sync/atomic" "sync" + "math" ) type sig struct{} type f func() +//type er interface{} + type Pool struct { capacity int32 running int32 - tasks chan f - workers chan *Worker - destroy chan sig - m sync.Mutex + //tasks chan er + //workers chan er + tasks *sync.Pool + workers *sync.Pool + freeSignal chan sig + launchSignal chan sig + destroy chan sig + m *sync.Mutex + wg *sync.WaitGroup } func NewPool(size int) *Pool { p := &Pool{ capacity: int32(size), - tasks: make(chan f, size), - //workers: &sync.Pool{New: func() interface{} { return &Worker{} }}, - workers: make(chan *Worker, size), - destroy: make(chan sig, runtime.GOMAXPROCS(-1)), + //tasks: make(chan er, size), + //workers: make(chan er, size), + tasks: &sync.Pool{}, + workers: &sync.Pool{}, + freeSignal: make(chan sig, math.MaxInt32), + launchSignal: make(chan sig, math.MaxInt32), + destroy: make(chan sig, runtime.GOMAXPROCS(-1)), + wg: &sync.WaitGroup{}, } p.loop() return p @@ -38,8 +50,8 @@ func (p *Pool) loop() { go func() { for { select { - case task := <-p.tasks: - p.getWorker().sendTask(task) + case <-p.launchSignal: + p.getWorker().sendTask(p.tasks.Get().(f)) case <-p.destroy: return } @@ -52,7 +64,10 @@ func (p *Pool) Push(task f) error { if len(p.destroy) > 0 { return nil } - p.tasks <- task + //p.tasks <- task + p.tasks.Put(task) + p.launchSignal <- sig{} + p.wg.Add(1) return nil } func (p *Pool) Running() int { @@ -67,6 +82,10 @@ func (p *Pool) Cap() int { return int(atomic.LoadInt32(&p.capacity)) } +func (p *Pool) Wait() { + p.wg.Wait() +} + func (p *Pool) Destroy() error { p.m.Lock() defer p.m.Unlock() @@ -83,6 +102,10 @@ func (p *Pool) reachLimit() bool { } func (p *Pool) newWorker() *Worker { + if p.reachLimit() { + <-p.freeSignal + return p.getWorker() + } worker := &Worker{ pool: p, task: make(chan f), @@ -92,18 +115,43 @@ func (p *Pool) newWorker() *Worker { return worker } +//func (p *Pool) newWorker() *Worker { +// worker := &Worker{ +// pool: p, +// task: make(chan f), +// exit: make(chan sig), +// } +// worker.run() +// return worker +//} + +//func (p *Pool) getWorker() *Worker { +// defer atomic.AddInt32(&p.running, 1) +// var worker *Worker +// if p.reachLimit() { +// worker = (<-p.workers).(*Worker) +// } else { +// select { +// case w := <-p.workers: +// return w.(*Worker) +// default: +// worker = p.newWorker() +// } +// } +// return worker +//} + func (p *Pool) getWorker() *Worker { defer atomic.AddInt32(&p.running, 1) - var worker *Worker - if p.reachLimit() { - worker = <-p.workers - } else { - select { - case worker = <-p.workers: - return worker - default: - worker = p.newWorker() - } + if w := p.workers.Get(); w != nil { + return w.(*Worker) + } + return p.newWorker() +} + +func (p *Pool) PutWorker(worker *Worker) { + p.workers.Put(worker) + if p.reachLimit() { + p.freeSignal <- sig{} } - return worker } diff --git a/worker.go b/worker.go index 76e8248..f6b0098 100644 --- a/worker.go +++ b/worker.go @@ -14,9 +14,11 @@ func (w *Worker) run() { select { case f := <-w.task: f() - w.pool.workers <- w - atomic.AddInt32(&w.pool.running, -1) + //w.pool.workers <- w + w.pool.workers.Put(w) + w.pool.wg.Done() case <-w.exit: + atomic.AddInt32(&w.pool.running, -1) return } }