diff --git a/cluster.go b/cluster.go index a784e28..d65fbff 100644 --- a/cluster.go +++ b/cluster.go @@ -809,7 +809,7 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { continue } - if isRetryableError(err, true) { + if isRetryableError(err, cmd.readTimeout() == nil) { // First retry the same node. if attempt == 0 { continue @@ -1075,7 +1075,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro } err = c.pipelineProcessCmds(ctx, node, cn, cmds, failedCmds) - node.Client.releaseConnStrict(cn, err) + node.Client.releaseConn(cn, err) }(node, cmds) } @@ -1282,7 +1282,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er } err = c.txPipelineProcessCmds(ctx, node, cn, cmds, failedCmds) - node.Client.releaseConnStrict(cn, err) + node.Client.releaseConn(cn, err) }(node, cmds) } diff --git a/error.go b/error.go index d9cfd76..0ab014d 100644 --- a/error.go +++ b/error.go @@ -6,13 +6,12 @@ import ( "net" "strings" - "github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/proto" ) func isRetryableError(err error, retryTimeout bool) bool { switch err { - case nil, context.Canceled, context.DeadlineExceeded, pool.ErrBadConn: + case nil, context.Canceled, context.DeadlineExceeded: return false case io.EOF: return true @@ -46,14 +45,13 @@ func isRedisError(err error) bool { } func isBadConn(err error, allowTimeout bool) bool { - switch err { - case nil: + if err == nil { return false - case pool.ErrBadConn: - return true } if isRedisError(err) { - return isReadOnlyError(err) // #790 + // Close connections in read only state in case domain addr is used + // and domain resolves to a different Redis Server. See #790. + return isReadOnlyError(err) } if allowTimeout { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index a461c72..197604c 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -86,7 +86,7 @@ func BenchmarkPoolGetRemove(b *testing.B) { if err != nil { b.Fatal(err) } - connPool.Remove(cn) + connPool.Remove(cn, nil) } }) }) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 1400a8c..1ac34fb 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -39,7 +39,7 @@ type Pooler interface { Get(context.Context) (*Conn, error) Put(*Conn) - Remove(*Conn) + Remove(*Conn, error) Len() int IdleLen() int @@ -311,7 +311,7 @@ func (p *ConnPool) popIdle() *Conn { func (p *ConnPool) Put(cn *Conn) { if !cn.pooled { - p.Remove(cn) + p.Remove(cn, nil) return } @@ -322,7 +322,7 @@ func (p *ConnPool) Put(cn *Conn) { p.freeTurn() } -func (p *ConnPool) Remove(cn *Conn) { +func (p *ConnPool) Remove(cn *Conn, reason error) { p.removeConnWithLock(cn) p.freeTurn() _ = p.closeConn(cn) diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 54d6c4e..3cee769 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -12,7 +12,17 @@ const ( stateClosed = 2 ) -var ErrBadConn = fmt.Errorf("pg: Conn is in a bad state") +type BadConnError struct { + wrapped error +} + +func (e BadConnError) Error() string { + return "pg: Conn is in a bad state" +} + +func (e BadConnError) Unwrap() error { + return e.wrapped +} type SingleConnPool struct { pool Pooler @@ -20,8 +30,8 @@ type SingleConnPool struct { state uint32 // atomic ch chan *Conn - level int32 // atomic - _hasBadConn uint32 // atomic + level int32 // atomic + _badConnError atomic.Value } var _ Pooler = (*SingleConnPool)(nil) @@ -66,10 +76,10 @@ func (p *SingleConnPool) Get(c context.Context) (*Conn, error) { if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { return cn, nil } - p.pool.Remove(cn) + p.pool.Remove(cn, ErrClosed) case stateInited: - if p.hasBadConn() { - return nil, ErrBadConn + if err := p.badConnError(); err != nil { + return nil, err } cn, ok := <-p.ch if !ok { @@ -95,20 +105,20 @@ func (p *SingleConnPool) Put(cn *Conn) { } func (p *SingleConnPool) freeConn(cn *Conn) { - if p.hasBadConn() { - p.pool.Remove(cn) + if err := p.badConnError(); err != nil { + p.pool.Remove(cn, err) } else { p.pool.Put(cn) } } -func (p *SingleConnPool) Remove(cn *Conn) { +func (p *SingleConnPool) Remove(cn *Conn, reason error) { defer func() { if recover() != nil { - p.pool.Remove(cn) + p.pool.Remove(cn, ErrClosed) } }() - atomic.StoreUint32(&p._hasBadConn, 1) + p._badConnError.Store(BadConnError{wrapped: reason}) p.ch <- cn } @@ -158,7 +168,7 @@ func (p *SingleConnPool) Close() error { } func (p *SingleConnPool) Reset() error { - if !atomic.CompareAndSwapUint32(&p._hasBadConn, 1, 0) { + if p.badConnError() == nil { return nil } @@ -167,7 +177,8 @@ func (p *SingleConnPool) Reset() error { if !ok { return ErrClosed } - p.pool.Remove(cn) + p.pool.Remove(cn, ErrClosed) + p._badConnError.Store(nil) default: return fmt.Errorf("pg: SingleConnPool does not have a Conn") } @@ -180,6 +191,9 @@ func (p *SingleConnPool) Reset() error { return nil } -func (p *SingleConnPool) hasBadConn() bool { - return atomic.LoadUint32(&p._hasBadConn) == 1 +func (p *SingleConnPool) badConnError() error { + if v := p._badConnError.Load(); v != nil { + return v.(BadConnError) + } + return nil } diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index d2074d2..d4a355a 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -58,13 +58,13 @@ func (p *StickyConnPool) putUpstream() { func (p *StickyConnPool) Put(cn *Conn) {} -func (p *StickyConnPool) removeUpstream() { - p.pool.Remove(p.cn) +func (p *StickyConnPool) removeUpstream(reason error) { + p.pool.Remove(p.cn, reason) p.cn = nil } -func (p *StickyConnPool) Remove(cn *Conn) { - p.removeUpstream() +func (p *StickyConnPool) Remove(cn *Conn, reason error) { + p.removeUpstream(reason) } func (p *StickyConnPool) Len() int { @@ -104,7 +104,7 @@ func (p *StickyConnPool) Close() error { if p.reusable { p.putUpstream() } else { - p.removeUpstream() + p.removeUpstream(ErrClosed) } } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index e023348..158c17c 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -65,7 +65,7 @@ var _ = Describe("ConnPool", func() { // ok } - connPool.Remove(cn) + connPool.Remove(cn, nil) // Check that Get is unblocked. select { @@ -128,7 +128,7 @@ var _ = Describe("MinIdleConns", func() { Context("after Remove", func() { BeforeEach(func() { - connPool.Remove(cn) + connPool.Remove(cn, nil) }) It("has idle connections", func() { @@ -205,7 +205,7 @@ var _ = Describe("MinIdleConns", func() { BeforeEach(func() { perform(len(cns), func(i int) { mu.RLock() - connPool.Remove(cns[i]) + connPool.Remove(cns[i], nil) mu.RUnlock() }) @@ -355,7 +355,7 @@ var _ = Describe("conns reaper", func() { Expect(connPool.Len()).To(Equal(4)) Expect(connPool.IdleLen()).To(Equal(0)) - connPool.Remove(cn) + connPool.Remove(cn, nil) Expect(connPool.Len()).To(Equal(3)) Expect(connPool.IdleLen()).To(Equal(0)) @@ -413,7 +413,7 @@ var _ = Describe("race", func() { cn, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) if err == nil { - connPool.Remove(cn) + connPool.Remove(cn, nil) } } }) diff --git a/internal/util.go b/internal/util.go index 6e47140..30223fb 100644 --- a/internal/util.go +++ b/internal/util.go @@ -44,3 +44,13 @@ func isLower(s string) bool { } return true } + +func Unwrap(err error) error { + u, ok := err.(interface { + Unwrap() error + }) + if !ok { + return nil + } + return u.Unwrap() +} diff --git a/redis.go b/redis.go index a20c556..6695576 100644 --- a/redis.go +++ b/redis.go @@ -171,7 +171,10 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { err = c.initConn(ctx, cn) if err != nil { - c.connPool.Remove(cn) + c.connPool.Remove(cn, err) + if err := internal.Unwrap(err); err != nil { + return nil, err + } return nil, err } @@ -226,24 +229,12 @@ func (c *baseClient) releaseConn(cn *pool.Conn, err error) { } if isBadConn(err, false) { - c.connPool.Remove(cn) + c.connPool.Remove(cn, err) } else { c.connPool.Put(cn) } } -func (c *baseClient) releaseConnStrict(cn *pool.Conn, err error) { - if c.limiter != nil { - c.limiter.ReportResult(err) - } - - if err == nil || isRedisError(err) { - c.connPool.Put(cn) - } else { - c.connPool.Remove(cn) - } -} - func (c *baseClient) process(ctx context.Context, cmd Cmder) error { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { @@ -348,7 +339,7 @@ func (c *baseClient) generalProcessPipeline( } canRetry, err := p(ctx, cn, cmds) - c.releaseConnStrict(cn, err) + c.releaseConn(cn, err) if !canRetry || !isRetryableError(err, true) { break diff --git a/ring.go b/ring.go index 30d87d2..5b7a7b9 100644 --- a/ring.go +++ b/ring.go @@ -660,7 +660,7 @@ func (c *Ring) generalProcessPipeline( } else { canRetry, err = shard.Client.pipelineProcessCmds(ctx, cn, cmds) } - shard.Client.releaseConnStrict(cn, err) + shard.Client.releaseConn(cn, err) if canRetry && isRetryableError(err, true) { mu.Lock()