diff --git a/bench_test.go b/bench_test.go index 5d6fa37f..f4cffc1d 100644 --- a/bench_test.go +++ b/bench_test.go @@ -278,8 +278,8 @@ func BenchmarkZAdd(b *testing.B) { } func benchmarkPoolGetPut(b *testing.B, poolSize int) { - dial := func() (*pool.Conn, error) { - return pool.NewConn(&net.TCPConn{}), nil + dial := func() (net.Conn, error) { + return &net.TCPConn{}, nil } pool := pool.NewConnPool(dial, poolSize, time.Second, 0) @@ -311,8 +311,8 @@ func BenchmarkPoolGetPut1000Conns(b *testing.B) { } func benchmarkPoolGetRemove(b *testing.B, poolSize int) { - dial := func() (*pool.Conn, error) { - return pool.NewConn(&net.TCPConn{}), nil + dial := func() (net.Conn, error) { + return &net.TCPConn{}, nil } pool := pool.NewConnPool(dial, poolSize, time.Second, 0) removeReason := errors.New("benchmark") @@ -325,7 +325,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) { if err != nil { b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) } - if err = pool.Remove(conn, removeReason); err != nil { + if err = pool.Replace(conn, removeReason); err != nil { b.Fatalf("no error expected on pool.Remove but received: %s", err.Error()) } } diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 1e1e8611..c5a539bc 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -8,10 +8,12 @@ import ( const defaultBufSize = 4096 -var noTimeout = time.Time{} +var noDeadline = time.Time{} type Conn struct { - NetConn net.Conn + idx int + + netConn net.Conn Rd *bufio.Reader Buf []byte @@ -22,7 +24,9 @@ type Conn struct { func NewConn(netConn net.Conn) *Conn { cn := &Conn{ - NetConn: netConn, + idx: -1, + + netConn: netConn, Buf: make([]byte, defaultBufSize), UsedAt: time.Now(), @@ -31,30 +35,35 @@ func NewConn(netConn net.Conn) *Conn { return cn } +func (cn *Conn) SetNetConn(netConn net.Conn) { + cn.netConn = netConn + cn.UsedAt = time.Now() +} + func (cn *Conn) Read(b []byte) (int, error) { cn.UsedAt = time.Now() if cn.ReadTimeout != 0 { - cn.NetConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout)) + cn.netConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout)) } else { - cn.NetConn.SetReadDeadline(noTimeout) + cn.netConn.SetReadDeadline(noDeadline) } - return cn.NetConn.Read(b) + return cn.netConn.Read(b) } func (cn *Conn) Write(b []byte) (int, error) { cn.UsedAt = time.Now() if cn.WriteTimeout != 0 { - cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout)) + cn.netConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout)) } else { - cn.NetConn.SetWriteDeadline(noTimeout) + cn.netConn.SetWriteDeadline(noDeadline) } - return cn.NetConn.Write(b) + return cn.netConn.Write(b) } func (cn *Conn) RemoteAddr() net.Addr { - return cn.NetConn.RemoteAddr() + return cn.netConn.RemoteAddr() } func (cn *Conn) Close() error { - return cn.NetConn.Close() + return cn.netConn.Close() } diff --git a/internal/pool/conn_list.go b/internal/pool/conn_list.go index f8b82ab2..e72dc91d 100644 --- a/internal/pool/conn_list.go +++ b/internal/pool/conn_list.go @@ -7,14 +7,14 @@ import ( type connList struct { cns []*Conn - mx sync.Mutex + mu sync.Mutex len int32 // atomic size int32 } func newConnList(size int) *connList { return &connList{ - cns: make([]*Conn, 0, size), + cns: make([]*Conn, size), size: int32(size), } } @@ -23,8 +23,8 @@ func (l *connList) Len() int { return int(atomic.LoadInt32(&l.len)) } -// Reserve reserves place in the list and returns true on success. The -// caller must add or remove connection if place was reserved. +// Reserve reserves place in the list and returns true on success. +// The caller must add or remove connection if place was reserved. func (l *connList) Reserve() bool { len := atomic.AddInt32(&l.len, 1) reserved := len <= l.size @@ -36,65 +36,49 @@ func (l *connList) Reserve() bool { // Add adds connection to the list. The caller must reserve place first. func (l *connList) Add(cn *Conn) { - l.mx.Lock() - l.cns = append(l.cns, cn) - l.mx.Unlock() + l.mu.Lock() + for i, c := range l.cns { + if c == nil { + cn.idx = i + l.cns[i] = cn + l.mu.Unlock() + return + } + } + panic("not reached") } // Remove closes connection and removes it from the list. func (l *connList) Remove(cn *Conn) error { - defer l.mx.Unlock() - l.mx.Lock() + atomic.AddInt32(&l.len, -1) - if cn == nil { - atomic.AddInt32(&l.len, -1) + if cn == nil { // free reserved place return nil } - for i, c := range l.cns { - if c == cn { - l.cns = append(l.cns[:i], l.cns[i+1:]...) - atomic.AddInt32(&l.len, -1) - return cn.Close() - } + l.mu.Lock() + if l.cns != nil { + l.cns[cn.idx] = nil + cn.idx = -1 } + l.mu.Unlock() - if l.closed() { - return nil - } - panic("conn not found in the list") + return nil } -func (l *connList) Replace(cn, newcn *Conn) error { - defer l.mx.Unlock() - l.mx.Lock() - - for i, c := range l.cns { - if c == cn { - l.cns[i] = newcn - return cn.Close() - } - } - - if l.closed() { - return newcn.Close() - } - panic("conn not found in the list") -} - -func (l *connList) Close() (retErr error) { - l.mx.Lock() +func (l *connList) Close() error { + var retErr error + l.mu.Lock() for _, c := range l.cns { - if err := c.Close(); err != nil { + if c == nil { + continue + } + if err := c.Close(); err != nil && retErr == nil { retErr = err } } l.cns = nil atomic.StoreInt32(&l.len, 0) - l.mx.Unlock() + l.mu.Unlock() return retErr } - -func (l *connList) closed() bool { - return l.cns == nil -} diff --git a/internal/pool/pool.go b/internal/pool/pool.go index ab03195d..bed6b468 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "log" + "net" "sync/atomic" "time" @@ -32,17 +33,17 @@ type Pooler interface { First() *Conn Get() (*Conn, bool, error) Put(*Conn) error - Remove(*Conn, error) error + Replace(*Conn, error) error Len() int FreeLen() int Close() error Stats() *PoolStats } -type dialer func() (*Conn, error) +type dialer func() (net.Conn, error) type ConnPool struct { - dial dialer + _dial dialer poolTimeout time.Duration idleTimeout time.Duration @@ -59,7 +60,7 @@ type ConnPool struct { func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Duration) *ConnPool { p := &ConnPool{ - dial: dial, + _dial: dial, poolTimeout: poolTimeout, idleTimeout: idleTimeout, @@ -126,8 +127,7 @@ func (p *ConnPool) wait() *Conn { panic("not reached") } -// Establish a new connection -func (p *ConnPool) new() (*Conn, error) { +func (p *ConnPool) dial() (net.Conn, error) { if p.rl.Limit() { err := fmt.Errorf( "redis: you open connections too fast (last_error=%q)", @@ -136,15 +136,22 @@ func (p *ConnPool) new() (*Conn, error) { return nil, err } - cn, err := p.dial() + cn, err := p._dial() if err != nil { p.storeLastErr(err.Error()) return nil, err } - return cn, nil } +func (p *ConnPool) newConn() (*Conn, error) { + netConn, err := p.dial() + if err != nil { + return nil, err + } + return NewConn(netConn), nil +} + // Get returns existed connection from the pool or creates a new one. func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) { if p.closed() { @@ -164,7 +171,7 @@ func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) { if p.conns.Reserve() { isNew = true - cn, err = p.new() + cn, err = p.newConn() if err != nil { p.conns.Remove(nil) return @@ -189,23 +196,26 @@ func (p *ConnPool) Put(cn *Conn) error { b, _ := cn.Rd.Peek(cn.Rd.Buffered()) err := fmt.Errorf("connection has unread data: %q", b) Logger.Print(err) - return p.Remove(cn, err) + return p.Replace(cn, err) } p.freeConns <- cn return nil } func (p *ConnPool) replace(cn *Conn) (*Conn, error) { - newcn, err := p.new() + _ = cn.Close() + + netConn, err := p.dial() if err != nil { _ = p.conns.Remove(cn) return nil, err } - _ = p.conns.Replace(cn, newcn) - return newcn, nil + cn.SetNetConn(netConn) + + return cn, nil } -func (p *ConnPool) Remove(cn *Conn, reason error) error { +func (p *ConnPool) Replace(cn *Conn, reason error) error { p.storeLastErr(reason.Error()) // Replace existing connection with new one and unblock waiter. diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index f2d58cf7..e0ea8689 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -25,7 +25,7 @@ func (p *SingleConnPool) Put(cn *Conn) error { return nil } -func (p *SingleConnPool) Remove(cn *Conn, _ error) error { +func (p *SingleConnPool) Replace(cn *Conn, _ error) error { if p.cn != cn { panic("p.cn != cn") } diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index c611c4b4..8f4c324b 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -67,13 +67,13 @@ func (p *StickyConnPool) Put(cn *Conn) error { return nil } -func (p *StickyConnPool) remove(reason error) error { - err := p.pool.Remove(p.cn, reason) +func (p *StickyConnPool) replace(reason error) error { + err := p.pool.Replace(p.cn, reason) p.cn = nil return err } -func (p *StickyConnPool) Remove(cn *Conn, reason error) error { +func (p *StickyConnPool) Replace(cn *Conn, reason error) error { defer p.mx.Unlock() p.mx.Lock() if p.closed { @@ -85,7 +85,7 @@ func (p *StickyConnPool) Remove(cn *Conn, reason error) error { if cn != nil && p.cn != cn { panic("p.cn != cn") } - return p.remove(reason) + return p.replace(reason) } func (p *StickyConnPool) Len() int { @@ -121,7 +121,7 @@ func (p *StickyConnPool) Close() error { err = p.put() } else { reason := errors.New("redis: sticky not reusable connection") - err = p.remove(reason) + err = p.replace(reason) } } return err diff --git a/multi_test.go b/multi_test.go index fa532d1a..459d0a62 100644 --- a/multi_test.go +++ b/multi_test.go @@ -145,7 +145,7 @@ var _ = Describe("Multi", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.NetConn = &badConn{} + cn.SetNetConn(&badConn{}) err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) @@ -172,7 +172,7 @@ var _ = Describe("Multi", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.NetConn = &badConn{} + cn.SetNetConn(&badConn{}) err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) diff --git a/options.go b/options.go index 95a26663..b0572124 100644 --- a/options.go +++ b/options.go @@ -67,18 +67,6 @@ func (opt *Options) getDialer() func() (net.Conn, error) { return opt.Dialer } -func (opt *Options) getPoolDialer() func() (*pool.Conn, error) { - dial := opt.getDialer() - return func() (*pool.Conn, error) { - netcn, err := dial() - if err != nil { - return nil, err - } - cn := pool.NewConn(netcn) - return cn, opt.initConn(cn) - } -} - func (opt *Options) getPoolSize() int { if opt.PoolSize == 0 { return 10 @@ -104,32 +92,9 @@ func (opt *Options) getIdleTimeout() time.Duration { return opt.IdleTimeout } -func (opt *Options) initConn(cn *pool.Conn) error { - if opt.Password == "" && opt.DB == 0 { - return nil - } - - // Temp client for Auth and Select. - client := newClient(opt, pool.NewSingleConnPool(cn)) - - if opt.Password != "" { - if err := client.Auth(opt.Password).Err(); err != nil { - return err - } - } - - if opt.DB > 0 { - if err := client.Select(opt.DB).Err(); err != nil { - return err - } - } - - return nil -} - func newConnPool(opt *Options) *pool.ConnPool { return pool.NewConnPool( - opt.getPoolDialer(), opt.getPoolSize(), opt.getPoolTimeout(), opt.getIdleTimeout()) + opt.getDialer(), opt.getPoolSize(), opt.getPoolTimeout(), opt.getIdleTimeout()) } // PoolStats contains pool state information and accumulated stats. diff --git a/pool_test.go b/pool_test.go index 2494e56e..a5b07216 100644 --- a/pool_test.go +++ b/pool_test.go @@ -179,7 +179,7 @@ var _ = Describe("pool", func() { // ok } - err = pool.Remove(cn, errors.New("test")) + err = pool.Replace(cn, errors.New("test")) Expect(err).NotTo(HaveOccurred()) // Check that Ping is unblocked. @@ -203,7 +203,7 @@ var _ = Describe("pool", func() { break } - _ = pool.Remove(cn, errors.New("test")) + _ = pool.Replace(cn, errors.New("test")) } Expect(rateErr).To(MatchError(`redis: you open connections too fast (last_error="test")`)) diff --git a/pubsub_test.go b/pubsub_test.go index a8bb610b..669c0737 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -291,10 +291,10 @@ var _ = Describe("PubSub", func() { expectReceiveMessageOnError := func(pubsub *redis.PubSub) { cn1, _, err := pubsub.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn1.NetConn = &badConn{ + cn1.SetNetConn(&badConn{ readErr: io.EOF, writeErr: io.EOF, - } + }) done := make(chan bool, 1) go func() { diff --git a/redis.go b/redis.go index da4b41b3..55f47572 100644 --- a/redis.go +++ b/redis.go @@ -33,12 +33,19 @@ func (c *baseClient) String() string { } func (c *baseClient) conn() (*pool.Conn, bool, error) { - return c.connPool.Get() + cn, isNew, err := c.connPool.Get() + if err == nil && isNew { + err = c.initConn(cn) + if err != nil { + c.putConn(cn, err, false) + } + } + return cn, isNew, err } func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { if isBadConn(err, allowTimeout) { - err = c.connPool.Remove(cn, err) + err = c.connPool.Replace(cn, err) if err != nil { Logger.Printf("pool.Remove failed: %s", err) } @@ -52,6 +59,29 @@ func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { return true } +func (c *baseClient) initConn(cn *pool.Conn) error { + if c.opt.Password == "" && c.opt.DB == 0 { + return nil + } + + // Temp client for Auth and Select. + client := newClient(c.opt, pool.NewSingleConnPool(cn)) + + if c.opt.Password != "" { + if err := client.Auth(c.opt.Password).Err(); err != nil { + return err + } + } + + if c.opt.DB > 0 { + if err := client.Select(c.opt.DB).Err(); err != nil { + return err + } + } + + return nil +} + func (c *baseClient) process(cmd Cmder) { for i := 0; i <= c.opt.MaxRetries; i++ { if i > 0 { diff --git a/redis_test.go b/redis_test.go index 23c39009..1435b7a4 100644 --- a/redis_test.go +++ b/redis_test.go @@ -160,7 +160,7 @@ var _ = Describe("Client", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.NetConn = &badConn{} + cn.SetNetConn(&badConn{}) err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) diff --git a/sentinel.go b/sentinel.go index 5575e73e..694dd602 100644 --- a/sentinel.go +++ b/sentinel.go @@ -267,7 +267,7 @@ func (d *sentinelFailover) closeOldConns(newMaster string) { cn.RemoteAddr(), ) Logger.Print(err) - d.pool.Remove(cn, err) + d.pool.Replace(cn, err) } else { cnsToPut = append(cnsToPut, cn) }