diff --git a/.travis.yml b/.travis.yml index 1d3148f7..dc4191ac 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,7 +5,6 @@ services: - redis-server go: - - 1.3 - 1.4 - 1.5 - tip diff --git a/multi.go b/multi.go index 320e78eb..0e0281d3 100644 --- a/multi.go +++ b/multi.go @@ -46,16 +46,15 @@ func (c *Client) Multi() *Multi { return multi } -func (c *Multi) putConn(cn *conn, ei error) { - var err error - if isBadConn(cn, ei) { +func (c *Multi) putConn(cn *conn, err error) { + if isBadConn(cn, err) { // Close current connection. - c.base.connPool.(*stickyConnPool).Reset() + c.base.connPool.(*stickyConnPool).Reset(err) } else { - err = c.base.connPool.Put(cn) - } - if err != nil { - log.Printf("redis: putConn failed: %s", err) + err := c.base.connPool.Put(cn) + if err != nil { + log.Printf("redis: putConn failed: %s", err) + } } } diff --git a/pool.go b/pool.go index 9968bd7e..5ed72097 100644 --- a/pool.go +++ b/pool.go @@ -20,7 +20,7 @@ type pool interface { First() *conn Get() (*conn, bool, error) Put(*conn) error - Remove(*conn) error + Remove(*conn, error) error Len() int FreeLen() int Close() error @@ -130,7 +130,7 @@ type connPool struct { _closed int32 - lastDialErr error + lastErr atomic.Value } func newConnPool(opt *Options) *connPool { @@ -204,15 +204,15 @@ func (p *connPool) wait() *conn { func (p *connPool) new() (*conn, error) { if p.rl.Limit() { err := fmt.Errorf( - "redis: you open connections too fast (last error: %v)", - p.lastDialErr, + "redis: you open connections too fast (last_error=%q)", + p.loadLastErr(), ) return nil, err } cn, err := p.dialer() if err != nil { - p.lastDialErr = err + p.storeLastErr(err.Error()) return nil, err } @@ -255,8 +255,9 @@ func (p *connPool) Get() (cn *conn, isNew bool, err error) { func (p *connPool) Put(cn *conn) error { if cn.rd.Buffered() != 0 { b, _ := cn.rd.Peek(cn.rd.Buffered()) - log.Printf("redis: connection has unread data: %q", b) - return p.Remove(cn) + err := fmt.Errorf("redis: connection has unread data: %q", b) + log.Print(err) + return p.Remove(cn, err) } if p.opt.getIdleTimeout() > 0 { cn.usedAt = time.Now() @@ -275,7 +276,9 @@ func (p *connPool) replace(cn *conn) (*conn, error) { return newcn, nil } -func (p *connPool) Remove(cn *conn) error { +func (p *connPool) Remove(cn *conn, reason error) error { + p.storeLastErr(reason.Error()) + // Replace existing connection with new one and unblock waiter. newcn, err := p.replace(cn) if err != nil { @@ -330,6 +333,17 @@ func (p *connPool) reaper() { } } +func (p *connPool) storeLastErr(err string) { + p.lastErr.Store(err) +} + +func (p *connPool) loadLastErr() string { + if v := p.lastErr.Load(); v != nil { + return v.(string) + } + return "" +} + //------------------------------------------------------------------------------ type singleConnPool struct { @@ -357,7 +371,7 @@ func (p *singleConnPool) Put(cn *conn) error { return nil } -func (p *singleConnPool) Remove(cn *conn) error { +func (p *singleConnPool) Remove(cn *conn, _ error) error { if p.cn != cn { panic("p.cn != cn") } @@ -440,13 +454,13 @@ func (p *stickyConnPool) Put(cn *conn) error { return nil } -func (p *stickyConnPool) remove() (err error) { - err = p.pool.Remove(p.cn) +func (p *stickyConnPool) remove(reason error) (err error) { + err = p.pool.Remove(p.cn, reason) p.cn = nil return err } -func (p *stickyConnPool) Remove(cn *conn) error { +func (p *stickyConnPool) Remove(cn *conn, _ error) error { defer p.mx.Unlock() p.mx.Lock() if p.closed { @@ -479,10 +493,10 @@ func (p *stickyConnPool) FreeLen() int { return 0 } -func (p *stickyConnPool) Reset() (err error) { +func (p *stickyConnPool) Reset(reason error) (err error) { p.mx.Lock() if p.cn != nil { - err = p.remove() + err = p.remove(reason) } p.mx.Unlock() return err @@ -500,7 +514,8 @@ func (p *stickyConnPool) Close() error { if p.reusable { err = p.put() } else { - err = p.remove() + reason := errors.New("redis: sticky not reusable connection") + err = p.remove(reason) } } return err diff --git a/pool_test.go b/pool_test.go index 9eb2c990..4d787a68 100644 --- a/pool_test.go +++ b/pool_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "errors" "sync" "testing" "time" @@ -36,7 +37,6 @@ var _ = Describe("pool", func() { }) AfterEach(func() { - Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred()) }) @@ -141,12 +141,12 @@ var _ = Describe("pool", func() { pool := client.Pool() // Reserve one connection. - cn, _, err := client.Pool().Get() + cn, _, err := pool.Get() Expect(err).NotTo(HaveOccurred()) // Reserve the rest of connections. for i := 0; i < 9; i++ { - _, _, err := client.Pool().Get() + _, _, err := pool.Get() Expect(err).NotTo(HaveOccurred()) } @@ -168,7 +168,8 @@ var _ = Describe("pool", func() { // ok } - Expect(pool.Remove(cn)).NotTo(HaveOccurred()) + err = pool.Remove(cn, errors.New("test")) + Expect(err).NotTo(HaveOccurred()) // Check that Ping is unblocked. select { @@ -179,6 +180,23 @@ var _ = Describe("pool", func() { } Expect(ping.Err()).NotTo(HaveOccurred()) }) + + It("should rate limit dial", func() { + pool := client.Pool() + + var rateErr error + for i := 0; i < 1000; i++ { + cn, _, err := pool.Get() + if err != nil { + rateErr = err + break + } + + _ = pool.Remove(cn, errors.New("test")) + } + + Expect(rateErr).To(MatchError(`redis: you open connections too fast (last_error="test")`)) + }) }) func BenchmarkPool(b *testing.B) { diff --git a/pubsub.go b/pubsub.go index 3e20fe72..aea2bed7 100644 --- a/pubsub.go +++ b/pubsub.go @@ -233,9 +233,9 @@ func (c *PubSub) Receive() (interface{}, error) { return c.ReceiveTimeout(0) } -func (c *PubSub) reconnect() { +func (c *PubSub) reconnect(reason error) { // Close current connection. - c.connPool.(*stickyConnPool).Reset() + c.connPool.(*stickyConnPool).Reset(reason) if len(c.channels) > 0 { if err := c.Subscribe(c.channels...); err != nil { @@ -276,7 +276,7 @@ func (c *PubSub) ReceiveMessage() (*Message, error) { if errNum > 2 { time.Sleep(time.Second) } - c.reconnect() + c.reconnect(err) continue } diff --git a/redis.go b/redis.go index cd88cefd..087564fd 100644 --- a/redis.go +++ b/redis.go @@ -20,10 +20,9 @@ func (c *baseClient) conn() (*conn, bool, error) { return c.connPool.Get() } -func (c *baseClient) putConn(cn *conn, ei error) { - var err error - if isBadConn(cn, ei) { - err = c.connPool.Remove(cn) +func (c *baseClient) putConn(cn *conn, err error) { + if isBadConn(cn, err) { + err = c.connPool.Remove(cn, err) } else { err = c.connPool.Put(cn) } diff --git a/sentinel.go b/sentinel.go index 63c011d4..7edd75f2 100644 --- a/sentinel.go +++ b/sentinel.go @@ -2,6 +2,7 @@ package redis import ( "errors" + "fmt" "log" "net" "strings" @@ -227,11 +228,12 @@ func (d *sentinelFailover) closeOldConns(newMaster string) { break } if cn.RemoteAddr().String() != newMaster { - log.Printf( + err := fmt.Errorf( "redis-sentinel: closing connection to the old master %s", cn.RemoteAddr(), ) - d.pool.Remove(cn) + log.Print(err) + d.pool.Remove(cn, err) } else { cnsToPut = append(cnsToPut, cn) }