diff --git a/cluster.go b/cluster.go index 15b17aa..402447e 100644 --- a/cluster.go +++ b/cluster.go @@ -516,7 +516,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { } } - cn, err := node.Client.conn() + cn, _, err := node.Client.conn() if err != nil { setCmdsErr(cmds, err) setRetErr(err) diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index 878b202..663abc0 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -16,7 +16,7 @@ func benchmarkPoolGetPut(b *testing.B, poolSize int) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - cn, err := connPool.Get() + cn, _, err := connPool.Get() if err != nil { b.Fatal(err) } @@ -48,7 +48,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - cn, err := connPool.Get() + cn, _, err := connPool.Get() if err != nil { b.Fatal(err) } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 5c3fa06..55c1f9f 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -36,7 +36,7 @@ type PoolStats struct { } type Pooler interface { - Get() (*Conn, error) + Get() (*Conn, bool, error) Put(*Conn) error Remove(*Conn, error) error Len() int @@ -152,9 +152,9 @@ func (p *ConnPool) popFree() *Conn { } // Get returns existed connection from the pool or creates a new one. -func (p *ConnPool) Get() (*Conn, error) { +func (p *ConnPool) Get() (*Conn, bool, error) { if p.Closed() { - return nil, ErrClosed + return nil, false, ErrClosed } atomic.AddUint32(&p.stats.Requests, 1) @@ -170,7 +170,7 @@ func (p *ConnPool) Get() (*Conn, error) { case <-timer.C: timers.Put(timer) atomic.AddUint32(&p.stats.Timeouts, 1) - return nil, ErrPoolTimeout + return nil, false, ErrPoolTimeout } p.freeConnsMu.Lock() @@ -180,7 +180,7 @@ func (p *ConnPool) Get() (*Conn, error) { if cn != nil { atomic.AddUint32(&p.stats.Hits, 1) if !cn.IsStale(p.idleTimeout) { - return cn, nil + return cn, false, nil } _ = p.closeConn(cn, errConnStale) } @@ -188,7 +188,7 @@ func (p *ConnPool) Get() (*Conn, error) { newcn, err := p.NewConn() if err != nil { <-p.queue - return nil, err + return nil, false, err } p.connsMu.Lock() @@ -198,7 +198,7 @@ func (p *ConnPool) Get() (*Conn, error) { p.conns = append(p.conns, newcn) p.connsMu.Unlock() - return newcn, nil + return newcn, true, nil } func (p *ConnPool) Put(cn *Conn) error { diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index cb1863e..0cf6c7c 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -16,8 +16,8 @@ func (p *SingleConnPool) First() *Conn { return p.cn } -func (p *SingleConnPool) Get() (*Conn, error) { - return p.cn, nil +func (p *SingleConnPool) Get() (*Conn, bool, error) { + return p.cn, false, nil } func (p *SingleConnPool) Put(cn *Conn) error { diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index 24a4f75..a2649e5 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -30,23 +30,23 @@ func (p *StickyConnPool) First() *Conn { return cn } -func (p *StickyConnPool) Get() (*Conn, error) { +func (p *StickyConnPool) Get() (*Conn, bool, error) { defer p.mx.Unlock() p.mx.Lock() if p.closed { - return nil, ErrClosed + return nil, false, ErrClosed } if p.cn != nil { - return p.cn, nil + return p.cn, false, nil } - cn, err := p.pool.Get() + cn, _, err := p.pool.Get() if err != nil { - return nil, err + return nil, false, err } p.cn = cn - return cn, nil + return cn, true, nil } func (p *StickyConnPool) put() (err error) { diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 5fe7e8d..425ce92 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -26,7 +26,7 @@ var _ = Describe("ConnPool", func() { It("rate limits dial", func() { var rateErr error for i := 0; i < 1000; i++ { - cn, err := connPool.Get() + cn, _, err := connPool.Get() if err != nil { rateErr = err break @@ -40,13 +40,13 @@ var _ = Describe("ConnPool", func() { It("should unblock client when conn is removed", func() { // Reserve one connection. - cn, err := connPool.Get() + cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) // Reserve all other connections. var cns []*pool.Conn for i := 0; i < 9; i++ { - cn, err := connPool.Get() + cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) cns = append(cns, cn) } @@ -57,7 +57,7 @@ var _ = Describe("ConnPool", func() { defer GinkgoRecover() started <- true - _, err := connPool.Get() + _, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) done <- true @@ -113,7 +113,7 @@ var _ = Describe("conns reaper", func() { // add stale connections idleConns = nil for i := 0; i < 3; i++ { - cn, err := connPool.Get() + cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) cn.UsedAt = time.Now().Add(-2 * idleTimeout) conns = append(conns, cn) @@ -122,7 +122,7 @@ var _ = Describe("conns reaper", func() { // add fresh connections for i := 0; i < 3; i++ { - cn, err := connPool.Get() + cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) conns = append(conns, cn) } @@ -167,7 +167,7 @@ var _ = Describe("conns reaper", func() { for j := 0; j < 3; j++ { var freeCns []*pool.Conn for i := 0; i < 3; i++ { - cn, err := connPool.Get() + cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) freeCns = append(freeCns, cn) @@ -176,7 +176,7 @@ var _ = Describe("conns reaper", func() { Expect(connPool.Len()).To(Equal(3)) Expect(connPool.FreeLen()).To(Equal(0)) - cn, err := connPool.Get() + cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) conns = append(conns, cn) @@ -224,7 +224,7 @@ var _ = Describe("race", func() { perform(C, func(id int) { for i := 0; i < N; i++ { - cn, err := connPool.Get() + cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) if err == nil { Expect(connPool.Put(cn)).NotTo(HaveOccurred()) @@ -232,7 +232,7 @@ var _ = Describe("race", func() { } }, func(id int) { for i := 0; i < N; i++ { - cn, err := connPool.Get() + cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) if err == nil { Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred()) @@ -248,7 +248,7 @@ var _ = Describe("race", func() { perform(C, func(id int) { for i := 0; i < N; i++ { - cn, err := connPool.Get() + cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) if err == nil { Expect(connPool.Put(cn)).NotTo(HaveOccurred()) diff --git a/pool_test.go b/pool_test.go index c1d2f68..13edd70 100644 --- a/pool_test.go +++ b/pool_test.go @@ -91,7 +91,7 @@ var _ = Describe("pool", func() { }) It("should remove broken connections", func() { - cn, err := client.Pool().Get() + cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn.NetConn = &badConn{} Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) diff --git a/pubsub.go b/pubsub.go index de8d35c..796bd4a 100644 --- a/pubsub.go +++ b/pubsub.go @@ -18,12 +18,25 @@ type PubSub struct { channels []string patterns []string +} - nsub int // number of active subscriptions +func (c *PubSub) conn() (*pool.Conn, bool, error) { + cn, isNew, err := c.base.conn() + if err != nil { + return nil, false, err + } + if isNew { + c.resubscribe() + } + return cn, isNew, nil +} + +func (c *PubSub) putConn(cn *pool.Conn, err error) { + c.base.putConn(cn, err, true) } func (c *PubSub) subscribe(redisCmd string, channels ...string) error { - cn, err := c.base.conn() + cn, _, err := c.conn() if err != nil { return err } @@ -44,7 +57,6 @@ func (c *PubSub) Subscribe(channels ...string) error { err := c.subscribe("SUBSCRIBE", channels...) if err == nil { c.channels = appendIfNotExists(c.channels, channels...) - c.nsub += len(channels) } return err } @@ -54,43 +66,10 @@ func (c *PubSub) PSubscribe(patterns ...string) error { err := c.subscribe("PSUBSCRIBE", patterns...) if err == nil { c.patterns = appendIfNotExists(c.patterns, patterns...) - c.nsub += len(patterns) } return err } -func remove(ss []string, es ...string) []string { - if len(es) == 0 { - return ss[:0] - } - for _, e := range es { - for i, s := range ss { - if s == e { - ss = append(ss[:i], ss[i+1:]...) - break - } - } - } - return ss -} - -func appendIfNotExists(ss []string, es ...string) []string { - for _, e := range es { - found := false - for _, s := range ss { - if s == e { - found = true - break - } - } - - if !found { - ss = append(ss, e) - } - } - return ss -} - // Unsubscribes the client from the given channels, or from all of // them if none is given. func (c *PubSub) Unsubscribe(channels ...string) error { @@ -116,7 +95,7 @@ func (c *PubSub) Close() error { } func (c *PubSub) Ping(payload string) error { - cn, err := c.base.conn() + cn, _, err := c.conn() if err != nil { return err } @@ -198,11 +177,7 @@ func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) { // is not received in time. This is low-level API and most clients // should use ReceiveMessage. func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { - if c.nsub == 0 { - c.resubscribe() - } - - cn, err := c.base.conn() + cn, _, err := c.conn() if err != nil { return nil, err } @@ -274,12 +249,6 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) { } } -func (c *PubSub) putConn(cn *pool.Conn, err error) { - if !c.base.putConn(cn, err, true) { - c.nsub = 0 - } -} - func (c *PubSub) resubscribe() { if c.base.closed() { return @@ -295,3 +264,31 @@ func (c *PubSub) resubscribe() { } } } + +func remove(ss []string, es ...string) []string { + if len(es) == 0 { + return ss[:0] + } + for _, e := range es { + for i, s := range ss { + if s == e { + ss = append(ss[:i], ss[i+1:]...) + break + } + } + } + return ss +} + +func appendIfNotExists(ss []string, es ...string) []string { +loop: + for _, e := range es { + for _, s := range ss { + if s == e { + continue loop + } + } + ss = append(ss, e) + } + return ss +} diff --git a/pubsub_test.go b/pubsub_test.go index 116099b..957d303 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -288,7 +288,7 @@ var _ = Describe("PubSub", func() { }) expectReceiveMessageOnError := func(pubsub *redis.PubSub) { - cn1, err := pubsub.Pool().Get() + cn1, _, err := pubsub.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn1.NetConn = &badConn{ readErr: io.EOF, diff --git a/redis.go b/redis.go index ea0a2b4..4b535ff 100644 --- a/redis.go +++ b/redis.go @@ -27,18 +27,18 @@ func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB) } -func (c *baseClient) conn() (*pool.Conn, error) { - cn, err := c.connPool.Get() +func (c *baseClient) conn() (*pool.Conn, bool, error) { + cn, isNew, err := c.connPool.Get() if err != nil { - return nil, err + return nil, false, err } if !cn.Inited { if err := c.initConn(cn); err != nil { _ = c.connPool.Remove(cn, err) - return nil, err + return nil, false, err } } - return cn, err + return cn, isNew, nil } func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { @@ -84,7 +84,7 @@ func (c *baseClient) Process(cmd Cmder) error { cmd.reset() } - cn, err := c.conn() + cn, _, err := c.conn() if err != nil { cmd.setErr(err) return err @@ -197,7 +197,7 @@ func (c *Client) pipelineExec(cmds []Cmder) error { var retErr error failedCmds := cmds for i := 0; i <= c.opt.MaxRetries; i++ { - cn, err := c.conn() + cn, _, err := c.conn() if err != nil { setCmdsErr(failedCmds, err) return err diff --git a/redis_test.go b/redis_test.go index 4b62f3e..0b59547 100644 --- a/redis_test.go +++ b/redis_test.go @@ -144,7 +144,7 @@ var _ = Describe("Client", func() { }) // Put bad connection in the pool. - cn, err := client.Pool().Get() + cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn.NetConn = &badConn{} @@ -156,7 +156,7 @@ var _ = Describe("Client", func() { }) It("should update conn.UsedAt on read/write", func() { - cn, err := client.Pool().Get() + cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) Expect(cn.UsedAt).NotTo(BeZero()) createdAt := cn.UsedAt @@ -168,7 +168,7 @@ var _ = Describe("Client", func() { err = client.Ping().Err() Expect(err).NotTo(HaveOccurred()) - cn, err = client.Pool().Get() + cn, _, err = client.Pool().Get() Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) Expect(cn.UsedAt.After(createdAt)).To(BeTrue()) diff --git a/ring.go b/ring.go index 5b2fdde..570c535 100644 --- a/ring.go +++ b/ring.go @@ -318,7 +318,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) error { for name, cmds := range cmdsMap { client := c.shards[name].Client - cn, err := client.conn() + cn, _, err := client.conn() if err != nil { setCmdsErr(cmds, err) if retErr == nil { diff --git a/tx.go b/tx.go index ca425eb..62e7701 100644 --- a/tx.go +++ b/tx.go @@ -139,7 +139,7 @@ func (c *Tx) MultiExec(fn func() error) ([]Cmder, error) { // Strip MULTI and EXEC commands. retCmds := cmds[1 : len(cmds)-1] - cn, err := c.conn() + cn, _, err := c.conn() if err != nil { setCmdsErr(retCmds, err) return retCmds, err diff --git a/tx_test.go b/tx_test.go index 7ff84dd..ddb9ddf 100644 --- a/tx_test.go +++ b/tx_test.go @@ -126,7 +126,7 @@ var _ = Describe("Tx", func() { It("should recover from bad connection", func() { // Put bad connection in the pool. - cn, err := client.Pool().Get() + cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn.NetConn = &badConn{} @@ -153,7 +153,7 @@ var _ = Describe("Tx", func() { It("should recover from bad connection when there are no commands", func() { // Put bad connection in the pool. - cn, err := client.Pool().Get() + cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn.NetConn = &badConn{}