From 0313effc53ee00ee453f526367737214cc1c339b Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Thu, 23 Mar 2023 11:40:06 +0800 Subject: [PATCH] opt: refactor the worker queue for reusability and readability of code --- pool.go | 34 +++++++------- pool_func.go | 74 ++++++++++-------------------- worker.go | 16 +++++++ worker_func.go | 16 +++++++ worker_loop_queue.go | 24 +++++----- worker_loop_queue_test.go | 12 ++--- worker_array.go => worker_queue.go | 30 +++++++----- worker_stack.go | 18 ++++---- worker_stack_test.go | 4 +- 9 files changed, 119 insertions(+), 109 deletions(-) rename worker_array.go => worker_queue.go (53%) diff --git a/pool.go b/pool.go index d281a1c..cbf008f 100644 --- a/pool.go +++ b/pool.go @@ -45,7 +45,7 @@ type Pool struct { lock sync.Locker // workers is a slice that store the available workers. - workers workerArray + workers workerQueue // state is used to notice the pool to closed itself. state int32 @@ -91,16 +91,16 @@ func (p *Pool) purgeStaleWorkers(ctx context.Context) { } p.lock.Lock() - expiredWorkers := p.workers.retrieveExpiry(p.options.ExpiryDuration) + staleWorkers := p.workers.staleWorkers(p.options.ExpiryDuration) p.lock.Unlock() // Notify obsolete workers to stop. // This notification must be outside the p.lock, since w.task // may be blocking and may consume a lot of time if many workers // are located on non-local CPUs. - for i := range expiredWorkers { - expiredWorkers[i].task <- nil - expiredWorkers[i] = nil + for i := range staleWorkers { + staleWorkers[i].finish() + staleWorkers[i] = nil } // There might be a situation where all workers have been cleaned up(no worker is running), @@ -160,12 +160,12 @@ func (p *Pool) nowTime() time.Time { // NewPool generates an instance of ants pool. func NewPool(size int, options ...Option) (*Pool, error) { - opts := loadOptions(options...) - if size <= 0 { size = -1 } + opts := loadOptions(options...) + if !opts.DisablePurge { if expiry := opts.ExpiryDuration; expiry < 0 { return nil, ErrInvalidPoolExpiry @@ -193,9 +193,9 @@ func NewPool(size int, options ...Option) (*Pool, error) { if size == -1 { return nil, ErrInvalidPreAllocSize } - p.workers = newWorkerArray(loopQueueType, size) + p.workers = newWorkerArray(queueTypeLoopQueue, size) } else { - p.workers = newWorkerArray(stackType, 0) + p.workers = newWorkerArray(queueTypeStack, 0) } p.cond = sync.NewCond(p.lock) @@ -218,12 +218,11 @@ func (p *Pool) Submit(task func()) error { if p.IsClosed() { return ErrPoolClosed } - var w *goWorker - if w = p.retrieveWorker(); w == nil { - return ErrPoolOverload + if w := p.retrieveWorker(); w != nil { + w.inputFunc(task) + return nil } - w.task <- task - return nil + return ErrPoolOverload } // Running returns the number of workers currently running. @@ -331,14 +330,13 @@ func (p *Pool) addWaiting(delta int) { } // retrieveWorker returns an available worker to run the tasks. -func (p *Pool) retrieveWorker() (w *goWorker) { +func (p *Pool) retrieveWorker() (w worker) { spawnWorker := func() { w = p.workerCache.Get().(*goWorker) w.run() } p.lock.Lock() - w = p.workers.detach() if w != nil { // first try to fetch the worker from the queue p.lock.Unlock() @@ -401,8 +399,7 @@ func (p *Pool) revertWorker(worker *goWorker) bool { return false } - err := p.workers.insert(worker) - if err != nil { + if err := p.workers.insert(worker); err != nil { p.lock.Unlock() return false } @@ -410,5 +407,6 @@ func (p *Pool) revertWorker(worker *goWorker) bool { // Notify the invoker stuck in 'retrieveWorker()' of there is an available worker in the worker queue. p.cond.Signal() p.lock.Unlock() + return true } diff --git a/pool_func.go b/pool_func.go index 40227db..7f07037 100644 --- a/pool_func.go +++ b/pool_func.go @@ -44,7 +44,7 @@ type PoolWithFunc struct { lock sync.Locker // workers is a slice that store the available workers. - workers []*goWorkerWithFunc + workers workerQueue // state is used to notice the pool to closed itself. state int32 @@ -80,7 +80,6 @@ func (p *PoolWithFunc) purgeStaleWorkers(ctx context.Context) { atomic.StoreInt32(&p.purgeDone, 1) }() - var expiredWorkers []*goWorkerWithFunc for { select { case <-ctx.Done(): @@ -92,38 +91,17 @@ func (p *PoolWithFunc) purgeStaleWorkers(ctx context.Context) { break } - criticalTime := time.Now().Add(-p.options.ExpiryDuration) - p.lock.Lock() - idleWorkers := p.workers - n := len(idleWorkers) - l, r, mid := 0, n-1, 0 - for l <= r { - mid = (l + r) / 2 - if criticalTime.Before(idleWorkers[mid].recycleTime) { - r = mid - 1 - } else { - l = mid + 1 - } - } - i := r + 1 - expiredWorkers = append(expiredWorkers[:0], idleWorkers[:i]...) - if i > 0 { - m := copy(idleWorkers, idleWorkers[i:]) - for i := m; i < n; i++ { - idleWorkers[i] = nil - } - p.workers = idleWorkers[:m] - } + staleWorkers := p.workers.staleWorkers(p.options.ExpiryDuration) p.lock.Unlock() // Notify obsolete workers to stop. // This notification must be outside the p.lock, since w.task // may be blocking and may consume a lot of time if many workers // are located on non-local CPUs. - for i, w := range expiredWorkers { - w.args <- nil - expiredWorkers[i] = nil + for i := range staleWorkers { + staleWorkers[i].finish() + staleWorkers[i] = nil } // There might be a situation where all workers have been cleaned up(no worker is running), @@ -221,8 +199,11 @@ func NewPoolWithFunc(size int, pf func(interface{}), options ...Option) (*PoolWi if size == -1 { return nil, ErrInvalidPreAllocSize } - p.workers = make([]*goWorkerWithFunc, 0, size) + p.workers = newWorkerArray(queueTypeLoopQueue, size) + } else { + p.workers = newWorkerArray(queueTypeStack, 0) } + p.cond = sync.NewCond(p.lock) p.goPurge() @@ -243,12 +224,11 @@ func (p *PoolWithFunc) Invoke(args interface{}) error { if p.IsClosed() { return ErrPoolClosed } - var w *goWorkerWithFunc - if w = p.retrieveWorker(); w == nil { - return ErrPoolOverload + if w := p.retrieveWorker(); w != nil { + w.inputParam(args) + return nil } - w.args <- args - return nil + return ErrPoolOverload } // Running returns the number of workers currently running. @@ -302,11 +282,7 @@ func (p *PoolWithFunc) Release() { return } p.lock.Lock() - idleWorkers := p.workers - for _, w := range idleWorkers { - w.args <- nil - } - p.workers = nil + p.workers.reset() p.lock.Unlock() // There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent // those callers blocking infinitely. @@ -360,19 +336,15 @@ func (p *PoolWithFunc) addWaiting(delta int) { } // retrieveWorker returns an available worker to run the tasks. -func (p *PoolWithFunc) retrieveWorker() (w *goWorkerWithFunc) { +func (p *PoolWithFunc) retrieveWorker() (w worker) { spawnWorker := func() { w = p.workerCache.Get().(*goWorkerWithFunc) w.run() } p.lock.Lock() - idleWorkers := p.workers - n := len(idleWorkers) - 1 - if n >= 0 { // first try to fetch the worker from the queue - w = idleWorkers[n] - idleWorkers[n] = nil - p.workers = idleWorkers[:n] + w = p.workers.detach() + if w != nil { // first try to fetch the worker from the queue p.lock.Unlock() } else if capacity := p.Cap(); capacity == -1 || capacity > p.Running() { // if the worker queue is empty and we don't run out of the pool capacity, @@ -404,8 +376,7 @@ func (p *PoolWithFunc) retrieveWorker() (w *goWorkerWithFunc) { spawnWorker() return } - l := len(p.workers) - 1 - if l < 0 { + if w = p.workers.detach(); w == nil { if nw < p.Cap() { p.lock.Unlock() spawnWorker() @@ -413,9 +384,6 @@ func (p *PoolWithFunc) retrieveWorker() (w *goWorkerWithFunc) { } goto retry } - w = p.workers[l] - p.workers[l] = nil - p.workers = p.workers[:l] p.lock.Unlock() } return @@ -437,10 +405,14 @@ func (p *PoolWithFunc) revertWorker(worker *goWorkerWithFunc) bool { return false } - p.workers = append(p.workers, worker) + if err := p.workers.insert(worker); err != nil { + p.lock.Unlock() + return false + } // Notify the invoker stuck in 'retrieveWorker()' of there is an available worker in the worker queue. p.cond.Signal() p.lock.Unlock() + return true } diff --git a/worker.go b/worker.go index a2bee57..856b46f 100644 --- a/worker.go +++ b/worker.go @@ -74,3 +74,19 @@ func (w *goWorker) run() { } }() } + +func (w *goWorker) finish() { + w.task <- nil +} + +func (w *goWorker) when() time.Time { + return w.recycleTime +} + +func (w *goWorker) inputFunc(fn func()) { + w.task <- fn +} + +func (w *goWorker) inputParam(interface{}) { + panic("unreachable") +} diff --git a/worker_func.go b/worker_func.go index d15224d..2d93589 100644 --- a/worker_func.go +++ b/worker_func.go @@ -74,3 +74,19 @@ func (w *goWorkerWithFunc) run() { } }() } + +func (w *goWorkerWithFunc) finish() { + w.args <- nil +} + +func (w *goWorkerWithFunc) when() time.Time { + return w.recycleTime +} + +func (w *goWorkerWithFunc) inputFunc(func()) { + panic("unreachable") +} + +func (w *goWorkerWithFunc) inputParam(arg interface{}) { + w.args <- arg +} diff --git a/worker_loop_queue.go b/worker_loop_queue.go index cc73934..89e8112 100644 --- a/worker_loop_queue.go +++ b/worker_loop_queue.go @@ -3,8 +3,8 @@ package ants import "time" type loopQueue struct { - items []*goWorker - expiry []*goWorker + items []worker + expiry []worker head int tail int size int @@ -13,7 +13,7 @@ type loopQueue struct { func newWorkerLoopQueue(size int) *loopQueue { return &loopQueue{ - items: make([]*goWorker, size), + items: make([]worker, size), size: size, } } @@ -41,7 +41,7 @@ func (wq *loopQueue) isEmpty() bool { return wq.head == wq.tail && !wq.isFull } -func (wq *loopQueue) insert(worker *goWorker) error { +func (wq *loopQueue) insert(w worker) error { if wq.size == 0 { return errQueueIsReleased } @@ -49,7 +49,7 @@ func (wq *loopQueue) insert(worker *goWorker) error { if wq.isFull { return errQueueIsFull } - wq.items[wq.tail] = worker + wq.items[wq.tail] = w wq.tail++ if wq.tail == wq.size { @@ -62,7 +62,7 @@ func (wq *loopQueue) insert(worker *goWorker) error { return nil } -func (wq *loopQueue) detach() *goWorker { +func (wq *loopQueue) detach() worker { if wq.isEmpty() { return nil } @@ -78,7 +78,7 @@ func (wq *loopQueue) detach() *goWorker { return w } -func (wq *loopQueue) retrieveExpiry(duration time.Duration) []*goWorker { +func (wq *loopQueue) staleWorkers(duration time.Duration) []worker { expiryTime := time.Now().Add(-duration) index := wq.binarySearch(expiryTime) if index == -1 { @@ -115,7 +115,7 @@ func (wq *loopQueue) binarySearch(expiryTime time.Time) int { nlen = len(wq.items) // if no need to remove work, return -1 - if wq.isEmpty() || expiryTime.Before(wq.items[wq.head].recycleTime) { + if wq.isEmpty() || expiryTime.Before(wq.items[wq.head].when()) { return -1 } @@ -137,7 +137,7 @@ func (wq *loopQueue) binarySearch(expiryTime time.Time) int { mid = l + ((r - l) >> 1) // calculate true mid position from mapped mid position tmid = (mid + basel + nlen) % nlen - if expiryTime.Before(wq.items[tmid].recycleTime) { + if expiryTime.Before(wq.items[tmid].when()) { r = mid - 1 } else { l = mid + 1 @@ -152,10 +152,10 @@ func (wq *loopQueue) reset() { return } -Releasing: +retry: if w := wq.detach(); w != nil { - w.task <- nil - goto Releasing + w.finish() + goto retry } wq.items = wq.items[:0] wq.size = 0 diff --git a/worker_loop_queue_test.go b/worker_loop_queue_test.go index 2c6e9f0..8f54d90 100644 --- a/worker_loop_queue_test.go +++ b/worker_loop_queue_test.go @@ -45,7 +45,7 @@ func TestLoopQueue(t *testing.T) { err := q.insert(&goWorker{recycleTime: time.Now()}) assert.Error(t, err, "Enqueue, error") - q.retrieveExpiry(time.Second) + q.staleWorkers(time.Second) assert.EqualValuesf(t, 6, q.len(), "Len error: %d", q.len()) } @@ -118,14 +118,14 @@ func TestRotatedArraySearch(t *testing.T) { // [expiry4, time, time, time, time, expiry5, time, time, time, time/head/tail] assert.EqualValues(t, -1, q.binarySearch(expiry2), "index should be -1") - assert.EqualValues(t, 9, q.binarySearch(q.items[9].recycleTime), "index should be 9") + assert.EqualValues(t, 9, q.binarySearch(q.items[9].when()), "index should be 9") assert.EqualValues(t, 8, q.binarySearch(time.Now()), "index should be 8") } func TestRetrieveExpiry(t *testing.T) { size := 10 q := newWorkerLoopQueue(size) - expirew := make([]*goWorker, 0) + expirew := make([]worker, 0) u, _ := time.ParseDuration("1s") // test [ time+1s, time+1s, time+1s, time+1s, time+1s, time, time, time, time, time] @@ -138,7 +138,7 @@ func TestRetrieveExpiry(t *testing.T) { for i := 0; i < size/2; i++ { _ = q.insert(&goWorker{recycleTime: time.Now()}) } - workers := q.retrieveExpiry(u) + workers := q.staleWorkers(u) assert.EqualValues(t, expirew, workers, "expired workers aren't right") @@ -151,7 +151,7 @@ func TestRetrieveExpiry(t *testing.T) { expirew = expirew[:0] expirew = append(expirew, q.items[size/2:]...) - workers2 := q.retrieveExpiry(u) + workers2 := q.staleWorkers(u) assert.EqualValues(t, expirew, workers2, "expired workers aren't right") @@ -171,7 +171,7 @@ func TestRetrieveExpiry(t *testing.T) { expirew = append(expirew, q.items[0:3]...) expirew = append(expirew, q.items[size/2:]...) - workers3 := q.retrieveExpiry(u) + workers3 := q.staleWorkers(u) assert.EqualValues(t, expirew, workers3, "expired workers aren't right") } diff --git a/worker_array.go b/worker_queue.go similarity index 53% rename from worker_array.go rename to worker_queue.go index a45f89a..b0da75d 100644 --- a/worker_array.go +++ b/worker_queue.go @@ -13,27 +13,35 @@ var ( errQueueIsReleased = errors.New("the queue length is zero") ) -type workerArray interface { +type worker interface { + run() + finish() + when() time.Time + inputFunc(func()) + inputParam(interface{}) +} + +type workerQueue interface { len() int isEmpty() bool - insert(worker *goWorker) error - detach() *goWorker - retrieveExpiry(duration time.Duration) []*goWorker + insert(worker) error + detach() worker + staleWorkers(duration time.Duration) []worker reset() } -type arrayType int +type queueType int const ( - stackType arrayType = 1 << iota - loopQueueType + queueTypeStack queueType = 1 << iota + queueTypeLoopQueue ) -func newWorkerArray(aType arrayType, size int) workerArray { - switch aType { - case stackType: +func newWorkerArray(qType queueType, size int) workerQueue { + switch qType { + case queueTypeStack: return newWorkerStack(size) - case loopQueueType: + case queueTypeLoopQueue: return newWorkerLoopQueue(size) default: return newWorkerStack(size) diff --git a/worker_stack.go b/worker_stack.go index 3f6aa1d..f8c7fa9 100644 --- a/worker_stack.go +++ b/worker_stack.go @@ -3,13 +3,13 @@ package ants import "time" type workerStack struct { - items []*goWorker - expiry []*goWorker + items []worker + expiry []worker } func newWorkerStack(size int) *workerStack { return &workerStack{ - items: make([]*goWorker, 0, size), + items: make([]worker, 0, size), } } @@ -21,12 +21,12 @@ func (wq *workerStack) isEmpty() bool { return len(wq.items) == 0 } -func (wq *workerStack) insert(worker *goWorker) error { - wq.items = append(wq.items, worker) +func (wq *workerStack) insert(w worker) error { + wq.items = append(wq.items, w) return nil } -func (wq *workerStack) detach() *goWorker { +func (wq *workerStack) detach() worker { l := wq.len() if l == 0 { return nil @@ -39,7 +39,7 @@ func (wq *workerStack) detach() *goWorker { return w } -func (wq *workerStack) retrieveExpiry(duration time.Duration) []*goWorker { +func (wq *workerStack) staleWorkers(duration time.Duration) []worker { n := wq.len() if n == 0 { return nil @@ -64,7 +64,7 @@ func (wq *workerStack) binarySearch(l, r int, expiryTime time.Time) int { var mid int for l <= r { mid = (l + r) / 2 - if expiryTime.Before(wq.items[mid].recycleTime) { + if expiryTime.Before(wq.items[mid].when()) { r = mid - 1 } else { l = mid + 1 @@ -75,7 +75,7 @@ func (wq *workerStack) binarySearch(l, r int, expiryTime time.Time) int { func (wq *workerStack) reset() { for i := 0; i < wq.len(); i++ { - wq.items[i].task <- nil + wq.items[i].finish() wq.items[i] = nil } wq.items = wq.items[:0] diff --git a/worker_stack_test.go b/worker_stack_test.go index 94b6055..462589c 100644 --- a/worker_stack_test.go +++ b/worker_stack_test.go @@ -19,7 +19,7 @@ func TestNewWorkerStack(t *testing.T) { } func TestWorkerStack(t *testing.T) { - q := newWorkerArray(arrayType(-1), 0) + q := newWorkerArray(queueType(-1), 0) for i := 0; i < 5; i++ { err := q.insert(&goWorker{recycleTime: time.Now()}) @@ -45,7 +45,7 @@ func TestWorkerStack(t *testing.T) { } } assert.EqualValues(t, 12, q.len(), "Len error") - q.retrieveExpiry(time.Second) + q.staleWorkers(time.Second) assert.EqualValues(t, 6, q.len(), "Len error") }