Check context.Done while waiting for a connection

This commit is contained in:
Vladimir Mihailenco 2019-06-08 15:36:00 +03:00
parent 530e66a66e
commit 35932b7961
3 changed files with 20 additions and 11 deletions

View File

@ -130,8 +130,8 @@ func (p *ConnPool) NewConn() (*Conn, error) {
return p._NewConn(nil, false)
}
func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) {
cn, err := p.newConn(c, pooled)
func (p *ConnPool) _NewConn(ctx context.Context, pooled bool) (*Conn, error) {
cn, err := p.newConn(ctx, pooled)
if err != nil {
return nil, err
}
@ -149,7 +149,7 @@ func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) {
return cn, nil
}
func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
@ -158,7 +158,7 @@ func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
return nil, p.getLastDialError()
}
netConn, err := p.opt.Dialer(c)
netConn, err := p.opt.Dialer(ctx)
if err != nil {
p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
@ -205,12 +205,12 @@ func (p *ConnPool) getLastDialError() error {
}
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(c context.Context) (*Conn, error) {
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
err := p.waitTurn()
err := p.waitTurn(ctx)
if err != nil {
return nil, err
}
@ -235,7 +235,7 @@ func (p *ConnPool) Get(c context.Context) (*Conn, error) {
atomic.AddUint32(&p.stats.Misses, 1)
newcn, err := p._NewConn(c, true)
newcn, err := p._NewConn(ctx, true)
if err != nil {
p.freeTurn()
return nil, err
@ -248,8 +248,15 @@ func (p *ConnPool) getTurn() {
p.queue <- struct{}{}
}
func (p *ConnPool) waitTurn() error {
func (p *ConnPool) waitTurn(ctx context.Context) error {
var done <-chan struct{}
if ctx != nil {
done = ctx.Done()
}
select {
case <-done:
return ctx.Err()
case p.queue <- struct{}{}:
return nil
default:
@ -257,6 +264,8 @@ func (p *ConnPool) waitTurn() error {
timer.Reset(p.opt.PoolTimeout)
select {
case <-done:
return ctx.Err()
case p.queue <- struct{}{}:
if !timer.Stop() {
<-timer.C

View File

@ -22,7 +22,7 @@ func (p *SingleConnPool) CloseConn(*Conn) error {
panic("not implemented")
}
func (p *SingleConnPool) Get(c context.Context) (*Conn, error) {
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
return p.cn, nil
}

View File

@ -31,7 +31,7 @@ func (p *StickyConnPool) CloseConn(*Conn) error {
panic("not implemented")
}
func (p *StickyConnPool) Get(c context.Context) (*Conn, error) {
func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
p.mu.Lock()
defer p.mu.Unlock()
@ -42,7 +42,7 @@ func (p *StickyConnPool) Get(c context.Context) (*Conn, error) {
return p.cn, nil
}
cn, err := p.pool.Get(c)
cn, err := p.pool.Get(ctx)
if err != nil {
return nil, err
}