From f3f44aefbe902201ec3de33bc1ae8c0bbda08745 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sat, 2 May 2015 16:11:18 +0300 Subject: [PATCH] Fix pool to close all connections when client is closed. --- conn.go | 96 ++++++++++++++++ pool.go | 298 +++++++++++++++++++++++++------------------------- pool_test.go | 18 +-- redis.go | 59 ++++++---- redis_test.go | 2 +- sentinel.go | 10 +- 6 files changed, 294 insertions(+), 189 deletions(-) create mode 100644 conn.go diff --git a/conn.go b/conn.go new file mode 100644 index 00000000..6ce55805 --- /dev/null +++ b/conn.go @@ -0,0 +1,96 @@ +package redis + +import ( + "net" + "time" + + "gopkg.in/bufio.v1" +) + +type conn struct { + netcn net.Conn + rd *bufio.Reader + buf []byte + + usedAt time.Time + readTimeout time.Duration + writeTimeout time.Duration +} + +func newConnDialer(opt *options) func() (*conn, error) { + return func() (*conn, error) { + netcn, err := opt.Dialer() + if err != nil { + return nil, err + } + cn := &conn{ + netcn: netcn, + buf: make([]byte, 0, 64), + } + cn.rd = bufio.NewReader(cn) + return cn, cn.init(opt) + } +} + +func (cn *conn) init(opt *options) error { + if opt.Password == "" && opt.DB == 0 { + return nil + } + + // Use connection to connect to redis + pool := newSingleConnPool(nil, false) + pool.SetConn(cn) + + // Client is not closed because we want to reuse underlying connection. + client := newClient(opt, pool) + + if opt.Password != "" { + if err := client.Auth(opt.Password).Err(); err != nil { + return err + } + } + + if opt.DB > 0 { + if err := client.Select(opt.DB).Err(); err != nil { + return err + } + } + + return nil +} + +func (cn *conn) writeCmds(cmds ...Cmder) error { + buf := cn.buf[:0] + for _, cmd := range cmds { + buf = appendArgs(buf, cmd.args()) + } + + _, err := cn.Write(buf) + return err +} + +func (cn *conn) Read(b []byte) (int, error) { + if cn.readTimeout != 0 { + cn.netcn.SetReadDeadline(time.Now().Add(cn.readTimeout)) + } else { + cn.netcn.SetReadDeadline(zeroTime) + } + return cn.netcn.Read(b) +} + +func (cn *conn) Write(b []byte) (int, error) { + if cn.writeTimeout != 0 { + cn.netcn.SetWriteDeadline(time.Now().Add(cn.writeTimeout)) + } else { + cn.netcn.SetWriteDeadline(zeroTime) + } + return cn.netcn.Write(b) +} + +func (cn *conn) RemoteAddr() net.Addr { + return cn.netcn.RemoteAddr() +} + +func (cn *conn) Close() error { + return cn.netcn.Close() +} diff --git a/pool.go b/pool.go index 96007fbd..5194bc80 100644 --- a/pool.go +++ b/pool.go @@ -4,13 +4,11 @@ import ( "errors" "fmt" "log" - "net" "sync" "sync/atomic" "time" "gopkg.in/bsm/ratelimit.v1" - "gopkg.in/bufio.v1" ) var ( @@ -28,103 +26,132 @@ type pool interface { Put(*conn) error Remove(*conn) error Len() int - Size() int + FreeLen() int Close() error } -//------------------------------------------------------------------------------ - -type conn struct { - netcn net.Conn - rd *bufio.Reader - buf []byte - - usedAt time.Time - readTimeout time.Duration - writeTimeout time.Duration +type connList struct { + cns []*conn + mx sync.Mutex + len int32 // atomic + size int32 } -func newConnFunc(dial func() (net.Conn, error)) func() (*conn, error) { - return func() (*conn, error) { - netcn, err := dial() - if err != nil { - return nil, err +func newConnList(size int) *connList { + return &connList{ + cns: make([]*conn, 0, size), + size: int32(size), + } +} + +func (l *connList) Len() int { + return int(atomic.LoadInt32(&l.len)) +} + +// Reserve reserves place in the list and returns true on success. The +// caller must add or remove connection if place was reserved. +func (l *connList) Reserve() bool { + len := atomic.AddInt32(&l.len, 1) + reserved := len <= l.size + if !reserved { + atomic.AddInt32(&l.len, -1) + } + return reserved +} + +// Add adds connection to the list. The caller must reserve place first. +func (l *connList) Add(cn *conn) { + l.mx.Lock() + l.cns = append(l.cns, cn) + l.mx.Unlock() +} + +func (l *connList) Remove(cn *conn) error { + defer l.mx.Unlock() + l.mx.Lock() + + if cn == nil { + atomic.AddInt32(&l.len, -1) + return nil + } + + for i, c := range l.cns { + if c == cn { + l.cns = append(l.cns[:i], l.cns[i+1:]...) + atomic.AddInt32(&l.len, -1) + return cn.Close() } - cn := &conn{ - netcn: netcn, - buf: make([]byte, 0, 64), + } + + panic("conn not found in the list") +} + +func (l *connList) Replace(cn, newcn *conn) error { + defer l.mx.Unlock() + l.mx.Lock() + + for i, c := range l.cns { + if c == cn { + l.cns[i] = newcn + return cn.Close() } - cn.rd = bufio.NewReader(cn) - return cn, nil - } -} - -func (cn *conn) writeCmds(cmds ...Cmder) error { - buf := cn.buf[:0] - for _, cmd := range cmds { - buf = appendArgs(buf, cmd.args()) } - _, err := cn.Write(buf) - return err + panic("conn not found in the list") } -func (cn *conn) Read(b []byte) (int, error) { - if cn.readTimeout != 0 { - cn.netcn.SetReadDeadline(time.Now().Add(cn.readTimeout)) - } else { - cn.netcn.SetReadDeadline(zeroTime) +func (l *connList) Close() (retErr error) { + l.mx.Lock() + for _, c := range l.cns { + if err := c.Close(); err != nil { + retErr = err + } } - return cn.netcn.Read(b) + l.cns = nil + atomic.StoreInt32(&l.len, 0) + l.mx.Unlock() + return retErr } -func (cn *conn) Write(b []byte) (int, error) { - if cn.writeTimeout != 0 { - cn.netcn.SetWriteDeadline(time.Now().Add(cn.writeTimeout)) - } else { - cn.netcn.SetWriteDeadline(zeroTime) - } - return cn.netcn.Write(b) +type connPoolOptions struct { + Dialer func() (*conn, error) + PoolSize int + PoolTimeout time.Duration + IdleTimeout time.Duration + IdleCheckFrequency time.Duration } -func (cn *conn) RemoteAddr() net.Addr { - return cn.netcn.RemoteAddr() -} - -func (cn *conn) Close() error { - return cn.netcn.Close() -} - -func (cn *conn) isIdle(timeout time.Duration) bool { - return timeout > 0 && time.Since(cn.usedAt) > timeout -} - -//------------------------------------------------------------------------------ - type connPool struct { - dial func() (*conn, error) - rl *ratelimit.RateLimiter - - opt *options + rl *ratelimit.RateLimiter + opt *connPoolOptions + conns *connList freeConns chan *conn - size int32 - closed int32 + _closed int32 lastDialErr error } -func newConnPool(dial func() (*conn, error), opt *options) *connPool { - return &connPool{ - dial: dial, - rl: ratelimit.New(2*opt.PoolSize, time.Second), - +func newConnPool(opt *connPoolOptions) *connPool { + p := &connPool{ + rl: ratelimit.New(2*opt.PoolSize, time.Second), opt: opt, + conns: newConnList(opt.PoolSize), freeConns: make(chan *conn, opt.PoolSize), } + if p.opt.IdleTimeout > 0 && p.opt.IdleCheckFrequency > 0 { + go p.reaper() + } + return p } -func (p *connPool) isClosed() bool { return atomic.LoadInt32(&p.closed) > 0 } +func (p *connPool) closed() bool { + return atomic.LoadInt32(&p._closed) == 1 +} + +func (p *connPool) isIdle(cn *conn) bool { + return p.opt.IdleTimeout > 0 && time.Since(cn.usedAt) > p.opt.IdleTimeout +} // First returns first non-idle connection from the pool or nil if // there are no connections. @@ -132,8 +159,8 @@ func (p *connPool) First() *conn { for { select { case cn := <-p.freeConns: - if cn.isIdle(p.opt.IdleTimeout) { - p.Remove(cn) + if p.isIdle(cn) { + p.conns.Remove(cn) continue } return cn @@ -150,7 +177,7 @@ func (p *connPool) wait(timeout time.Duration) *conn { for { select { case cn := <-p.freeConns: - if cn.isIdle(p.opt.IdleTimeout) { + if p.isIdle(cn) { p.Remove(cn) continue } @@ -172,52 +199,19 @@ func (p *connPool) new() (*conn, error) { return nil, err } - cn, err := p.dial() + cn, err := p.opt.Dialer() if err != nil { p.lastDialErr = err return nil, err } - if err := p.initConn(cn); err != nil { - cn.Close() - return nil, err - } - return cn, nil } -// Initialize connection -func (p *connPool) initConn(cn *conn) error { - if p.opt.Password == "" && p.opt.DB == 0 { - return nil - } - - // Use connection to connect to redis - pool := newSingleConnPool(p, false) - pool.SetConn(cn) - - // Client is not closed because we want to reuse underlying connection. - client := newClient(p.opt, pool) - - if p.opt.Password != "" { - if err := client.Auth(p.opt.Password).Err(); err != nil { - return err - } - } - - if p.opt.DB > 0 { - if err := client.Select(p.opt.DB).Err(); err != nil { - return err - } - } - - return nil -} - // Get returns existed connection from the pool or creates a new one // if needed. func (p *connPool) Get() (*conn, error) { - if p.isClosed() { + if p.closed() { return nil, errClosed } @@ -226,16 +220,16 @@ func (p *connPool) Get() (*conn, error) { return cn, nil } - // Try to create a new one - if ref := atomic.AddInt32(&p.size, 1); int(ref) <= p.opt.PoolSize { + // Try to create a new one. + if p.conns.Reserve() { cn, err := p.new() if err != nil { - atomic.AddInt32(&p.size, -1) // Undo ref increment + p.conns.Remove(nil) return nil, err } + p.conns.Add(cn) return cn, nil } - atomic.AddInt32(&p.size, -1) // Otherwise, wait for the available connection if cn := p.wait(p.opt.PoolTimeout); cn != nil { @@ -259,49 +253,53 @@ func (p *connPool) Put(cn *conn) error { } func (p *connPool) Remove(cn *conn) error { - if p.isClosed() { - atomic.AddInt32(&p.size, -1) - return cn.Close() - } - - // Replace existing connection with new one and unblock `wait`. - newcn, err := p.new() - if err != nil { - atomic.AddInt32(&p.size, -1) - } else { - p.Put(newcn) - } - - return cn.Close() -} - -// Len returns number of idle connections. -func (p *connPool) Len() int { - return len(p.freeConns) -} - -// Size returns number of connections in the pool. -func (p *connPool) Size() int { - return int(atomic.LoadInt32(&p.size)) -} - -func (p *connPool) Close() (retErr error) { - if !atomic.CompareAndSwapInt32(&p.closed, 0, 1) { + if p.closed() { + // Close already closed all connections. return nil } - // Wait until pool has no connections - for p.Size() > 0 { - cn := p.wait(p.opt.PoolTimeout) - if cn == nil { + // Replace existing connection with new one and unblock waiter. + newcn, err := p.new() + if err != nil { + return p.conns.Remove(cn) + } + p.freeConns <- newcn + return p.conns.Replace(cn, newcn) +} + +// Len returns total number of connections. +func (p *connPool) Len() int { + return p.conns.Len() +} + +// FreeLen returns number of free connections. +func (p *connPool) FreeLen() int { + return len(p.freeConns) +} + +func (p *connPool) Close() error { + if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) { + return errClosed + } + return p.conns.Close() +} + +func (p *connPool) reaper() { + ticker := time.NewTicker(p.opt.IdleCheckFrequency) + defer ticker.Stop() + + for _ = range ticker.C { + if p.closed() { break } - if err := p.Remove(cn); err != nil { - retErr = err + + // pool.First removes idle connections from the pool and + // returns first non-idle connection. So just put returned + // connection back. + if cn := p.First(); cn != nil { + p.Put(cn) } } - - return retErr } //------------------------------------------------------------------------------ @@ -404,7 +402,7 @@ func (p *singleConnPool) Len() int { return 1 } -func (p *singleConnPool) Size() int { +func (p *singleConnPool) FreeLen() int { defer p.cnMtx.Unlock() p.cnMtx.Lock() if p.cn == nil { diff --git a/pool_test.go b/pool_test.go index 890ca2cf..fae6120d 100644 --- a/pool_test.go +++ b/pool_test.go @@ -48,9 +48,9 @@ var _ = Describe("Pool", func() { }) pool := client.Pool() - Expect(pool.Size()).To(BeNumerically("<=", 10)) Expect(pool.Len()).To(BeNumerically("<=", 10)) - Expect(pool.Size()).To(Equal(pool.Len())) + Expect(pool.FreeLen()).To(BeNumerically("<=", 10)) + Expect(pool.Len()).To(Equal(pool.FreeLen())) }) It("should respect max on multi", func() { @@ -70,9 +70,9 @@ var _ = Describe("Pool", func() { }) pool := client.Pool() - Expect(pool.Size()).To(BeNumerically("<=", 10)) Expect(pool.Len()).To(BeNumerically("<=", 10)) - Expect(pool.Size()).To(Equal(pool.Len())) + Expect(pool.FreeLen()).To(BeNumerically("<=", 10)) + Expect(pool.Len()).To(Equal(pool.FreeLen())) }) It("should respect max on pipelines", func() { @@ -88,9 +88,9 @@ var _ = Describe("Pool", func() { }) pool := client.Pool() - Expect(pool.Size()).To(BeNumerically("<=", 10)) Expect(pool.Len()).To(BeNumerically("<=", 10)) - Expect(pool.Size()).To(Equal(pool.Len())) + Expect(pool.FreeLen()).To(BeNumerically("<=", 10)) + Expect(pool.Len()).To(Equal(pool.FreeLen())) }) It("should respect max on pubsub", func() { @@ -101,8 +101,8 @@ var _ = Describe("Pool", func() { }) pool := client.Pool() - Expect(pool.Size()).To(Equal(10)) Expect(pool.Len()).To(Equal(10)) + Expect(pool.FreeLen()).To(Equal(10)) }) It("should remove broken connections", func() { @@ -120,8 +120,8 @@ var _ = Describe("Pool", func() { Expect(val).To(Equal("PONG")) pool := client.Pool() - Expect(pool.Size()).To(Equal(1)) Expect(pool.Len()).To(Equal(1)) + Expect(pool.FreeLen()).To(Equal(1)) }) It("should reuse connections", func() { @@ -132,8 +132,8 @@ var _ = Describe("Pool", func() { } pool := client.Pool() - Expect(pool.Size()).To(Equal(1)) Expect(pool.Len()).To(Equal(1)) + Expect(pool.FreeLen()).To(Equal(1)) }) It("should unblock client when connection is removed", func() { diff --git a/redis.go b/redis.go index d03425cb..fbaff50a 100644 --- a/redis.go +++ b/redis.go @@ -67,19 +67,6 @@ func (c *baseClient) Close() error { //------------------------------------------------------------------------------ -type options struct { - Password string - DB int64 - - DialTimeout time.Duration - ReadTimeout time.Duration - WriteTimeout time.Duration - - PoolSize int - PoolTimeout time.Duration - IdleTimeout time.Duration -} - type Options struct { // The network type, either "tcp" or "unix". // Default: "tcp" @@ -120,6 +107,15 @@ type Options struct { IdleTimeout time.Duration } +func (opt *Options) getDialer() func() (net.Conn, error) { + if opt.Dialer == nil { + return func() (net.Conn, error) { + return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout()) + } + } + return opt.Dialer +} + func (opt *Options) getNetwork() string { if opt.Network == "" { return "tcp" @@ -150,15 +146,39 @@ func (opt *Options) getPoolTimeout() time.Duration { func (opt *Options) options() *options { return &options{ + Dialer: opt.getDialer(), + PoolSize: opt.getPoolSize(), + PoolTimeout: opt.getPoolTimeout(), + IdleTimeout: opt.IdleTimeout, + DB: opt.DB, Password: opt.Password, DialTimeout: opt.getDialTimeout(), ReadTimeout: opt.ReadTimeout, WriteTimeout: opt.WriteTimeout, + } +} - PoolSize: opt.getPoolSize(), - PoolTimeout: opt.getPoolTimeout(), +type options struct { + Dialer func() (net.Conn, error) + PoolSize int + PoolTimeout time.Duration + IdleTimeout time.Duration + + Password string + DB int64 + + DialTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +func (opt *options) connPoolOptions() *connPoolOptions { + return &connPoolOptions{ + Dialer: newConnDialer(opt), + PoolSize: opt.PoolSize, + PoolTimeout: opt.PoolTimeout, IdleTimeout: opt.IdleTimeout, } } @@ -180,13 +200,8 @@ func newClient(opt *options, pool pool) *Client { func NewClient(clOpt *Options) *Client { opt := clOpt.options() - dialer := clOpt.Dialer - if dialer == nil { - dialer = func() (net.Conn, error) { - return net.DialTimeout(clOpt.getNetwork(), clOpt.Addr, opt.DialTimeout) - } - } - return newClient(opt, newConnPool(newConnFunc(dialer), opt)) + pool := newConnPool(opt.connPoolOptions()) + return newClient(opt, pool) } // Deprecated. Use NewClient instead. diff --git a/redis_test.go b/redis_test.go index 62068d06..855b1cb8 100644 --- a/redis_test.go +++ b/redis_test.go @@ -34,7 +34,7 @@ var _ = Describe("Client", func() { }) AfterEach(func() { - Expect(client.Close()).NotTo(HaveOccurred()) + client.Close() }) It("should ping", func() { diff --git a/sentinel.go b/sentinel.go index 496e2cb4..717f4122 100644 --- a/sentinel.go +++ b/sentinel.go @@ -104,14 +104,9 @@ type sentinelClient struct { func newSentinel(clOpt *Options) *sentinelClient { opt := clOpt.options() - opt.Password = "" - opt.DB = 0 - dialer := func() (net.Conn, error) { - return net.DialTimeout("tcp", clOpt.Addr, opt.DialTimeout) - } base := &baseClient{ opt: opt, - connPool: newConnPool(newConnFunc(dialer), opt), + connPool: newConnPool(opt.connPoolOptions()), } return &sentinelClient{ baseClient: base, @@ -163,7 +158,8 @@ func (d *sentinelFailover) dial() (net.Conn, error) { func (d *sentinelFailover) Pool() pool { d.poolOnce.Do(func() { - d.pool = newConnPool(newConnFunc(d.dial), d.opt) + d.opt.Dialer = d.dial + d.pool = newConnPool(d.opt.connPoolOptions()) }) return d.pool }