diff --git a/internal/pool/pool.go b/internal/pool/pool.go index fa6855c..013e29b 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -130,8 +130,8 @@ func (p *ConnPool) NewConn() (*Conn, error) { return p._NewConn(nil, false) } -func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) { - cn, err := p.newConn(c, pooled) +func (p *ConnPool) _NewConn(ctx context.Context, pooled bool) (*Conn, error) { + cn, err := p.newConn(ctx, pooled) if err != nil { return nil, err } @@ -149,7 +149,7 @@ func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) { return cn, nil } -func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) { +func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { if p.closed() { return nil, ErrClosed } @@ -158,7 +158,7 @@ func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) { return nil, p.getLastDialError() } - netConn, err := p.opt.Dialer(c) + netConn, err := p.opt.Dialer(ctx) if err != nil { p.setLastDialError(err) if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { @@ -205,12 +205,12 @@ func (p *ConnPool) getLastDialError() error { } // Get returns existed connection from the pool or creates a new one. -func (p *ConnPool) Get(c context.Context) (*Conn, error) { +func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { if p.closed() { return nil, ErrClosed } - err := p.waitTurn() + err := p.waitTurn(ctx) if err != nil { return nil, err } @@ -235,7 +235,7 @@ func (p *ConnPool) Get(c context.Context) (*Conn, error) { atomic.AddUint32(&p.stats.Misses, 1) - newcn, err := p._NewConn(c, true) + newcn, err := p._NewConn(ctx, true) if err != nil { p.freeTurn() return nil, err @@ -248,8 +248,15 @@ func (p *ConnPool) getTurn() { p.queue <- struct{}{} } -func (p *ConnPool) waitTurn() error { +func (p *ConnPool) waitTurn(ctx context.Context) error { + var done <-chan struct{} + if ctx != nil { + done = ctx.Done() + } + select { + case <-done: + return ctx.Err() case p.queue <- struct{}{}: return nil default: @@ -257,6 +264,8 @@ func (p *ConnPool) waitTurn() error { timer.Reset(p.opt.PoolTimeout) select { + case <-done: + return ctx.Err() case p.queue <- struct{}{}: if !timer.Stop() { <-timer.C diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 6112d8f..778b124 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -22,7 +22,7 @@ func (p *SingleConnPool) CloseConn(*Conn) error { panic("not implemented") } -func (p *SingleConnPool) Get(c context.Context) (*Conn, error) { +func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { return p.cn, nil } diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index 1e632ec..174dc9c 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -31,7 +31,7 @@ func (p *StickyConnPool) CloseConn(*Conn) error { panic("not implemented") } -func (p *StickyConnPool) Get(c context.Context) (*Conn, error) { +func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) { p.mu.Lock() defer p.mu.Unlock() @@ -42,7 +42,7 @@ func (p *StickyConnPool) Get(c context.Context) (*Conn, error) { return p.cn, nil } - cn, err := p.pool.Get(c) + cn, err := p.pool.Get(ctx) if err != nil { return nil, err }