Check context.Done while waiting for a connection

This commit is contained in:
Vladimir Mihailenco 2019-06-08 15:36:00 +03:00
parent 530e66a66e
commit 35932b7961
3 changed files with 20 additions and 11 deletions

View File

@ -130,8 +130,8 @@ func (p *ConnPool) NewConn() (*Conn, error) {
return p._NewConn(nil, false) return p._NewConn(nil, false)
} }
func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) { func (p *ConnPool) _NewConn(ctx context.Context, pooled bool) (*Conn, error) {
cn, err := p.newConn(c, pooled) cn, err := p.newConn(ctx, pooled)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -149,7 +149,7 @@ func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) {
return cn, nil return cn, nil
} }
func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) { func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
if p.closed() { if p.closed() {
return nil, ErrClosed return nil, ErrClosed
} }
@ -158,7 +158,7 @@ func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
return nil, p.getLastDialError() return nil, p.getLastDialError()
} }
netConn, err := p.opt.Dialer(c) netConn, err := p.opt.Dialer(ctx)
if err != nil { if err != nil {
p.setLastDialError(err) p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
@ -205,12 +205,12 @@ func (p *ConnPool) getLastDialError() error {
} }
// Get returns existed connection from the pool or creates a new one. // Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(c context.Context) (*Conn, error) { func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
if p.closed() { if p.closed() {
return nil, ErrClosed return nil, ErrClosed
} }
err := p.waitTurn() err := p.waitTurn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -235,7 +235,7 @@ func (p *ConnPool) Get(c context.Context) (*Conn, error) {
atomic.AddUint32(&p.stats.Misses, 1) atomic.AddUint32(&p.stats.Misses, 1)
newcn, err := p._NewConn(c, true) newcn, err := p._NewConn(ctx, true)
if err != nil { if err != nil {
p.freeTurn() p.freeTurn()
return nil, err return nil, err
@ -248,8 +248,15 @@ func (p *ConnPool) getTurn() {
p.queue <- struct{}{} p.queue <- struct{}{}
} }
func (p *ConnPool) waitTurn() error { func (p *ConnPool) waitTurn(ctx context.Context) error {
var done <-chan struct{}
if ctx != nil {
done = ctx.Done()
}
select { select {
case <-done:
return ctx.Err()
case p.queue <- struct{}{}: case p.queue <- struct{}{}:
return nil return nil
default: default:
@ -257,6 +264,8 @@ func (p *ConnPool) waitTurn() error {
timer.Reset(p.opt.PoolTimeout) timer.Reset(p.opt.PoolTimeout)
select { select {
case <-done:
return ctx.Err()
case p.queue <- struct{}{}: case p.queue <- struct{}{}:
if !timer.Stop() { if !timer.Stop() {
<-timer.C <-timer.C

View File

@ -22,7 +22,7 @@ func (p *SingleConnPool) CloseConn(*Conn) error {
panic("not implemented") panic("not implemented")
} }
func (p *SingleConnPool) Get(c context.Context) (*Conn, error) { func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
return p.cn, nil return p.cn, nil
} }

View File

@ -31,7 +31,7 @@ func (p *StickyConnPool) CloseConn(*Conn) error {
panic("not implemented") panic("not implemented")
} }
func (p *StickyConnPool) Get(c context.Context) (*Conn, error) { func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
@ -42,7 +42,7 @@ func (p *StickyConnPool) Get(c context.Context) (*Conn, error) {
return p.cn, nil return p.cn, nil
} }
cn, err := p.pool.Get(c) cn, err := p.pool.Get(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }