diff --git a/Makefile b/Makefile index 42c86f2..b7867b4 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ all: testdeps - go test ./... -test.v -test.cpu=1,2,4 - go test ./... -test.v -test.short -test.race + go test ./... -test.cpu=1,2,4 + go test ./... -test.short -test.race testdeps: testdata/redis/src/redis-server diff --git a/cluster.go b/cluster.go index 6e79cb5..0cd59d9 100644 --- a/cluster.go +++ b/cluster.go @@ -7,6 +7,7 @@ import ( "time" "gopkg.in/redis.v3/internal/hashtag" + "gopkg.in/redis.v3/internal/pool" ) // ClusterClient is a Redis Cluster client representing a pool of zero @@ -80,7 +81,7 @@ func (c *ClusterClient) Close() error { c.clientsMx.Lock() if c.closed { - return errClosed + return pool.ErrClosed } c.closed = true c.resetClients() @@ -105,7 +106,7 @@ func (c *ClusterClient) getClient(addr string) (*Client, error) { c.clientsMx.Lock() if c.closed { c.clientsMx.Unlock() - return nil, errClosed + return nil, pool.ErrClosed } client, ok = c.clients[addr] diff --git a/cluster_pipeline.go b/cluster_pipeline.go index 7fa721c..8641d3d 100644 --- a/cluster_pipeline.go +++ b/cluster_pipeline.go @@ -34,7 +34,7 @@ func (pipe *ClusterPipeline) process(cmd Cmder) { // Discard resets the pipeline and discards queued commands. func (pipe *ClusterPipeline) Discard() error { if pipe.closed { - return errClosed + return pool.ErrClosed } pipe.cmds = pipe.cmds[:0] return nil @@ -42,7 +42,7 @@ func (pipe *ClusterPipeline) Discard() error { func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) { if pipe.closed { - return nil, errClosed + return nil, pool.ErrClosed } if len(pipe.cmds) == 0 { return []Cmder{}, nil diff --git a/command_test.go b/command_test.go index 088b178..064e734 100644 --- a/command_test.go +++ b/command_test.go @@ -5,11 +5,13 @@ import ( "strconv" "sync" "testing" + "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "gopkg.in/redis.v3" + "gopkg.in/redis.v3/internal/pool" ) var _ = Describe("Command", func() { @@ -17,7 +19,8 @@ var _ = Describe("Command", func() { connect := func() *redis.Client { return redis.NewClient(&redis.Options{ - Addr: redisAddr, + Addr: redisAddr, + PoolTimeout: time.Minute, }) } @@ -62,19 +65,19 @@ var _ = Describe("Command", func() { }) It("should handle big vals", func() { - val := string(bytes.Repeat([]byte{'*'}, 1<<16)) - set := client.Set("key", val, 0) - Expect(set.Err()).NotTo(HaveOccurred()) - Expect(set.Val()).To(Equal("OK")) + bigVal := string(bytes.Repeat([]byte{'*'}, 1<<16)) + + err := client.Set("key", bigVal, 0).Err() + Expect(err).NotTo(HaveOccurred()) // Reconnect to get new connection. Expect(client.Close()).To(BeNil()) client = connect() - get := client.Get("key") - Expect(get.Err()).NotTo(HaveOccurred()) - Expect(len(get.Val())).To(Equal(len(val))) - Expect(get.Val()).To(Equal(val)) + got, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(got)).To(Equal(len(bigVal))) + Expect(got).To(Equal(bigVal)) }) It("should handle many keys #1", func() { @@ -140,48 +143,111 @@ var _ = Describe("Command", func() { } It("should echo", func() { - wg := &sync.WaitGroup{} - for i := 0; i < C; i++ { - wg.Add(1) - - go func(i int) { - defer GinkgoRecover() - defer wg.Done() - - for j := 0; j < N; j++ { - msg := "echo" + strconv.Itoa(i) - echo := client.Echo(msg) - Expect(echo.Err()).NotTo(HaveOccurred()) - Expect(echo.Val()).To(Equal(msg)) - } - }(i) - } - wg.Wait() + perform(C, func() { + for i := 0; i < N; i++ { + msg := "echo" + strconv.Itoa(i) + echo, err := client.Echo(msg).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(echo).To(Equal(msg)) + } + }) }) It("should incr", func() { key := "TestIncrFromGoroutines" - wg := &sync.WaitGroup{} - for i := 0; i < C; i++ { - wg.Add(1) - go func() { - defer GinkgoRecover() - defer wg.Done() - - for j := 0; j < N; j++ { - err := client.Incr(key).Err() - Expect(err).NotTo(HaveOccurred()) - } - }() - } - wg.Wait() + perform(C, func() { + for i := 0; i < N; i++ { + err := client.Incr(key).Err() + Expect(err).NotTo(HaveOccurred()) + } + }) val, err := client.Get(key).Int64() Expect(err).NotTo(HaveOccurred()) Expect(val).To(Equal(int64(C * N))) }) + It("should handle big vals", func() { + client2 := connect() + defer client2.Close() + + bigVal := string(bytes.Repeat([]byte{'*'}, 1<<16)) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + perform(C, func() { + for i := 0; i < N; i++ { + got, err := client.Get("key").Result() + if err == redis.Nil { + continue + } + Expect(got).To(Equal(bigVal)) + } + }) + }() + + go func() { + defer wg.Done() + perform(C, func() { + for i := 0; i < N; i++ { + err := client2.Set("key", bigVal, 0).Err() + Expect(err).NotTo(HaveOccurred()) + } + }) + }() + + wg.Wait() + }) + + It("should PubSub", func() { + connPool := client.Pool() + connPool.(*pool.ConnPool).DialLimiter = nil + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + perform(C, func() { + for i := 0; i < N; i++ { + pubsub, err := client.Subscribe("mychannel") + Expect(err).NotTo(HaveOccurred()) + + go func() { + defer GinkgoRecover() + + time.Sleep(time.Millisecond) + err := pubsub.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + _, err = pubsub.ReceiveMessage() + Expect(err.Error()).To(ContainSubstring("closed")) + } + }) + }() + + go func() { + defer wg.Done() + perform(C, func() { + for i := 0; i < N; i++ { + val := "echo" + strconv.Itoa(i) + echo, err := client.Echo(val).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(echo).To(Equal(val)) + } + }) + }() + + wg.Wait() + + Expect(connPool.Len()).To(Equal(connPool.FreeLen())) + Expect(connPool.Len()).To(BeNumerically("<=", 10)) + }) }) }) diff --git a/error.go b/error.go index e2430b4..3f2a560 100644 --- a/error.go +++ b/error.go @@ -1,15 +1,12 @@ package redis import ( - "errors" "fmt" "io" "net" "strings" ) -var errClosed = errors.New("redis: client is closed") - // Redis nil reply, .e.g. when key does not exist. var Nil = errorf("redis: nil") diff --git a/example_test.go b/example_test.go index c93e0b7..3034d86 100644 --- a/example_test.go +++ b/example_test.go @@ -220,13 +220,13 @@ func ExampleClient_Watch() { } func ExamplePubSub() { - pubsub, err := client.Subscribe("mychannel") + pubsub, err := client.Subscribe("mychannel1") if err != nil { panic(err) } defer pubsub.Close() - err = client.Publish("mychannel", "hello").Err() + err = client.Publish("mychannel1", "hello").Err() if err != nil { panic(err) } @@ -237,17 +237,17 @@ func ExamplePubSub() { } fmt.Println(msg.Channel, msg.Payload) - // Output: mychannel hello + // Output: mychannel1 hello } func ExamplePubSub_Receive() { - pubsub, err := client.Subscribe("mychannel") + pubsub, err := client.Subscribe("mychannel2") if err != nil { panic(err) } defer pubsub.Close() - err = client.Publish("mychannel", "hello").Err() + err = client.Publish("mychannel2", "hello").Err() if err != nil { panic(err) } @@ -269,8 +269,8 @@ func ExamplePubSub_Receive() { } } - // Output: subscribe mychannel - // mychannel hello + // Output: subscribe mychannel2 + // mychannel2 hello } func ExampleScript() { diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 1f1d4a9..cbe379b 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -3,6 +3,7 @@ package pool import ( "bufio" "net" + "sync/atomic" "time" ) @@ -11,9 +12,9 @@ const defaultBufSize = 4096 var noDeadline = time.Time{} type Conn struct { - idx int + idx int32 - netConn net.Conn + NetConn net.Conn Rd *bufio.Reader Buf []byte @@ -26,7 +27,7 @@ func NewConn(netConn net.Conn) *Conn { cn := &Conn{ idx: -1, - netConn: netConn, + NetConn: netConn, Buf: make([]byte, defaultBufSize), UsedAt: time.Now(), @@ -35,39 +36,47 @@ func NewConn(netConn net.Conn) *Conn { return cn } -func (cn *Conn) IsStale(timeout time.Duration) bool { - return timeout > 0 && time.Since(cn.UsedAt) > timeout +func (cn *Conn) Index() int { + return int(atomic.LoadInt32(&cn.idx)) } -func (cn *Conn) SetNetConn(netConn net.Conn) { - cn.netConn = netConn - cn.UsedAt = time.Now() +func (cn *Conn) SetIndex(idx int) { + atomic.StoreInt32(&cn.idx, int32(idx)) +} + +func (cn *Conn) IsStale(timeout time.Duration) bool { + return timeout > 0 && time.Since(cn.UsedAt) > timeout } func (cn *Conn) Read(b []byte) (int, error) { cn.UsedAt = time.Now() if cn.ReadTimeout != 0 { - cn.netConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout)) + cn.NetConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout)) } else { - cn.netConn.SetReadDeadline(noDeadline) + cn.NetConn.SetReadDeadline(noDeadline) } - return cn.netConn.Read(b) + return cn.NetConn.Read(b) } func (cn *Conn) Write(b []byte) (int, error) { cn.UsedAt = time.Now() if cn.WriteTimeout != 0 { - cn.netConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout)) + cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout)) } else { - cn.netConn.SetWriteDeadline(noDeadline) + cn.NetConn.SetWriteDeadline(noDeadline) } - return cn.netConn.Write(b) + return cn.NetConn.Write(b) } func (cn *Conn) RemoteAddr() net.Addr { - return cn.netConn.RemoteAddr() + return cn.NetConn.RemoteAddr() } -func (cn *Conn) Close() error { - return cn.netConn.Close() +func (cn *Conn) Close() int { + idx := cn.Index() + if !atomic.CompareAndSwapInt32(&cn.idx, int32(idx), -1) { + return -1 + } + _ = cn.NetConn.Close() + return idx } diff --git a/internal/pool/conn_list.go b/internal/pool/conn_list.go index e72dc91..7e43ee7 100644 --- a/internal/pool/conn_list.go +++ b/internal/pool/conn_list.go @@ -24,7 +24,7 @@ func (l *connList) Len() int { } // Reserve reserves place in the list and returns true on success. -// The caller must add or remove connection if place was reserved. +// The caller must add connection or cancel reservation if it was reserved. func (l *connList) Reserve() bool { len := atomic.AddInt32(&l.len, 1) reserved := len <= l.size @@ -34,12 +34,16 @@ func (l *connList) Reserve() bool { return reserved } +func (l *connList) CancelReservation() { + atomic.AddInt32(&l.len, -1) +} + // Add adds connection to the list. The caller must reserve place first. func (l *connList) Add(cn *Conn) { l.mu.Lock() for i, c := range l.cns { if c == nil { - cn.idx = i + cn.SetIndex(i) l.cns[i] = cn l.mu.Unlock() return @@ -48,37 +52,34 @@ func (l *connList) Add(cn *Conn) { panic("not reached") } -// Remove closes connection and removes it from the list. -func (l *connList) Remove(cn *Conn) error { - atomic.AddInt32(&l.len, -1) - - if cn == nil { // free reserved place - return nil - } - +func (l *connList) Replace(cn *Conn) { l.mu.Lock() if l.cns != nil { - l.cns[cn.idx] = nil - cn.idx = -1 + l.cns[cn.idx] = cn } l.mu.Unlock() +} - return nil +// Remove closes connection and removes it from the list. +func (l *connList) Remove(idx int) { + l.mu.Lock() + if l.cns != nil { + l.cns[idx] = nil + l.len -= 1 + } + l.mu.Unlock() } func (l *connList) Close() error { - var retErr error l.mu.Lock() for _, c := range l.cns { if c == nil { continue } - if err := c.Close(); err != nil && retErr == nil { - retErr = err - } + c.Close() } l.cns = nil - atomic.StoreInt32(&l.len, 0) + l.len = 0 l.mu.Unlock() - return retErr + return nil } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 243ebea..4f2b217 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -14,7 +14,8 @@ import ( var Logger *log.Logger var ( - errClosed = errors.New("redis: client is closed") + ErrClosed = errors.New("redis: client is closed") + errConnClosed = errors.New("redis: connection is closed") ErrPoolTimeout = errors.New("redis: connection pool timeout") ) @@ -36,8 +37,9 @@ type Pooler interface { Replace(*Conn, error) error Len() int FreeLen() int - Close() error Stats() *PoolStats + Close() error + Closed() bool } type dialer func() (net.Conn, error) @@ -58,6 +60,8 @@ type ConnPool struct { lastErr atomic.Value } +var _ Pooler = (*ConnPool)(nil) + func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Duration) *ConnPool { p := &ConnPool{ _dial: dial, @@ -75,7 +79,7 @@ func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Durati return p } -func (p *ConnPool) closed() bool { +func (p *ConnPool) Closed() bool { return atomic.LoadInt32(&p._closed) == 1 } @@ -152,8 +156,8 @@ func (p *ConnPool) newConn() (*Conn, error) { // Get returns existed connection from the pool or creates a new one. func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) { - if p.closed() { - err = errClosed + if p.Closed() { + err = ErrClosed return } @@ -171,7 +175,7 @@ func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) { cn, err = p.newConn() if err != nil { - p.conns.Remove(nil) + p.conns.CancelReservation() return } p.conns.Add(cn) @@ -201,14 +205,20 @@ func (p *ConnPool) Put(cn *Conn) error { } func (p *ConnPool) replace(cn *Conn) (*Conn, error) { - _ = cn.Close() + idx := cn.Close() + if idx == -1 { + return nil, errConnClosed + } netConn, err := p.dial() if err != nil { - _ = p.conns.Remove(cn) + p.conns.Remove(idx) return nil, err } - cn.SetNetConn(netConn) + + cn = NewConn(netConn) + cn.SetIndex(idx) + p.conns.Replace(cn) return cn, nil } @@ -226,9 +236,14 @@ func (p *ConnPool) Replace(cn *Conn, reason error) error { } func (p *ConnPool) Remove(cn *Conn, reason error) error { + idx := cn.Close() + if idx == -1 { + return errConnClosed + } + p.storeLastErr(reason.Error()) - _ = cn.Close() - return p.conns.Remove(cn) + p.conns.Remove(idx) + return nil } // Len returns total number of connections. @@ -253,7 +268,7 @@ func (p *ConnPool) Stats() *PoolStats { func (p *ConnPool) Close() (retErr error) { if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) { - return errClosed + return ErrClosed } // Wait for app to free connections, but don't close them immediately. for i := 0; i < p.Len(); i++ { @@ -287,7 +302,7 @@ func (p *ConnPool) reaper() { defer ticker.Stop() for _ = range ticker.C { - if p.closed() { + if p.Closed() { break } n, err := p.ReapStaleConns() diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index e0ea868..f9ebfa6 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -4,6 +4,8 @@ type SingleConnPool struct { cn *Conn } +var _ Pooler = (*SingleConnPool)(nil) + func NewSingleConnPool(cn *Conn) *SingleConnPool { return &SingleConnPool{ cn: cn, @@ -40,8 +42,14 @@ func (p *SingleConnPool) FreeLen() int { return 0 } -func (p *SingleConnPool) Stats() *PoolStats { return nil } +func (p *SingleConnPool) Stats() *PoolStats { + return nil +} func (p *SingleConnPool) Close() error { return nil } + +func (p *SingleConnPool) Closed() bool { + return false +} diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index 8f4c324..11a7ee4 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -14,6 +14,8 @@ type StickyConnPool struct { mx sync.Mutex } +var _ Pooler = (*StickyConnPool)(nil) + func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool { return &StickyConnPool{ pool: pool, @@ -33,7 +35,7 @@ func (p *StickyConnPool) Get() (cn *Conn, isNew bool, err error) { p.mx.Lock() if p.closed { - err = errClosed + err = ErrClosed return } if p.cn != nil { @@ -59,7 +61,7 @@ func (p *StickyConnPool) Put(cn *Conn) error { defer p.mx.Unlock() p.mx.Lock() if p.closed { - return errClosed + return ErrClosed } if p.cn != cn { panic("p.cn != cn") @@ -77,7 +79,7 @@ func (p *StickyConnPool) Replace(cn *Conn, reason error) error { defer p.mx.Unlock() p.mx.Lock() if p.closed { - return errClosed + return nil } if p.cn == nil { panic("p.cn == nil") @@ -112,7 +114,7 @@ func (p *StickyConnPool) Close() error { defer p.mx.Unlock() p.mx.Lock() if p.closed { - return errClosed + return ErrClosed } p.closed = true var err error @@ -126,3 +128,10 @@ func (p *StickyConnPool) Close() error { } return err } + +func (p *StickyConnPool) Closed() bool { + p.mx.Lock() + closed := p.closed + p.mx.Unlock() + return closed +} diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 07d3a52..5dd7784 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -31,12 +31,14 @@ var _ = Describe("conns reapser", func() { cn := pool.NewConn(&net.TCPConn{}) cn.UsedAt = time.Now().Add(-2 * time.Minute) Expect(connPool.Add(cn)).To(BeTrue()) + Expect(cn.Index()).To(Equal(i)) } // add fresh connections for i := 0; i < 3; i++ { cn := pool.NewConn(&net.TCPConn{}) Expect(connPool.Add(cn)).To(BeTrue()) + Expect(cn.Index()).To(Equal(3 + i)) } Expect(connPool.Len()).To(Equal(6)) diff --git a/main_test.go b/main_test.go index e1d65c6..e3e747f 100644 --- a/main_test.go +++ b/main_test.go @@ -6,6 +6,7 @@ import ( "os" "os/exec" "path/filepath" + "sync" "sync/atomic" "testing" "time" @@ -98,6 +99,20 @@ func TestGinkgoSuite(t *testing.T) { //------------------------------------------------------------------------------ +func perform(n int, cb func()) { + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + + cb() + }() + } + wg.Wait() +} + func eventually(fn func() error, timeout time.Duration) error { done := make(chan struct{}) var exit int32 @@ -138,7 +153,7 @@ func connectTo(port string) (*redis.Client, error) { err := eventually(func() error { return client.Ping().Err() - }, 10*time.Second) + }, 30*time.Second) if err != nil { return nil, err } diff --git a/multi.go b/multi.go index 6b43591..a049821 100644 --- a/multi.go +++ b/multi.go @@ -109,7 +109,7 @@ func (c *Multi) Discard() error { // failed command or nil. func (c *Multi) Exec(f func() error) ([]Cmder, error) { if c.closed { - return nil, errClosed + return nil, pool.ErrClosed } c.cmds = []Cmder{NewStatusCmd("MULTI")} diff --git a/multi_test.go b/multi_test.go index 459d0a6..fa532d1 100644 --- a/multi_test.go +++ b/multi_test.go @@ -145,7 +145,7 @@ var _ = Describe("Multi", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.SetNetConn(&badConn{}) + cn.NetConn = &badConn{} err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) @@ -172,7 +172,7 @@ var _ = Describe("Multi", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.SetNetConn(&badConn{}) + cn.NetConn = &badConn{} err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) diff --git a/pipeline.go b/pipeline.go index 098207c..842fad7 100644 --- a/pipeline.go +++ b/pipeline.go @@ -62,7 +62,7 @@ func (pipe *Pipeline) Discard() error { defer pipe.mu.Unlock() pipe.mu.Lock() if pipe.isClosed() { - return errClosed + return pool.ErrClosed } pipe.cmds = pipe.cmds[:0] return nil @@ -75,7 +75,7 @@ func (pipe *Pipeline) Discard() error { // command if any. func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) { if pipe.isClosed() { - return nil, errClosed + return nil, pool.ErrClosed } defer pipe.mu.Unlock() diff --git a/pool_test.go b/pool_test.go index a5b0721..225ad6a 100644 --- a/pool_test.go +++ b/pool_test.go @@ -2,7 +2,6 @@ package redis_test import ( "errors" - "sync" "time" . "github.com/onsi/ginkgo" @@ -14,20 +13,6 @@ import ( var _ = Describe("pool", func() { var client *redis.Client - var perform = func(n int, cb func()) { - wg := &sync.WaitGroup{} - for i := 0; i < n; i++ { - wg.Add(1) - go func() { - defer GinkgoRecover() - defer wg.Done() - - cb() - }() - } - wg.Wait() - } - BeforeEach(func() { client = redis.NewClient(&redis.Options{ Addr: redisAddr, @@ -108,12 +93,11 @@ var _ = Describe("pool", func() { It("should remove broken connections", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - Expect(cn.Close()).NotTo(HaveOccurred()) + cn.NetConn = &badConn{} Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) err = client.Ping().Err() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("use of closed network connection")) + Expect(err).To(MatchError("bad connection")) val, err := client.Ping().Result() Expect(err).NotTo(HaveOccurred()) diff --git a/pubsub.go b/pubsub.go index f1c93c8..68b2aeb 100644 --- a/pubsub.go +++ b/pubsub.go @@ -54,6 +54,7 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { if err != nil { return err } + c.putConn(cn, err) args := make([]interface{}, 1+len(channels)) args[0] = redisCmd @@ -306,6 +307,9 @@ func (c *PubSub) putConn(cn *pool.Conn, err error) { } func (c *PubSub) resubscribe() { + if c.base.closed() { + return + } if len(c.channels) > 0 { if err := c.Subscribe(c.channels...); err != nil { Logger.Printf("Subscribe failed: %s", err) diff --git a/pubsub_test.go b/pubsub_test.go index 669c073..a8bb610 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -291,10 +291,10 @@ var _ = Describe("PubSub", func() { expectReceiveMessageOnError := func(pubsub *redis.PubSub) { cn1, _, err := pubsub.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn1.SetNetConn(&badConn{ + cn1.NetConn = &badConn{ readErr: io.EOF, writeErr: io.EOF, - }) + } done := make(chan bool, 1) go func() { diff --git a/redis.go b/redis.go index 55f4757..aab5ba6 100644 --- a/redis.go +++ b/redis.go @@ -45,17 +45,11 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) { func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { if isBadConn(err, allowTimeout) { - err = c.connPool.Replace(cn, err) - if err != nil { - Logger.Printf("pool.Remove failed: %s", err) - } + _ = c.connPool.Replace(cn, err) return false } - err = c.connPool.Put(cn) - if err != nil { - Logger.Printf("pool.Put failed: %s", err) - } + _ = c.connPool.Put(cn) return true } @@ -121,6 +115,10 @@ func (c *baseClient) process(cmd Cmder) { } } +func (c *baseClient) closed() bool { + return c.connPool.Closed() +} + // Close closes the client, releasing any open resources. // // It is rare to Close a Client, as the Client is meant to be diff --git a/redis_test.go b/redis_test.go index 1435b7a..23c3900 100644 --- a/redis_test.go +++ b/redis_test.go @@ -160,7 +160,7 @@ var _ = Describe("Client", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.SetNetConn(&badConn{}) + cn.NetConn = &badConn{} err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) diff --git a/ring.go b/ring.go index 3b88d7c..c66a5bc 100644 --- a/ring.go +++ b/ring.go @@ -149,7 +149,7 @@ func (ring *Ring) getClient(key string) (*Client, error) { ring.mx.RLock() if ring.closed { - return nil, errClosed + return nil, pool.ErrClosed } name := ring.hash.Get(hashtag.Key(key)) @@ -277,7 +277,7 @@ func (pipe *RingPipeline) process(cmd Cmder) { // Discard resets the pipeline and discards queued commands. func (pipe *RingPipeline) Discard() error { if pipe.closed { - return errClosed + return pool.ErrClosed } pipe.cmds = pipe.cmds[:0] return nil @@ -287,7 +287,7 @@ func (pipe *RingPipeline) Discard() error { // command if any. func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) { if pipe.closed { - return nil, errClosed + return nil, pool.ErrClosed } if len(pipe.cmds) == 0 { return pipe.cmds, nil