opt: refactor the worker queue for reusability and readability of code

This commit is contained in:
Andy Pan 2023-03-23 11:40:06 +08:00
parent b880b659f5
commit 0313effc53
9 changed files with 119 additions and 109 deletions

34
pool.go
View File

@ -45,7 +45,7 @@ type Pool struct {
lock sync.Locker
// workers is a slice that store the available workers.
workers workerArray
workers workerQueue
// state is used to notice the pool to closed itself.
state int32
@ -91,16 +91,16 @@ func (p *Pool) purgeStaleWorkers(ctx context.Context) {
}
p.lock.Lock()
expiredWorkers := p.workers.retrieveExpiry(p.options.ExpiryDuration)
staleWorkers := p.workers.staleWorkers(p.options.ExpiryDuration)
p.lock.Unlock()
// Notify obsolete workers to stop.
// This notification must be outside the p.lock, since w.task
// may be blocking and may consume a lot of time if many workers
// are located on non-local CPUs.
for i := range expiredWorkers {
expiredWorkers[i].task <- nil
expiredWorkers[i] = nil
for i := range staleWorkers {
staleWorkers[i].finish()
staleWorkers[i] = nil
}
// There might be a situation where all workers have been cleaned up(no worker is running),
@ -160,12 +160,12 @@ func (p *Pool) nowTime() time.Time {
// NewPool generates an instance of ants pool.
func NewPool(size int, options ...Option) (*Pool, error) {
opts := loadOptions(options...)
if size <= 0 {
size = -1
}
opts := loadOptions(options...)
if !opts.DisablePurge {
if expiry := opts.ExpiryDuration; expiry < 0 {
return nil, ErrInvalidPoolExpiry
@ -193,9 +193,9 @@ func NewPool(size int, options ...Option) (*Pool, error) {
if size == -1 {
return nil, ErrInvalidPreAllocSize
}
p.workers = newWorkerArray(loopQueueType, size)
p.workers = newWorkerArray(queueTypeLoopQueue, size)
} else {
p.workers = newWorkerArray(stackType, 0)
p.workers = newWorkerArray(queueTypeStack, 0)
}
p.cond = sync.NewCond(p.lock)
@ -218,12 +218,11 @@ func (p *Pool) Submit(task func()) error {
if p.IsClosed() {
return ErrPoolClosed
}
var w *goWorker
if w = p.retrieveWorker(); w == nil {
return ErrPoolOverload
if w := p.retrieveWorker(); w != nil {
w.inputFunc(task)
return nil
}
w.task <- task
return nil
return ErrPoolOverload
}
// Running returns the number of workers currently running.
@ -331,14 +330,13 @@ func (p *Pool) addWaiting(delta int) {
}
// retrieveWorker returns an available worker to run the tasks.
func (p *Pool) retrieveWorker() (w *goWorker) {
func (p *Pool) retrieveWorker() (w worker) {
spawnWorker := func() {
w = p.workerCache.Get().(*goWorker)
w.run()
}
p.lock.Lock()
w = p.workers.detach()
if w != nil { // first try to fetch the worker from the queue
p.lock.Unlock()
@ -401,8 +399,7 @@ func (p *Pool) revertWorker(worker *goWorker) bool {
return false
}
err := p.workers.insert(worker)
if err != nil {
if err := p.workers.insert(worker); err != nil {
p.lock.Unlock()
return false
}
@ -410,5 +407,6 @@ func (p *Pool) revertWorker(worker *goWorker) bool {
// Notify the invoker stuck in 'retrieveWorker()' of there is an available worker in the worker queue.
p.cond.Signal()
p.lock.Unlock()
return true
}

View File

@ -44,7 +44,7 @@ type PoolWithFunc struct {
lock sync.Locker
// workers is a slice that store the available workers.
workers []*goWorkerWithFunc
workers workerQueue
// state is used to notice the pool to closed itself.
state int32
@ -80,7 +80,6 @@ func (p *PoolWithFunc) purgeStaleWorkers(ctx context.Context) {
atomic.StoreInt32(&p.purgeDone, 1)
}()
var expiredWorkers []*goWorkerWithFunc
for {
select {
case <-ctx.Done():
@ -92,38 +91,17 @@ func (p *PoolWithFunc) purgeStaleWorkers(ctx context.Context) {
break
}
criticalTime := time.Now().Add(-p.options.ExpiryDuration)
p.lock.Lock()
idleWorkers := p.workers
n := len(idleWorkers)
l, r, mid := 0, n-1, 0
for l <= r {
mid = (l + r) / 2
if criticalTime.Before(idleWorkers[mid].recycleTime) {
r = mid - 1
} else {
l = mid + 1
}
}
i := r + 1
expiredWorkers = append(expiredWorkers[:0], idleWorkers[:i]...)
if i > 0 {
m := copy(idleWorkers, idleWorkers[i:])
for i := m; i < n; i++ {
idleWorkers[i] = nil
}
p.workers = idleWorkers[:m]
}
staleWorkers := p.workers.staleWorkers(p.options.ExpiryDuration)
p.lock.Unlock()
// Notify obsolete workers to stop.
// This notification must be outside the p.lock, since w.task
// may be blocking and may consume a lot of time if many workers
// are located on non-local CPUs.
for i, w := range expiredWorkers {
w.args <- nil
expiredWorkers[i] = nil
for i := range staleWorkers {
staleWorkers[i].finish()
staleWorkers[i] = nil
}
// There might be a situation where all workers have been cleaned up(no worker is running),
@ -221,8 +199,11 @@ func NewPoolWithFunc(size int, pf func(interface{}), options ...Option) (*PoolWi
if size == -1 {
return nil, ErrInvalidPreAllocSize
}
p.workers = make([]*goWorkerWithFunc, 0, size)
p.workers = newWorkerArray(queueTypeLoopQueue, size)
} else {
p.workers = newWorkerArray(queueTypeStack, 0)
}
p.cond = sync.NewCond(p.lock)
p.goPurge()
@ -243,12 +224,11 @@ func (p *PoolWithFunc) Invoke(args interface{}) error {
if p.IsClosed() {
return ErrPoolClosed
}
var w *goWorkerWithFunc
if w = p.retrieveWorker(); w == nil {
return ErrPoolOverload
if w := p.retrieveWorker(); w != nil {
w.inputParam(args)
return nil
}
w.args <- args
return nil
return ErrPoolOverload
}
// Running returns the number of workers currently running.
@ -302,11 +282,7 @@ func (p *PoolWithFunc) Release() {
return
}
p.lock.Lock()
idleWorkers := p.workers
for _, w := range idleWorkers {
w.args <- nil
}
p.workers = nil
p.workers.reset()
p.lock.Unlock()
// There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent
// those callers blocking infinitely.
@ -360,19 +336,15 @@ func (p *PoolWithFunc) addWaiting(delta int) {
}
// retrieveWorker returns an available worker to run the tasks.
func (p *PoolWithFunc) retrieveWorker() (w *goWorkerWithFunc) {
func (p *PoolWithFunc) retrieveWorker() (w worker) {
spawnWorker := func() {
w = p.workerCache.Get().(*goWorkerWithFunc)
w.run()
}
p.lock.Lock()
idleWorkers := p.workers
n := len(idleWorkers) - 1
if n >= 0 { // first try to fetch the worker from the queue
w = idleWorkers[n]
idleWorkers[n] = nil
p.workers = idleWorkers[:n]
w = p.workers.detach()
if w != nil { // first try to fetch the worker from the queue
p.lock.Unlock()
} else if capacity := p.Cap(); capacity == -1 || capacity > p.Running() {
// if the worker queue is empty and we don't run out of the pool capacity,
@ -404,8 +376,7 @@ func (p *PoolWithFunc) retrieveWorker() (w *goWorkerWithFunc) {
spawnWorker()
return
}
l := len(p.workers) - 1
if l < 0 {
if w = p.workers.detach(); w == nil {
if nw < p.Cap() {
p.lock.Unlock()
spawnWorker()
@ -413,9 +384,6 @@ func (p *PoolWithFunc) retrieveWorker() (w *goWorkerWithFunc) {
}
goto retry
}
w = p.workers[l]
p.workers[l] = nil
p.workers = p.workers[:l]
p.lock.Unlock()
}
return
@ -437,10 +405,14 @@ func (p *PoolWithFunc) revertWorker(worker *goWorkerWithFunc) bool {
return false
}
p.workers = append(p.workers, worker)
if err := p.workers.insert(worker); err != nil {
p.lock.Unlock()
return false
}
// Notify the invoker stuck in 'retrieveWorker()' of there is an available worker in the worker queue.
p.cond.Signal()
p.lock.Unlock()
return true
}

View File

@ -74,3 +74,19 @@ func (w *goWorker) run() {
}
}()
}
func (w *goWorker) finish() {
w.task <- nil
}
func (w *goWorker) when() time.Time {
return w.recycleTime
}
func (w *goWorker) inputFunc(fn func()) {
w.task <- fn
}
func (w *goWorker) inputParam(interface{}) {
panic("unreachable")
}

View File

@ -74,3 +74,19 @@ func (w *goWorkerWithFunc) run() {
}
}()
}
func (w *goWorkerWithFunc) finish() {
w.args <- nil
}
func (w *goWorkerWithFunc) when() time.Time {
return w.recycleTime
}
func (w *goWorkerWithFunc) inputFunc(func()) {
panic("unreachable")
}
func (w *goWorkerWithFunc) inputParam(arg interface{}) {
w.args <- arg
}

View File

@ -3,8 +3,8 @@ package ants
import "time"
type loopQueue struct {
items []*goWorker
expiry []*goWorker
items []worker
expiry []worker
head int
tail int
size int
@ -13,7 +13,7 @@ type loopQueue struct {
func newWorkerLoopQueue(size int) *loopQueue {
return &loopQueue{
items: make([]*goWorker, size),
items: make([]worker, size),
size: size,
}
}
@ -41,7 +41,7 @@ func (wq *loopQueue) isEmpty() bool {
return wq.head == wq.tail && !wq.isFull
}
func (wq *loopQueue) insert(worker *goWorker) error {
func (wq *loopQueue) insert(w worker) error {
if wq.size == 0 {
return errQueueIsReleased
}
@ -49,7 +49,7 @@ func (wq *loopQueue) insert(worker *goWorker) error {
if wq.isFull {
return errQueueIsFull
}
wq.items[wq.tail] = worker
wq.items[wq.tail] = w
wq.tail++
if wq.tail == wq.size {
@ -62,7 +62,7 @@ func (wq *loopQueue) insert(worker *goWorker) error {
return nil
}
func (wq *loopQueue) detach() *goWorker {
func (wq *loopQueue) detach() worker {
if wq.isEmpty() {
return nil
}
@ -78,7 +78,7 @@ func (wq *loopQueue) detach() *goWorker {
return w
}
func (wq *loopQueue) retrieveExpiry(duration time.Duration) []*goWorker {
func (wq *loopQueue) staleWorkers(duration time.Duration) []worker {
expiryTime := time.Now().Add(-duration)
index := wq.binarySearch(expiryTime)
if index == -1 {
@ -115,7 +115,7 @@ func (wq *loopQueue) binarySearch(expiryTime time.Time) int {
nlen = len(wq.items)
// if no need to remove work, return -1
if wq.isEmpty() || expiryTime.Before(wq.items[wq.head].recycleTime) {
if wq.isEmpty() || expiryTime.Before(wq.items[wq.head].when()) {
return -1
}
@ -137,7 +137,7 @@ func (wq *loopQueue) binarySearch(expiryTime time.Time) int {
mid = l + ((r - l) >> 1)
// calculate true mid position from mapped mid position
tmid = (mid + basel + nlen) % nlen
if expiryTime.Before(wq.items[tmid].recycleTime) {
if expiryTime.Before(wq.items[tmid].when()) {
r = mid - 1
} else {
l = mid + 1
@ -152,10 +152,10 @@ func (wq *loopQueue) reset() {
return
}
Releasing:
retry:
if w := wq.detach(); w != nil {
w.task <- nil
goto Releasing
w.finish()
goto retry
}
wq.items = wq.items[:0]
wq.size = 0

View File

@ -45,7 +45,7 @@ func TestLoopQueue(t *testing.T) {
err := q.insert(&goWorker{recycleTime: time.Now()})
assert.Error(t, err, "Enqueue, error")
q.retrieveExpiry(time.Second)
q.staleWorkers(time.Second)
assert.EqualValuesf(t, 6, q.len(), "Len error: %d", q.len())
}
@ -118,14 +118,14 @@ func TestRotatedArraySearch(t *testing.T) {
// [expiry4, time, time, time, time, expiry5, time, time, time, time/head/tail]
assert.EqualValues(t, -1, q.binarySearch(expiry2), "index should be -1")
assert.EqualValues(t, 9, q.binarySearch(q.items[9].recycleTime), "index should be 9")
assert.EqualValues(t, 9, q.binarySearch(q.items[9].when()), "index should be 9")
assert.EqualValues(t, 8, q.binarySearch(time.Now()), "index should be 8")
}
func TestRetrieveExpiry(t *testing.T) {
size := 10
q := newWorkerLoopQueue(size)
expirew := make([]*goWorker, 0)
expirew := make([]worker, 0)
u, _ := time.ParseDuration("1s")
// test [ time+1s, time+1s, time+1s, time+1s, time+1s, time, time, time, time, time]
@ -138,7 +138,7 @@ func TestRetrieveExpiry(t *testing.T) {
for i := 0; i < size/2; i++ {
_ = q.insert(&goWorker{recycleTime: time.Now()})
}
workers := q.retrieveExpiry(u)
workers := q.staleWorkers(u)
assert.EqualValues(t, expirew, workers, "expired workers aren't right")
@ -151,7 +151,7 @@ func TestRetrieveExpiry(t *testing.T) {
expirew = expirew[:0]
expirew = append(expirew, q.items[size/2:]...)
workers2 := q.retrieveExpiry(u)
workers2 := q.staleWorkers(u)
assert.EqualValues(t, expirew, workers2, "expired workers aren't right")
@ -171,7 +171,7 @@ func TestRetrieveExpiry(t *testing.T) {
expirew = append(expirew, q.items[0:3]...)
expirew = append(expirew, q.items[size/2:]...)
workers3 := q.retrieveExpiry(u)
workers3 := q.staleWorkers(u)
assert.EqualValues(t, expirew, workers3, "expired workers aren't right")
}

View File

@ -13,27 +13,35 @@ var (
errQueueIsReleased = errors.New("the queue length is zero")
)
type workerArray interface {
type worker interface {
run()
finish()
when() time.Time
inputFunc(func())
inputParam(interface{})
}
type workerQueue interface {
len() int
isEmpty() bool
insert(worker *goWorker) error
detach() *goWorker
retrieveExpiry(duration time.Duration) []*goWorker
insert(worker) error
detach() worker
staleWorkers(duration time.Duration) []worker
reset()
}
type arrayType int
type queueType int
const (
stackType arrayType = 1 << iota
loopQueueType
queueTypeStack queueType = 1 << iota
queueTypeLoopQueue
)
func newWorkerArray(aType arrayType, size int) workerArray {
switch aType {
case stackType:
func newWorkerArray(qType queueType, size int) workerQueue {
switch qType {
case queueTypeStack:
return newWorkerStack(size)
case loopQueueType:
case queueTypeLoopQueue:
return newWorkerLoopQueue(size)
default:
return newWorkerStack(size)

View File

@ -3,13 +3,13 @@ package ants
import "time"
type workerStack struct {
items []*goWorker
expiry []*goWorker
items []worker
expiry []worker
}
func newWorkerStack(size int) *workerStack {
return &workerStack{
items: make([]*goWorker, 0, size),
items: make([]worker, 0, size),
}
}
@ -21,12 +21,12 @@ func (wq *workerStack) isEmpty() bool {
return len(wq.items) == 0
}
func (wq *workerStack) insert(worker *goWorker) error {
wq.items = append(wq.items, worker)
func (wq *workerStack) insert(w worker) error {
wq.items = append(wq.items, w)
return nil
}
func (wq *workerStack) detach() *goWorker {
func (wq *workerStack) detach() worker {
l := wq.len()
if l == 0 {
return nil
@ -39,7 +39,7 @@ func (wq *workerStack) detach() *goWorker {
return w
}
func (wq *workerStack) retrieveExpiry(duration time.Duration) []*goWorker {
func (wq *workerStack) staleWorkers(duration time.Duration) []worker {
n := wq.len()
if n == 0 {
return nil
@ -64,7 +64,7 @@ func (wq *workerStack) binarySearch(l, r int, expiryTime time.Time) int {
var mid int
for l <= r {
mid = (l + r) / 2
if expiryTime.Before(wq.items[mid].recycleTime) {
if expiryTime.Before(wq.items[mid].when()) {
r = mid - 1
} else {
l = mid + 1
@ -75,7 +75,7 @@ func (wq *workerStack) binarySearch(l, r int, expiryTime time.Time) int {
func (wq *workerStack) reset() {
for i := 0; i < wq.len(); i++ {
wq.items[i].task <- nil
wq.items[i].finish()
wq.items[i] = nil
}
wq.items = wq.items[:0]

View File

@ -19,7 +19,7 @@ func TestNewWorkerStack(t *testing.T) {
}
func TestWorkerStack(t *testing.T) {
q := newWorkerArray(arrayType(-1), 0)
q := newWorkerArray(queueType(-1), 0)
for i := 0; i < 5; i++ {
err := q.insert(&goWorker{recycleTime: time.Now()})
@ -45,7 +45,7 @@ func TestWorkerStack(t *testing.T) {
}
}
assert.EqualValues(t, 12, q.len(), "Len error")
q.retrieveExpiry(time.Second)
q.staleWorkers(time.Second)
assert.EqualValues(t, 6, q.len(), "Len error")
}