diff --git a/cluster_pipeline.go b/cluster_pipeline.go index 8641d3df..6d552659 100644 --- a/cluster_pipeline.go +++ b/cluster_pipeline.go @@ -69,7 +69,7 @@ func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) { continue } - cn, _, err := client.conn() + cn, err := client.conn() if err != nil { setCmdsErr(cmds, err) retErr = err diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index ba6df0c0..5acc5e2c 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -20,7 +20,7 @@ func benchmarkPoolGetPut(b *testing.B, poolSize int) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - conn, _, err := pool.Get() + conn, err := pool.Get() if err != nil { b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) } @@ -56,7 +56,7 @@ func benchmarkPoolGetReplace(b *testing.B, poolSize int) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - conn, _, err := pool.Get() + conn, err := pool.Get() if err != nil { b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) } diff --git a/internal/pool/conn.go b/internal/pool/conn.go index cbe379b1..c3768862 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -18,7 +18,9 @@ type Conn struct { Rd *bufio.Reader Buf []byte - UsedAt time.Time + Inited bool + UsedAt time.Time + ReadTimeout time.Duration WriteTimeout time.Duration } @@ -40,8 +42,12 @@ func (cn *Conn) Index() int { return int(atomic.LoadInt32(&cn.idx)) } -func (cn *Conn) SetIndex(idx int) { - atomic.StoreInt32(&cn.idx, int32(idx)) +func (cn *Conn) SetIndex(newIdx int) int { + oldIdx := cn.Index() + if !atomic.CompareAndSwapInt32(&cn.idx, int32(oldIdx), int32(newIdx)) { + return -1 + } + return oldIdx } func (cn *Conn) IsStale(timeout time.Duration) bool { @@ -72,11 +78,6 @@ func (cn *Conn) RemoteAddr() net.Addr { return cn.NetConn.RemoteAddr() } -func (cn *Conn) Close() int { - idx := cn.Index() - if !atomic.CompareAndSwapInt32(&cn.idx, int32(idx), -1) { - return -1 - } - _ = cn.NetConn.Close() - return idx +func (cn *Conn) Close() error { + return cn.NetConn.Close() } diff --git a/internal/pool/conn_list.go b/internal/pool/conn_list.go index 7e43ee74..b3f58704 100644 --- a/internal/pool/conn_list.go +++ b/internal/pool/conn_list.go @@ -43,7 +43,7 @@ func (l *connList) Add(cn *Conn) { l.mu.Lock() for i, c := range l.cns { if c == nil { - cn.SetIndex(i) + cn.idx = int32(i) l.cns[i] = cn l.mu.Unlock() return @@ -76,6 +76,7 @@ func (l *connList) Close() error { if c == nil { continue } + c.idx = -1 c.Close() } l.cns = nil diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 4f2b2175..4de11fc6 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -32,7 +32,7 @@ type PoolStats struct { type Pooler interface { First() *Conn - Get() (*Conn, bool, error) + Get() (*Conn, error) Put(*Conn) error Replace(*Conn, error) error Len() int @@ -146,7 +146,7 @@ func (p *ConnPool) dial() (net.Conn, error) { return cn, nil } -func (p *ConnPool) newConn() (*Conn, error) { +func (p *ConnPool) NewConn() (*Conn, error) { netConn, err := p.dial() if err != nil { return nil, err @@ -155,42 +155,38 @@ func (p *ConnPool) newConn() (*Conn, error) { } // Get returns existed connection from the pool or creates a new one. -func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) { +func (p *ConnPool) Get() (*Conn, error) { if p.Closed() { - err = ErrClosed - return + return nil, ErrClosed } atomic.AddUint32(&p.stats.Requests, 1) // Fetch first non-idle connection, if available. - if cn = p.First(); cn != nil { + if cn := p.First(); cn != nil { atomic.AddUint32(&p.stats.Hits, 1) - return + return cn, nil } // Try to create a new one. if p.conns.Reserve() { - isNew = true - - cn, err = p.newConn() + cn, err := p.NewConn() if err != nil { p.conns.CancelReservation() - return + return nil, err } p.conns.Add(cn) - return + return cn, nil } // Otherwise, wait for the available connection. atomic.AddUint32(&p.stats.Waits, 1) - if cn = p.wait(); cn != nil { - return + if cn := p.wait(); cn != nil { + return cn, nil } atomic.AddUint32(&p.stats.Timeouts, 1) - err = ErrPoolTimeout - return + return nil, ErrPoolTimeout } func (p *ConnPool) Put(cn *Conn) error { @@ -205,7 +201,9 @@ func (p *ConnPool) Put(cn *Conn) error { } func (p *ConnPool) replace(cn *Conn) (*Conn, error) { - idx := cn.Close() + _ = cn.Close() + + idx := cn.SetIndex(-1) if idx == -1 { return nil, errConnClosed } @@ -236,7 +234,9 @@ func (p *ConnPool) Replace(cn *Conn, reason error) error { } func (p *ConnPool) Remove(cn *Conn, reason error) error { - idx := cn.Close() + _ = cn.Close() + + idx := cn.SetIndex(-1) if idx == -1 { return errConnClosed } diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index f9ebfa6d..39362c02 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -16,8 +16,8 @@ func (p *SingleConnPool) First() *Conn { return p.cn } -func (p *SingleConnPool) Get() (*Conn, bool, error) { - return p.cn, false, nil +func (p *SingleConnPool) Get() (*Conn, error) { + return p.cn, nil } func (p *SingleConnPool) Put(cn *Conn) error { diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index 11a7ee49..8b76b6f6 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -30,25 +30,23 @@ func (p *StickyConnPool) First() *Conn { return cn } -func (p *StickyConnPool) Get() (cn *Conn, isNew bool, err error) { +func (p *StickyConnPool) Get() (*Conn, error) { defer p.mx.Unlock() p.mx.Lock() if p.closed { - err = ErrClosed - return + return nil, ErrClosed } if p.cn != nil { - cn = p.cn - return + return p.cn, nil } - cn, isNew, err = p.pool.Get() + cn, err := p.pool.Get() if err != nil { - return + return nil, err } p.cn = cn - return + return cn, nil } func (p *StickyConnPool) put() (err error) { diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 5dd7784e..1c591924 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -69,9 +69,8 @@ var _ = Describe("conns reapser", func() { cn := connPool.First() Expect(cn).To(BeNil()) - cn, isNew, err := connPool.Get() + cn, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) - Expect(isNew).To(BeTrue()) Expect(cn).NotTo(BeNil()) Expect(connPool.Len()).To(Equal(4)) diff --git a/multi.go b/multi.go index a0498211..79b7cb6d 100644 --- a/multi.go +++ b/multi.go @@ -128,7 +128,7 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) { // Strip MULTI and EXEC commands. retCmds := cmds[1 : len(cmds)-1] - cn, _, err := c.base.conn() + cn, err := c.base.conn() if err != nil { setCmdsErr(retCmds, err) return retCmds, err diff --git a/multi_test.go b/multi_test.go index fa532d1a..a82a347a 100644 --- a/multi_test.go +++ b/multi_test.go @@ -142,7 +142,7 @@ var _ = Describe("Multi", func() { It("should recover from bad connection", func() { // Put bad connection in the pool. - cn, _, err := client.Pool().Get() + cn, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn.NetConn = &badConn{} @@ -169,7 +169,7 @@ var _ = Describe("Multi", func() { It("should recover from bad connection when there are no commands", func() { // Put bad connection in the pool. - cn, _, err := client.Pool().Get() + cn, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn.NetConn = &badConn{} diff --git a/pipeline.go b/pipeline.go index 842fad7b..888d8c40 100644 --- a/pipeline.go +++ b/pipeline.go @@ -90,7 +90,7 @@ func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) { failedCmds := cmds for i := 0; i <= pipe.client.opt.MaxRetries; i++ { - cn, _, err := pipe.client.conn() + cn, err := pipe.client.conn() if err != nil { setCmdsErr(failedCmds, err) return cmds, err diff --git a/pool_test.go b/pool_test.go index 225ad6ad..006ab0be 100644 --- a/pool_test.go +++ b/pool_test.go @@ -91,7 +91,7 @@ var _ = Describe("pool", func() { }) It("should remove broken connections", func() { - cn, _, err := client.Pool().Get() + cn, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn.NetConn = &badConn{} Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) @@ -136,12 +136,12 @@ var _ = Describe("pool", func() { pool := client.Pool() // Reserve one connection. - cn, _, err := pool.Get() + cn, err := pool.Get() Expect(err).NotTo(HaveOccurred()) // Reserve the rest of connections. for i := 0; i < 9; i++ { - _, _, err := pool.Get() + _, err := pool.Get() Expect(err).NotTo(HaveOccurred()) } @@ -181,7 +181,7 @@ var _ = Describe("pool", func() { var rateErr error for i := 0; i < 1000; i++ { - cn, _, err := pool.Get() + cn, err := pool.Get() if err != nil { rateErr = err break diff --git a/pubsub.go b/pubsub.go index 68b2aeb5..05e59213 100644 --- a/pubsub.go +++ b/pubsub.go @@ -50,7 +50,7 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) { } func (c *PubSub) subscribe(redisCmd string, channels ...string) error { - cn, _, err := c.base.conn() + cn, err := c.base.conn() if err != nil { return err } @@ -126,7 +126,7 @@ func (c *PubSub) Close() error { } func (c *PubSub) Ping(payload string) error { - cn, _, err := c.base.conn() + cn, err := c.base.conn() if err != nil { return err } @@ -226,7 +226,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { c.resubscribe() } - cn, _, err := c.base.conn() + cn, err := c.base.conn() if err != nil { return nil, err } diff --git a/pubsub_test.go b/pubsub_test.go index a8bb610b..835d7c1a 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -289,7 +289,7 @@ var _ = Describe("PubSub", func() { }) expectReceiveMessageOnError := func(pubsub *redis.PubSub) { - cn1, _, err := pubsub.Pool().Get() + cn1, err := pubsub.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn1.NetConn = &badConn{ readErr: io.EOF, diff --git a/redis.go b/redis.go index aab5ba63..dc55572e 100644 --- a/redis.go +++ b/redis.go @@ -32,15 +32,18 @@ func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB) } -func (c *baseClient) conn() (*pool.Conn, bool, error) { - cn, isNew, err := c.connPool.Get() - if err == nil && isNew { - err = c.initConn(cn) - if err != nil { - c.putConn(cn, err, false) +func (c *baseClient) conn() (*pool.Conn, error) { + cn, err := c.connPool.Get() + if err != nil { + return nil, err + } + if !cn.Inited { + if err := c.initConn(cn); err != nil { + _ = c.connPool.Replace(cn, err) + return nil, err } } - return cn, isNew, err + return cn, err } func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { @@ -54,6 +57,8 @@ func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { } func (c *baseClient) initConn(cn *pool.Conn) error { + cn.Inited = true + if c.opt.Password == "" && c.opt.DB == 0 { return nil } @@ -82,7 +87,7 @@ func (c *baseClient) process(cmd Cmder) { cmd.reset() } - cn, _, err := c.conn() + cn, err := c.conn() if err != nil { cmd.setErr(err) return diff --git a/redis_test.go b/redis_test.go index 23c39009..8b3d8dbd 100644 --- a/redis_test.go +++ b/redis_test.go @@ -157,7 +157,7 @@ var _ = Describe("Client", func() { }) // Put bad connection in the pool. - cn, _, err := client.Pool().Get() + cn, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn.NetConn = &badConn{} @@ -169,7 +169,7 @@ var _ = Describe("Client", func() { }) It("should maintain conn.UsedAt", func() { - cn, _, err := client.Pool().Get() + cn, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) Expect(cn.UsedAt).NotTo(BeZero()) createdAt := cn.UsedAt diff --git a/ring.go b/ring.go index c66a5bc1..32212161 100644 --- a/ring.go +++ b/ring.go @@ -314,7 +314,7 @@ func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) { for name, cmds := range cmdsMap { client := pipe.ring.shards[name].Client - cn, _, err := client.conn() + cn, err := client.conn() if err != nil { setCmdsErr(cmds, err) if retErr == nil {