diff --git a/cluster.go b/cluster.go index a784e28a..bd16fe1e 100644 --- a/cluster.go +++ b/cluster.go @@ -809,7 +809,7 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { continue } - if isRetryableError(err, true) { + if isRetryableError(err, cmd.readTimeout() == nil) { // First retry the same node. if attempt == 0 { continue diff --git a/error.go b/error.go index d9cfd760..c0e561dd 100644 --- a/error.go +++ b/error.go @@ -6,13 +6,12 @@ import ( "net" "strings" - "github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/proto" ) func isRetryableError(err error, retryTimeout bool) bool { switch err { - case nil, context.Canceled, context.DeadlineExceeded, pool.ErrBadConn: + case nil, context.Canceled, context.DeadlineExceeded: return false case io.EOF: return true @@ -49,8 +48,6 @@ func isBadConn(err error, allowTimeout bool) bool { switch err { case nil: return false - case pool.ErrBadConn: - return true } if isRedisError(err) { return isReadOnlyError(err) // #790 diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index a461c72f..197604cb 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -86,7 +86,7 @@ func BenchmarkPoolGetRemove(b *testing.B) { if err != nil { b.Fatal(err) } - connPool.Remove(cn) + connPool.Remove(cn, nil) } }) }) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 1400a8cf..1ac34fbe 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -39,7 +39,7 @@ type Pooler interface { Get(context.Context) (*Conn, error) Put(*Conn) - Remove(*Conn) + Remove(*Conn, error) Len() int IdleLen() int @@ -311,7 +311,7 @@ func (p *ConnPool) popIdle() *Conn { func (p *ConnPool) Put(cn *Conn) { if !cn.pooled { - p.Remove(cn) + p.Remove(cn, nil) return } @@ -322,7 +322,7 @@ func (p *ConnPool) Put(cn *Conn) { p.freeTurn() } -func (p *ConnPool) Remove(cn *Conn) { +func (p *ConnPool) Remove(cn *Conn, reason error) { p.removeConnWithLock(cn) p.freeTurn() _ = p.closeConn(cn) diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 54d6c4e5..3cee769b 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -12,7 +12,17 @@ const ( stateClosed = 2 ) -var ErrBadConn = fmt.Errorf("pg: Conn is in a bad state") +type BadConnError struct { + wrapped error +} + +func (e BadConnError) Error() string { + return "pg: Conn is in a bad state" +} + +func (e BadConnError) Unwrap() error { + return e.wrapped +} type SingleConnPool struct { pool Pooler @@ -20,8 +30,8 @@ type SingleConnPool struct { state uint32 // atomic ch chan *Conn - level int32 // atomic - _hasBadConn uint32 // atomic + level int32 // atomic + _badConnError atomic.Value } var _ Pooler = (*SingleConnPool)(nil) @@ -66,10 +76,10 @@ func (p *SingleConnPool) Get(c context.Context) (*Conn, error) { if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { return cn, nil } - p.pool.Remove(cn) + p.pool.Remove(cn, ErrClosed) case stateInited: - if p.hasBadConn() { - return nil, ErrBadConn + if err := p.badConnError(); err != nil { + return nil, err } cn, ok := <-p.ch if !ok { @@ -95,20 +105,20 @@ func (p *SingleConnPool) Put(cn *Conn) { } func (p *SingleConnPool) freeConn(cn *Conn) { - if p.hasBadConn() { - p.pool.Remove(cn) + if err := p.badConnError(); err != nil { + p.pool.Remove(cn, err) } else { p.pool.Put(cn) } } -func (p *SingleConnPool) Remove(cn *Conn) { +func (p *SingleConnPool) Remove(cn *Conn, reason error) { defer func() { if recover() != nil { - p.pool.Remove(cn) + p.pool.Remove(cn, ErrClosed) } }() - atomic.StoreUint32(&p._hasBadConn, 1) + p._badConnError.Store(BadConnError{wrapped: reason}) p.ch <- cn } @@ -158,7 +168,7 @@ func (p *SingleConnPool) Close() error { } func (p *SingleConnPool) Reset() error { - if !atomic.CompareAndSwapUint32(&p._hasBadConn, 1, 0) { + if p.badConnError() == nil { return nil } @@ -167,7 +177,8 @@ func (p *SingleConnPool) Reset() error { if !ok { return ErrClosed } - p.pool.Remove(cn) + p.pool.Remove(cn, ErrClosed) + p._badConnError.Store(nil) default: return fmt.Errorf("pg: SingleConnPool does not have a Conn") } @@ -180,6 +191,9 @@ func (p *SingleConnPool) Reset() error { return nil } -func (p *SingleConnPool) hasBadConn() bool { - return atomic.LoadUint32(&p._hasBadConn) == 1 +func (p *SingleConnPool) badConnError() error { + if v := p._badConnError.Load(); v != nil { + return v.(BadConnError) + } + return nil } diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index d2074d23..d4a355a4 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -58,13 +58,13 @@ func (p *StickyConnPool) putUpstream() { func (p *StickyConnPool) Put(cn *Conn) {} -func (p *StickyConnPool) removeUpstream() { - p.pool.Remove(p.cn) +func (p *StickyConnPool) removeUpstream(reason error) { + p.pool.Remove(p.cn, reason) p.cn = nil } -func (p *StickyConnPool) Remove(cn *Conn) { - p.removeUpstream() +func (p *StickyConnPool) Remove(cn *Conn, reason error) { + p.removeUpstream(reason) } func (p *StickyConnPool) Len() int { @@ -104,7 +104,7 @@ func (p *StickyConnPool) Close() error { if p.reusable { p.putUpstream() } else { - p.removeUpstream() + p.removeUpstream(ErrClosed) } } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index e023348c..158c17cd 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -65,7 +65,7 @@ var _ = Describe("ConnPool", func() { // ok } - connPool.Remove(cn) + connPool.Remove(cn, nil) // Check that Get is unblocked. select { @@ -128,7 +128,7 @@ var _ = Describe("MinIdleConns", func() { Context("after Remove", func() { BeforeEach(func() { - connPool.Remove(cn) + connPool.Remove(cn, nil) }) It("has idle connections", func() { @@ -205,7 +205,7 @@ var _ = Describe("MinIdleConns", func() { BeforeEach(func() { perform(len(cns), func(i int) { mu.RLock() - connPool.Remove(cns[i]) + connPool.Remove(cns[i], nil) mu.RUnlock() }) @@ -355,7 +355,7 @@ var _ = Describe("conns reaper", func() { Expect(connPool.Len()).To(Equal(4)) Expect(connPool.IdleLen()).To(Equal(0)) - connPool.Remove(cn) + connPool.Remove(cn, nil) Expect(connPool.Len()).To(Equal(3)) Expect(connPool.IdleLen()).To(Equal(0)) @@ -413,7 +413,7 @@ var _ = Describe("race", func() { cn, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) if err == nil { - connPool.Remove(cn) + connPool.Remove(cn, nil) } } }) diff --git a/internal/util.go b/internal/util.go index 6e471400..30223fb2 100644 --- a/internal/util.go +++ b/internal/util.go @@ -44,3 +44,13 @@ func isLower(s string) bool { } return true } + +func Unwrap(err error) error { + u, ok := err.(interface { + Unwrap() error + }) + if !ok { + return nil + } + return u.Unwrap() +} diff --git a/redis.go b/redis.go index a20c5568..c7b12ec0 100644 --- a/redis.go +++ b/redis.go @@ -171,7 +171,10 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { err = c.initConn(ctx, cn) if err != nil { - c.connPool.Remove(cn) + c.connPool.Remove(cn, err) + if err := internal.Unwrap(err); err != nil { + return nil, err + } return nil, err } @@ -226,7 +229,7 @@ func (c *baseClient) releaseConn(cn *pool.Conn, err error) { } if isBadConn(err, false) { - c.connPool.Remove(cn) + c.connPool.Remove(cn, err) } else { c.connPool.Put(cn) } @@ -240,7 +243,7 @@ func (c *baseClient) releaseConnStrict(cn *pool.Conn, err error) { if err == nil || isRedisError(err) { c.connPool.Put(cn) } else { - c.connPool.Remove(cn) + c.connPool.Remove(cn, err) } }