opt: refactor the worker queue for reusability and readability of code

This commit is contained in:
Andy Pan 2023-03-23 11:40:06 +08:00
parent b880b659f5
commit 0313effc53
9 changed files with 119 additions and 109 deletions

34
pool.go
View File

@ -45,7 +45,7 @@ type Pool struct {
lock sync.Locker lock sync.Locker
// workers is a slice that store the available workers. // workers is a slice that store the available workers.
workers workerArray workers workerQueue
// state is used to notice the pool to closed itself. // state is used to notice the pool to closed itself.
state int32 state int32
@ -91,16 +91,16 @@ func (p *Pool) purgeStaleWorkers(ctx context.Context) {
} }
p.lock.Lock() p.lock.Lock()
expiredWorkers := p.workers.retrieveExpiry(p.options.ExpiryDuration) staleWorkers := p.workers.staleWorkers(p.options.ExpiryDuration)
p.lock.Unlock() p.lock.Unlock()
// Notify obsolete workers to stop. // Notify obsolete workers to stop.
// This notification must be outside the p.lock, since w.task // This notification must be outside the p.lock, since w.task
// may be blocking and may consume a lot of time if many workers // may be blocking and may consume a lot of time if many workers
// are located on non-local CPUs. // are located on non-local CPUs.
for i := range expiredWorkers { for i := range staleWorkers {
expiredWorkers[i].task <- nil staleWorkers[i].finish()
expiredWorkers[i] = nil staleWorkers[i] = nil
} }
// There might be a situation where all workers have been cleaned up(no worker is running), // There might be a situation where all workers have been cleaned up(no worker is running),
@ -160,12 +160,12 @@ func (p *Pool) nowTime() time.Time {
// NewPool generates an instance of ants pool. // NewPool generates an instance of ants pool.
func NewPool(size int, options ...Option) (*Pool, error) { func NewPool(size int, options ...Option) (*Pool, error) {
opts := loadOptions(options...)
if size <= 0 { if size <= 0 {
size = -1 size = -1
} }
opts := loadOptions(options...)
if !opts.DisablePurge { if !opts.DisablePurge {
if expiry := opts.ExpiryDuration; expiry < 0 { if expiry := opts.ExpiryDuration; expiry < 0 {
return nil, ErrInvalidPoolExpiry return nil, ErrInvalidPoolExpiry
@ -193,9 +193,9 @@ func NewPool(size int, options ...Option) (*Pool, error) {
if size == -1 { if size == -1 {
return nil, ErrInvalidPreAllocSize return nil, ErrInvalidPreAllocSize
} }
p.workers = newWorkerArray(loopQueueType, size) p.workers = newWorkerArray(queueTypeLoopQueue, size)
} else { } else {
p.workers = newWorkerArray(stackType, 0) p.workers = newWorkerArray(queueTypeStack, 0)
} }
p.cond = sync.NewCond(p.lock) p.cond = sync.NewCond(p.lock)
@ -218,12 +218,11 @@ func (p *Pool) Submit(task func()) error {
if p.IsClosed() { if p.IsClosed() {
return ErrPoolClosed return ErrPoolClosed
} }
var w *goWorker if w := p.retrieveWorker(); w != nil {
if w = p.retrieveWorker(); w == nil { w.inputFunc(task)
return ErrPoolOverload
}
w.task <- task
return nil return nil
}
return ErrPoolOverload
} }
// Running returns the number of workers currently running. // Running returns the number of workers currently running.
@ -331,14 +330,13 @@ func (p *Pool) addWaiting(delta int) {
} }
// retrieveWorker returns an available worker to run the tasks. // retrieveWorker returns an available worker to run the tasks.
func (p *Pool) retrieveWorker() (w *goWorker) { func (p *Pool) retrieveWorker() (w worker) {
spawnWorker := func() { spawnWorker := func() {
w = p.workerCache.Get().(*goWorker) w = p.workerCache.Get().(*goWorker)
w.run() w.run()
} }
p.lock.Lock() p.lock.Lock()
w = p.workers.detach() w = p.workers.detach()
if w != nil { // first try to fetch the worker from the queue if w != nil { // first try to fetch the worker from the queue
p.lock.Unlock() p.lock.Unlock()
@ -401,8 +399,7 @@ func (p *Pool) revertWorker(worker *goWorker) bool {
return false return false
} }
err := p.workers.insert(worker) if err := p.workers.insert(worker); err != nil {
if err != nil {
p.lock.Unlock() p.lock.Unlock()
return false return false
} }
@ -410,5 +407,6 @@ func (p *Pool) revertWorker(worker *goWorker) bool {
// Notify the invoker stuck in 'retrieveWorker()' of there is an available worker in the worker queue. // Notify the invoker stuck in 'retrieveWorker()' of there is an available worker in the worker queue.
p.cond.Signal() p.cond.Signal()
p.lock.Unlock() p.lock.Unlock()
return true return true
} }

View File

@ -44,7 +44,7 @@ type PoolWithFunc struct {
lock sync.Locker lock sync.Locker
// workers is a slice that store the available workers. // workers is a slice that store the available workers.
workers []*goWorkerWithFunc workers workerQueue
// state is used to notice the pool to closed itself. // state is used to notice the pool to closed itself.
state int32 state int32
@ -80,7 +80,6 @@ func (p *PoolWithFunc) purgeStaleWorkers(ctx context.Context) {
atomic.StoreInt32(&p.purgeDone, 1) atomic.StoreInt32(&p.purgeDone, 1)
}() }()
var expiredWorkers []*goWorkerWithFunc
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -92,38 +91,17 @@ func (p *PoolWithFunc) purgeStaleWorkers(ctx context.Context) {
break break
} }
criticalTime := time.Now().Add(-p.options.ExpiryDuration)
p.lock.Lock() p.lock.Lock()
idleWorkers := p.workers staleWorkers := p.workers.staleWorkers(p.options.ExpiryDuration)
n := len(idleWorkers)
l, r, mid := 0, n-1, 0
for l <= r {
mid = (l + r) / 2
if criticalTime.Before(idleWorkers[mid].recycleTime) {
r = mid - 1
} else {
l = mid + 1
}
}
i := r + 1
expiredWorkers = append(expiredWorkers[:0], idleWorkers[:i]...)
if i > 0 {
m := copy(idleWorkers, idleWorkers[i:])
for i := m; i < n; i++ {
idleWorkers[i] = nil
}
p.workers = idleWorkers[:m]
}
p.lock.Unlock() p.lock.Unlock()
// Notify obsolete workers to stop. // Notify obsolete workers to stop.
// This notification must be outside the p.lock, since w.task // This notification must be outside the p.lock, since w.task
// may be blocking and may consume a lot of time if many workers // may be blocking and may consume a lot of time if many workers
// are located on non-local CPUs. // are located on non-local CPUs.
for i, w := range expiredWorkers { for i := range staleWorkers {
w.args <- nil staleWorkers[i].finish()
expiredWorkers[i] = nil staleWorkers[i] = nil
} }
// There might be a situation where all workers have been cleaned up(no worker is running), // There might be a situation where all workers have been cleaned up(no worker is running),
@ -221,8 +199,11 @@ func NewPoolWithFunc(size int, pf func(interface{}), options ...Option) (*PoolWi
if size == -1 { if size == -1 {
return nil, ErrInvalidPreAllocSize return nil, ErrInvalidPreAllocSize
} }
p.workers = make([]*goWorkerWithFunc, 0, size) p.workers = newWorkerArray(queueTypeLoopQueue, size)
} else {
p.workers = newWorkerArray(queueTypeStack, 0)
} }
p.cond = sync.NewCond(p.lock) p.cond = sync.NewCond(p.lock)
p.goPurge() p.goPurge()
@ -243,12 +224,11 @@ func (p *PoolWithFunc) Invoke(args interface{}) error {
if p.IsClosed() { if p.IsClosed() {
return ErrPoolClosed return ErrPoolClosed
} }
var w *goWorkerWithFunc if w := p.retrieveWorker(); w != nil {
if w = p.retrieveWorker(); w == nil { w.inputParam(args)
return ErrPoolOverload
}
w.args <- args
return nil return nil
}
return ErrPoolOverload
} }
// Running returns the number of workers currently running. // Running returns the number of workers currently running.
@ -302,11 +282,7 @@ func (p *PoolWithFunc) Release() {
return return
} }
p.lock.Lock() p.lock.Lock()
idleWorkers := p.workers p.workers.reset()
for _, w := range idleWorkers {
w.args <- nil
}
p.workers = nil
p.lock.Unlock() p.lock.Unlock()
// There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent // There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent
// those callers blocking infinitely. // those callers blocking infinitely.
@ -360,19 +336,15 @@ func (p *PoolWithFunc) addWaiting(delta int) {
} }
// retrieveWorker returns an available worker to run the tasks. // retrieveWorker returns an available worker to run the tasks.
func (p *PoolWithFunc) retrieveWorker() (w *goWorkerWithFunc) { func (p *PoolWithFunc) retrieveWorker() (w worker) {
spawnWorker := func() { spawnWorker := func() {
w = p.workerCache.Get().(*goWorkerWithFunc) w = p.workerCache.Get().(*goWorkerWithFunc)
w.run() w.run()
} }
p.lock.Lock() p.lock.Lock()
idleWorkers := p.workers w = p.workers.detach()
n := len(idleWorkers) - 1 if w != nil { // first try to fetch the worker from the queue
if n >= 0 { // first try to fetch the worker from the queue
w = idleWorkers[n]
idleWorkers[n] = nil
p.workers = idleWorkers[:n]
p.lock.Unlock() p.lock.Unlock()
} else if capacity := p.Cap(); capacity == -1 || capacity > p.Running() { } else if capacity := p.Cap(); capacity == -1 || capacity > p.Running() {
// if the worker queue is empty and we don't run out of the pool capacity, // if the worker queue is empty and we don't run out of the pool capacity,
@ -404,8 +376,7 @@ func (p *PoolWithFunc) retrieveWorker() (w *goWorkerWithFunc) {
spawnWorker() spawnWorker()
return return
} }
l := len(p.workers) - 1 if w = p.workers.detach(); w == nil {
if l < 0 {
if nw < p.Cap() { if nw < p.Cap() {
p.lock.Unlock() p.lock.Unlock()
spawnWorker() spawnWorker()
@ -413,9 +384,6 @@ func (p *PoolWithFunc) retrieveWorker() (w *goWorkerWithFunc) {
} }
goto retry goto retry
} }
w = p.workers[l]
p.workers[l] = nil
p.workers = p.workers[:l]
p.lock.Unlock() p.lock.Unlock()
} }
return return
@ -437,10 +405,14 @@ func (p *PoolWithFunc) revertWorker(worker *goWorkerWithFunc) bool {
return false return false
} }
p.workers = append(p.workers, worker) if err := p.workers.insert(worker); err != nil {
p.lock.Unlock()
return false
}
// Notify the invoker stuck in 'retrieveWorker()' of there is an available worker in the worker queue. // Notify the invoker stuck in 'retrieveWorker()' of there is an available worker in the worker queue.
p.cond.Signal() p.cond.Signal()
p.lock.Unlock() p.lock.Unlock()
return true return true
} }

View File

@ -74,3 +74,19 @@ func (w *goWorker) run() {
} }
}() }()
} }
func (w *goWorker) finish() {
w.task <- nil
}
func (w *goWorker) when() time.Time {
return w.recycleTime
}
func (w *goWorker) inputFunc(fn func()) {
w.task <- fn
}
func (w *goWorker) inputParam(interface{}) {
panic("unreachable")
}

View File

@ -74,3 +74,19 @@ func (w *goWorkerWithFunc) run() {
} }
}() }()
} }
func (w *goWorkerWithFunc) finish() {
w.args <- nil
}
func (w *goWorkerWithFunc) when() time.Time {
return w.recycleTime
}
func (w *goWorkerWithFunc) inputFunc(func()) {
panic("unreachable")
}
func (w *goWorkerWithFunc) inputParam(arg interface{}) {
w.args <- arg
}

View File

@ -3,8 +3,8 @@ package ants
import "time" import "time"
type loopQueue struct { type loopQueue struct {
items []*goWorker items []worker
expiry []*goWorker expiry []worker
head int head int
tail int tail int
size int size int
@ -13,7 +13,7 @@ type loopQueue struct {
func newWorkerLoopQueue(size int) *loopQueue { func newWorkerLoopQueue(size int) *loopQueue {
return &loopQueue{ return &loopQueue{
items: make([]*goWorker, size), items: make([]worker, size),
size: size, size: size,
} }
} }
@ -41,7 +41,7 @@ func (wq *loopQueue) isEmpty() bool {
return wq.head == wq.tail && !wq.isFull return wq.head == wq.tail && !wq.isFull
} }
func (wq *loopQueue) insert(worker *goWorker) error { func (wq *loopQueue) insert(w worker) error {
if wq.size == 0 { if wq.size == 0 {
return errQueueIsReleased return errQueueIsReleased
} }
@ -49,7 +49,7 @@ func (wq *loopQueue) insert(worker *goWorker) error {
if wq.isFull { if wq.isFull {
return errQueueIsFull return errQueueIsFull
} }
wq.items[wq.tail] = worker wq.items[wq.tail] = w
wq.tail++ wq.tail++
if wq.tail == wq.size { if wq.tail == wq.size {
@ -62,7 +62,7 @@ func (wq *loopQueue) insert(worker *goWorker) error {
return nil return nil
} }
func (wq *loopQueue) detach() *goWorker { func (wq *loopQueue) detach() worker {
if wq.isEmpty() { if wq.isEmpty() {
return nil return nil
} }
@ -78,7 +78,7 @@ func (wq *loopQueue) detach() *goWorker {
return w return w
} }
func (wq *loopQueue) retrieveExpiry(duration time.Duration) []*goWorker { func (wq *loopQueue) staleWorkers(duration time.Duration) []worker {
expiryTime := time.Now().Add(-duration) expiryTime := time.Now().Add(-duration)
index := wq.binarySearch(expiryTime) index := wq.binarySearch(expiryTime)
if index == -1 { if index == -1 {
@ -115,7 +115,7 @@ func (wq *loopQueue) binarySearch(expiryTime time.Time) int {
nlen = len(wq.items) nlen = len(wq.items)
// if no need to remove work, return -1 // if no need to remove work, return -1
if wq.isEmpty() || expiryTime.Before(wq.items[wq.head].recycleTime) { if wq.isEmpty() || expiryTime.Before(wq.items[wq.head].when()) {
return -1 return -1
} }
@ -137,7 +137,7 @@ func (wq *loopQueue) binarySearch(expiryTime time.Time) int {
mid = l + ((r - l) >> 1) mid = l + ((r - l) >> 1)
// calculate true mid position from mapped mid position // calculate true mid position from mapped mid position
tmid = (mid + basel + nlen) % nlen tmid = (mid + basel + nlen) % nlen
if expiryTime.Before(wq.items[tmid].recycleTime) { if expiryTime.Before(wq.items[tmid].when()) {
r = mid - 1 r = mid - 1
} else { } else {
l = mid + 1 l = mid + 1
@ -152,10 +152,10 @@ func (wq *loopQueue) reset() {
return return
} }
Releasing: retry:
if w := wq.detach(); w != nil { if w := wq.detach(); w != nil {
w.task <- nil w.finish()
goto Releasing goto retry
} }
wq.items = wq.items[:0] wq.items = wq.items[:0]
wq.size = 0 wq.size = 0

View File

@ -45,7 +45,7 @@ func TestLoopQueue(t *testing.T) {
err := q.insert(&goWorker{recycleTime: time.Now()}) err := q.insert(&goWorker{recycleTime: time.Now()})
assert.Error(t, err, "Enqueue, error") assert.Error(t, err, "Enqueue, error")
q.retrieveExpiry(time.Second) q.staleWorkers(time.Second)
assert.EqualValuesf(t, 6, q.len(), "Len error: %d", q.len()) assert.EqualValuesf(t, 6, q.len(), "Len error: %d", q.len())
} }
@ -118,14 +118,14 @@ func TestRotatedArraySearch(t *testing.T) {
// [expiry4, time, time, time, time, expiry5, time, time, time, time/head/tail] // [expiry4, time, time, time, time, expiry5, time, time, time, time/head/tail]
assert.EqualValues(t, -1, q.binarySearch(expiry2), "index should be -1") assert.EqualValues(t, -1, q.binarySearch(expiry2), "index should be -1")
assert.EqualValues(t, 9, q.binarySearch(q.items[9].recycleTime), "index should be 9") assert.EqualValues(t, 9, q.binarySearch(q.items[9].when()), "index should be 9")
assert.EqualValues(t, 8, q.binarySearch(time.Now()), "index should be 8") assert.EqualValues(t, 8, q.binarySearch(time.Now()), "index should be 8")
} }
func TestRetrieveExpiry(t *testing.T) { func TestRetrieveExpiry(t *testing.T) {
size := 10 size := 10
q := newWorkerLoopQueue(size) q := newWorkerLoopQueue(size)
expirew := make([]*goWorker, 0) expirew := make([]worker, 0)
u, _ := time.ParseDuration("1s") u, _ := time.ParseDuration("1s")
// test [ time+1s, time+1s, time+1s, time+1s, time+1s, time, time, time, time, time] // test [ time+1s, time+1s, time+1s, time+1s, time+1s, time, time, time, time, time]
@ -138,7 +138,7 @@ func TestRetrieveExpiry(t *testing.T) {
for i := 0; i < size/2; i++ { for i := 0; i < size/2; i++ {
_ = q.insert(&goWorker{recycleTime: time.Now()}) _ = q.insert(&goWorker{recycleTime: time.Now()})
} }
workers := q.retrieveExpiry(u) workers := q.staleWorkers(u)
assert.EqualValues(t, expirew, workers, "expired workers aren't right") assert.EqualValues(t, expirew, workers, "expired workers aren't right")
@ -151,7 +151,7 @@ func TestRetrieveExpiry(t *testing.T) {
expirew = expirew[:0] expirew = expirew[:0]
expirew = append(expirew, q.items[size/2:]...) expirew = append(expirew, q.items[size/2:]...)
workers2 := q.retrieveExpiry(u) workers2 := q.staleWorkers(u)
assert.EqualValues(t, expirew, workers2, "expired workers aren't right") assert.EqualValues(t, expirew, workers2, "expired workers aren't right")
@ -171,7 +171,7 @@ func TestRetrieveExpiry(t *testing.T) {
expirew = append(expirew, q.items[0:3]...) expirew = append(expirew, q.items[0:3]...)
expirew = append(expirew, q.items[size/2:]...) expirew = append(expirew, q.items[size/2:]...)
workers3 := q.retrieveExpiry(u) workers3 := q.staleWorkers(u)
assert.EqualValues(t, expirew, workers3, "expired workers aren't right") assert.EqualValues(t, expirew, workers3, "expired workers aren't right")
} }

View File

@ -13,27 +13,35 @@ var (
errQueueIsReleased = errors.New("the queue length is zero") errQueueIsReleased = errors.New("the queue length is zero")
) )
type workerArray interface { type worker interface {
run()
finish()
when() time.Time
inputFunc(func())
inputParam(interface{})
}
type workerQueue interface {
len() int len() int
isEmpty() bool isEmpty() bool
insert(worker *goWorker) error insert(worker) error
detach() *goWorker detach() worker
retrieveExpiry(duration time.Duration) []*goWorker staleWorkers(duration time.Duration) []worker
reset() reset()
} }
type arrayType int type queueType int
const ( const (
stackType arrayType = 1 << iota queueTypeStack queueType = 1 << iota
loopQueueType queueTypeLoopQueue
) )
func newWorkerArray(aType arrayType, size int) workerArray { func newWorkerArray(qType queueType, size int) workerQueue {
switch aType { switch qType {
case stackType: case queueTypeStack:
return newWorkerStack(size) return newWorkerStack(size)
case loopQueueType: case queueTypeLoopQueue:
return newWorkerLoopQueue(size) return newWorkerLoopQueue(size)
default: default:
return newWorkerStack(size) return newWorkerStack(size)

View File

@ -3,13 +3,13 @@ package ants
import "time" import "time"
type workerStack struct { type workerStack struct {
items []*goWorker items []worker
expiry []*goWorker expiry []worker
} }
func newWorkerStack(size int) *workerStack { func newWorkerStack(size int) *workerStack {
return &workerStack{ return &workerStack{
items: make([]*goWorker, 0, size), items: make([]worker, 0, size),
} }
} }
@ -21,12 +21,12 @@ func (wq *workerStack) isEmpty() bool {
return len(wq.items) == 0 return len(wq.items) == 0
} }
func (wq *workerStack) insert(worker *goWorker) error { func (wq *workerStack) insert(w worker) error {
wq.items = append(wq.items, worker) wq.items = append(wq.items, w)
return nil return nil
} }
func (wq *workerStack) detach() *goWorker { func (wq *workerStack) detach() worker {
l := wq.len() l := wq.len()
if l == 0 { if l == 0 {
return nil return nil
@ -39,7 +39,7 @@ func (wq *workerStack) detach() *goWorker {
return w return w
} }
func (wq *workerStack) retrieveExpiry(duration time.Duration) []*goWorker { func (wq *workerStack) staleWorkers(duration time.Duration) []worker {
n := wq.len() n := wq.len()
if n == 0 { if n == 0 {
return nil return nil
@ -64,7 +64,7 @@ func (wq *workerStack) binarySearch(l, r int, expiryTime time.Time) int {
var mid int var mid int
for l <= r { for l <= r {
mid = (l + r) / 2 mid = (l + r) / 2
if expiryTime.Before(wq.items[mid].recycleTime) { if expiryTime.Before(wq.items[mid].when()) {
r = mid - 1 r = mid - 1
} else { } else {
l = mid + 1 l = mid + 1
@ -75,7 +75,7 @@ func (wq *workerStack) binarySearch(l, r int, expiryTime time.Time) int {
func (wq *workerStack) reset() { func (wq *workerStack) reset() {
for i := 0; i < wq.len(); i++ { for i := 0; i < wq.len(); i++ {
wq.items[i].task <- nil wq.items[i].finish()
wq.items[i] = nil wq.items[i] = nil
} }
wq.items = wq.items[:0] wq.items = wq.items[:0]

View File

@ -19,7 +19,7 @@ func TestNewWorkerStack(t *testing.T) {
} }
func TestWorkerStack(t *testing.T) { func TestWorkerStack(t *testing.T) {
q := newWorkerArray(arrayType(-1), 0) q := newWorkerArray(queueType(-1), 0)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
err := q.insert(&goWorker{recycleTime: time.Now()}) err := q.insert(&goWorker{recycleTime: time.Now()})
@ -45,7 +45,7 @@ func TestWorkerStack(t *testing.T) {
} }
} }
assert.EqualValues(t, 12, q.len(), "Len error") assert.EqualValues(t, 12, q.len(), "Len error")
q.retrieveExpiry(time.Second) q.staleWorkers(time.Second)
assert.EqualValues(t, 6, q.len(), "Len error") assert.EqualValues(t, 6, q.len(), "Len error")
} }