forked from mirror/redis
Check context.Done while waiting for a connection
This commit is contained in:
parent
530e66a66e
commit
35932b7961
|
@ -130,8 +130,8 @@ func (p *ConnPool) NewConn() (*Conn, error) {
|
|||
return p._NewConn(nil, false)
|
||||
}
|
||||
|
||||
func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) {
|
||||
cn, err := p.newConn(c, pooled)
|
||||
func (p *ConnPool) _NewConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
cn, err := p.newConn(ctx, pooled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) {
|
|||
return cn, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
|
||||
func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
|
||||
if p.closed() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
@ -158,7 +158,7 @@ func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
|
|||
return nil, p.getLastDialError()
|
||||
}
|
||||
|
||||
netConn, err := p.opt.Dialer(c)
|
||||
netConn, err := p.opt.Dialer(ctx)
|
||||
if err != nil {
|
||||
p.setLastDialError(err)
|
||||
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
|
||||
|
@ -205,12 +205,12 @@ func (p *ConnPool) getLastDialError() error {
|
|||
}
|
||||
|
||||
// Get returns existed connection from the pool or creates a new one.
|
||||
func (p *ConnPool) Get(c context.Context) (*Conn, error) {
|
||||
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
if p.closed() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
err := p.waitTurn()
|
||||
err := p.waitTurn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -235,7 +235,7 @@ func (p *ConnPool) Get(c context.Context) (*Conn, error) {
|
|||
|
||||
atomic.AddUint32(&p.stats.Misses, 1)
|
||||
|
||||
newcn, err := p._NewConn(c, true)
|
||||
newcn, err := p._NewConn(ctx, true)
|
||||
if err != nil {
|
||||
p.freeTurn()
|
||||
return nil, err
|
||||
|
@ -248,8 +248,15 @@ func (p *ConnPool) getTurn() {
|
|||
p.queue <- struct{}{}
|
||||
}
|
||||
|
||||
func (p *ConnPool) waitTurn() error {
|
||||
func (p *ConnPool) waitTurn(ctx context.Context) error {
|
||||
var done <-chan struct{}
|
||||
if ctx != nil {
|
||||
done = ctx.Done()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return ctx.Err()
|
||||
case p.queue <- struct{}{}:
|
||||
return nil
|
||||
default:
|
||||
|
@ -257,6 +264,8 @@ func (p *ConnPool) waitTurn() error {
|
|||
timer.Reset(p.opt.PoolTimeout)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return ctx.Err()
|
||||
case p.queue <- struct{}{}:
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
|
|
|
@ -22,7 +22,7 @@ func (p *SingleConnPool) CloseConn(*Conn) error {
|
|||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Get(c context.Context) (*Conn, error) {
|
||||
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
return p.cn, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ func (p *StickyConnPool) CloseConn(*Conn) error {
|
|||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) Get(c context.Context) (*Conn, error) {
|
||||
func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
|
@ -42,7 +42,7 @@ func (p *StickyConnPool) Get(c context.Context) (*Conn, error) {
|
|||
return p.cn, nil
|
||||
}
|
||||
|
||||
cn, err := p.pool.Get(c)
|
||||
cn, err := p.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue