diff --git a/command_test.go b/command_test.go index 1be2b945..c58528f0 100644 --- a/command_test.go +++ b/command_test.go @@ -150,7 +150,9 @@ var _ = Describe("Command", func() { wg.Add(n) for i := 0; i < n; i++ { go func() { + defer GinkgoRecover() defer wg.Done() + err := client.Incr(key).Err() Expect(err).NotTo(HaveOccurred()) }() diff --git a/commands_test.go b/commands_test.go index 2eb43002..c6e71ab1 100644 --- a/commands_test.go +++ b/commands_test.go @@ -17,7 +17,8 @@ var _ = Describe("Commands", func() { BeforeEach(func() { client = redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, + Addr: redisAddr, + PoolTimeout: 30 * time.Second, }) }) @@ -1116,6 +1117,8 @@ var _ = Describe("Commands", func() { started := make(chan bool) done := make(chan bool) go func() { + defer GinkgoRecover() + started <- true bLPop := client.BLPop(0, "list") Expect(bLPop.Err()).NotTo(HaveOccurred()) @@ -1161,6 +1164,8 @@ var _ = Describe("Commands", func() { started := make(chan bool) done := make(chan bool) go func() { + defer GinkgoRecover() + started <- true brpop := client.BRPop(0, "list") Expect(brpop.Err()).NotTo(HaveOccurred()) @@ -2190,7 +2195,9 @@ var _ = Describe("Commands", func() { wg.Add(1) go func() { + defer GinkgoRecover() defer wg.Done() + for { cmds, err := safeIncr() if err == redis.TxFailedErr { diff --git a/pool.go b/pool.go index ffedb3d5..08a17002 100644 --- a/pool.go +++ b/pool.go @@ -1,19 +1,20 @@ package redis import ( - "container/list" "errors" "fmt" "log" "net" "sync" + "sync/atomic" "time" "gopkg.in/bufio.v1" ) var ( - errClosed = errors.New("redis: client is closed") + errClosed = errors.New("redis: client is closed") + errPoolTimeout = errors.New("redis: connection pool timeout") ) var ( @@ -37,13 +38,9 @@ type conn struct { rd *bufio.Reader buf []byte - inUse bool - usedAt time.Time - + usedAt time.Time readTimeout time.Duration writeTimeout time.Duration - - elem *list.Element } func newConnFunc(dial func() (net.Conn, error)) func() (*conn, error) { @@ -87,19 +84,21 @@ func (cn *conn) Close() error { return cn.netcn.Close() } +func (cn *conn) isIdle(timeout time.Duration) bool { + return timeout > 0 && time.Since(cn.usedAt) > timeout +} + //------------------------------------------------------------------------------ type connPool struct { dial func() (*conn, error) rl *rateLimiter - opt *options + opt *options + conns chan *conn - cond *sync.Cond - conns *list.List - - idleNum int - closed bool + size int32 + closed int32 lastDialErr error } @@ -109,13 +108,47 @@ func newConnPool(dial func() (*conn, error), opt *options) *connPool { dial: dial, rl: newRateLimiter(time.Second, 2*opt.PoolSize), - opt: opt, - - cond: sync.NewCond(&sync.Mutex{}), - conns: list.New(), + opt: opt, + conns: make(chan *conn, opt.PoolSize), } } +func (p *connPool) isClosed() bool { return atomic.LoadInt32(&p.closed) > 0 } + +// First available connection, non-blocking +func (p *connPool) first() *conn { + for { + select { + case cn := <-p.conns: + if !cn.isIdle(p.opt.IdleTimeout) { + return cn + } + p.remove(cn) + default: + return nil + } + } + panic("not reached") +} + +// Wait for available connection, blocking +func (p *connPool) wait() (*conn, error) { + deadline := time.After(p.opt.PoolTimeout) + for { + select { + case cn := <-p.conns: + if !cn.isIdle(p.opt.IdleTimeout) { + return cn, nil + } + p.remove(cn) + case <-deadline: + return nil, errPoolTimeout + } + } + panic("not reached") +} + +// Establish a new connection func (p *connPool) new() (*conn, error) { if !p.rl.Check() { err := fmt.Errorf( @@ -132,60 +165,29 @@ func (p *connPool) new() (*conn, error) { } func (p *connPool) Get() (*conn, bool, error) { - p.cond.L.Lock() - - if p.closed { - p.cond.L.Unlock() + if p.isClosed() { return nil, false, errClosed } - if p.opt.IdleTimeout > 0 { - for el := p.conns.Front(); el != nil; el = el.Next() { - cn := el.Value.(*conn) - if cn.inUse { - break - } - if time.Since(cn.usedAt) > p.opt.IdleTimeout { - if err := p.remove(cn); err != nil { - log.Printf("remove failed: %s", err) - } - } - } - } - - for p.conns.Len() >= p.opt.PoolSize && p.idleNum == 0 { - p.cond.Wait() - } - - if p.idleNum > 0 { - elem := p.conns.Front() - cn := elem.Value.(*conn) - if cn.inUse { - panic("pool: precondition failed") - } - cn.inUse = true - p.conns.MoveToBack(elem) - p.idleNum-- - - p.cond.L.Unlock() + // Fetch first non-idle connection, if available + if cn := p.first(); cn != nil { return cn, false, nil } - if p.conns.Len() < p.opt.PoolSize { + // Try to create a new one + if ref := atomic.AddInt32(&p.size, 1); int(ref) <= p.opt.PoolSize { cn, err := p.new() if err != nil { - p.cond.L.Unlock() + atomic.AddInt32(&p.size, -1) // Undo ref increment return nil, false, err } - - cn.inUse = true - cn.elem = p.conns.PushBack(cn) - - p.cond.L.Unlock() return cn, true, nil } + atomic.AddInt32(&p.size, -1) - panic("not reached") + // Otherwise, wait for the available connection + cn, err := p.wait() + return cn, false, err } func (p *connPool) Put(cn *conn) error { @@ -195,92 +197,67 @@ func (p *connPool) Put(cn *conn) error { return p.Remove(cn) } + if p.isClosed() { + return errClosed + } if p.opt.IdleTimeout > 0 { cn.usedAt = time.Now() } - - p.cond.L.Lock() - if p.closed { - p.cond.L.Unlock() - return errClosed - } - cn.inUse = false - p.conns.MoveToFront(cn.elem) - p.idleNum++ - p.cond.Signal() - p.cond.L.Unlock() - + p.conns <- cn return nil } func (p *connPool) Remove(cn *conn) error { - p.cond.L.Lock() - if p.closed { - // Noop, connection is already closed. - p.cond.L.Unlock() + if p.isClosed() { return nil } - err := p.remove(cn) - p.cond.Signal() - p.cond.L.Unlock() - return err + return p.remove(cn) } func (p *connPool) remove(cn *conn) error { - p.conns.Remove(cn.elem) - cn.elem = nil - if !cn.inUse { - p.idleNum-- - } + atomic.AddInt32(&p.size, -1) return cn.Close() } // Len returns number of idle connections. func (p *connPool) Len() int { - defer p.cond.L.Unlock() - p.cond.L.Lock() - return p.idleNum + return len(p.conns) } // Size returns number of connections in the pool. func (p *connPool) Size() int { - defer p.cond.L.Unlock() - p.cond.L.Lock() - return p.conns.Len() + return int(atomic.LoadInt32(&p.size)) } func (p *connPool) Filter(f func(*conn) bool) { - p.cond.L.Lock() - for el, next := p.conns.Front(), p.conns.Front(); el != nil; el = next { - next = el.Next() - cn := el.Value.(*conn) - if !f(cn) { - p.remove(cn) + for { + select { + case cn := <-p.conns: + if !f(cn) { + p.remove(cn) + } + default: + return } } - p.cond.L.Unlock() + panic("not reached") } -func (p *connPool) Close() error { - defer p.cond.L.Unlock() - p.cond.L.Lock() - if p.closed { +func (p *connPool) Close() (err error) { + if !atomic.CompareAndSwapInt32(&p.closed, 0, 1) { return nil } - p.closed = true p.rl.Close() - var retErr error + for { - e := p.conns.Front() - if e == nil { - break + if p.Size() < 1 { + return } - if err := p.remove(e.Value.(*conn)); err != nil { - log.Printf("cn.Close failed: %s", err) - retErr = err + if e := p.remove(<-p.conns); e != nil { + err = e } } - return retErr + panic("not reached") } //------------------------------------------------------------------------------ diff --git a/pool_test.go b/pool_test.go index 311bccd8..2960daad 100644 --- a/pool_test.go +++ b/pool_test.go @@ -17,7 +17,9 @@ var _ = Describe("Pool", func() { for i := 0; i < n; i++ { wg.Add(1) go func() { + defer GinkgoRecover() defer wg.Done() + cb() }() } diff --git a/redis.go b/redis.go index 3a3ebc55..a1045c36 100644 --- a/redis.go +++ b/redis.go @@ -148,6 +148,7 @@ type options struct { WriteTimeout time.Duration PoolSize int + PoolTimeout time.Duration IdleTimeout time.Duration } @@ -167,6 +168,7 @@ type Options struct { WriteTimeout time.Duration PoolSize int + PoolTimeout time.Duration IdleTimeout time.Duration } @@ -191,6 +193,13 @@ func (opt *Options) getDialTimeout() time.Duration { return opt.DialTimeout } +func (opt *Options) getPoolTimeout() time.Duration { + if opt.PoolTimeout == 0 { + return 5 * time.Second + } + return opt.PoolTimeout +} + func (opt *Options) options() *options { return &options{ DB: opt.DB, @@ -201,6 +210,7 @@ func (opt *Options) options() *options { WriteTimeout: opt.WriteTimeout, PoolSize: opt.getPoolSize(), + PoolTimeout: opt.getPoolTimeout(), IdleTimeout: opt.IdleTimeout, } } diff --git a/sentinel.go b/sentinel.go index d3ffeca9..1ed2e50b 100644 --- a/sentinel.go +++ b/sentinel.go @@ -23,6 +23,7 @@ type FailoverOptions struct { DialTimeout time.Duration ReadTimeout time.Duration WriteTimeout time.Duration + PoolTimeout time.Duration IdleTimeout time.Duration } @@ -33,6 +34,13 @@ func (opt *FailoverOptions) getPoolSize() int { return opt.PoolSize } +func (opt *FailoverOptions) getPoolTimeout() time.Duration { + if opt.PoolTimeout == 0 { + return 5 * time.Second + } + return opt.PoolTimeout +} + func (opt *FailoverOptions) getDialTimeout() time.Duration { if opt.DialTimeout == 0 { return 5 * time.Second @@ -50,6 +58,7 @@ func (opt *FailoverOptions) options() *options { WriteTimeout: opt.WriteTimeout, PoolSize: opt.getPoolSize(), + PoolTimeout: opt.getPoolTimeout(), IdleTimeout: opt.IdleTimeout, } } @@ -169,6 +178,7 @@ func (d *sentinelFailover) MasterAddr() (string, error) { WriteTimeout: d.opt.WriteTimeout, PoolSize: d.opt.PoolSize, + PoolTimeout: d.opt.PoolTimeout, IdleTimeout: d.opt.IdleTimeout, }) masterAddr, err := sentinel.GetMasterAddrByName(d.masterName).Result()