diff --git a/cluster_pipeline.go b/cluster_pipeline.go index 73d272b..5299b5f 100644 --- a/cluster_pipeline.go +++ b/cluster_pipeline.go @@ -79,7 +79,7 @@ func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) { if err != nil { retErr = err } - client.putConn(cn, err) + client.putConn(cn, err, false) } cmdsMap = failedCmds diff --git a/command.go b/command.go index 1986c66..31516f6 100644 --- a/command.go +++ b/command.go @@ -32,7 +32,6 @@ type Cmder interface { setErr(error) reset() - writeTimeout() *time.Duration readTimeout() *time.Duration clusterKey() string @@ -82,7 +81,7 @@ type baseCmd struct { _clusterKeyPos int - _writeTimeout, _readTimeout *time.Duration + _readTimeout *time.Duration } func (cmd *baseCmd) Err() error { @@ -104,10 +103,6 @@ func (cmd *baseCmd) setReadTimeout(d time.Duration) { cmd._readTimeout = &d } -func (cmd *baseCmd) writeTimeout() *time.Duration { - return cmd._writeTimeout -} - func (cmd *baseCmd) clusterKey() string { if cmd._clusterKeyPos > 0 && cmd._clusterKeyPos < len(cmd._args) { return fmt.Sprint(cmd._args[cmd._clusterKeyPos]) @@ -115,10 +110,6 @@ func (cmd *baseCmd) clusterKey() string { return "" } -func (cmd *baseCmd) setWriteTimeout(d time.Duration) { - cmd._writeTimeout = &d -} - func (cmd *baseCmd) setErr(e error) { cmd.err = e } diff --git a/commands_test.go b/commands_test.go index 334ddd1..49d8488 100644 --- a/commands_test.go +++ b/commands_test.go @@ -1303,6 +1303,9 @@ var _ = Describe("Commands", func() { bLPop := client.BLPop(time.Second, "list1") Expect(bLPop.Val()).To(BeNil()) Expect(bLPop.Err()).To(Equal(redis.Nil)) + + stats := client.Pool().Stats() + Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(1))) }) It("should BRPop", func() { diff --git a/error.go b/error.go index dce10a3..3f2a560 100644 --- a/error.go +++ b/error.go @@ -33,15 +33,17 @@ func isNetworkError(err error) bool { return ok } -func isBadConn(err error) bool { +func isBadConn(err error, allowTimeout bool) bool { if err == nil { return false } if _, ok := err.(redisError); ok { return false } - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - return false + if allowTimeout { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return false + } } return true } diff --git a/main_test.go b/main_test.go index b9b3e21..d298dd2 100644 --- a/main_test.go +++ b/main_test.go @@ -98,9 +98,10 @@ func TestGinkgoSuite(t *testing.T) { //------------------------------------------------------------------------------ -func eventually(fn func() error, timeout time.Duration) (err error) { +func eventually(fn func() error, timeout time.Duration) error { done := make(chan struct{}) var exit int32 + var err error go func() { for atomic.LoadInt32(&exit) == 0 { err = fn() diff --git a/multi.go b/multi.go index 7ffc7e0..1a13d04 100644 --- a/multi.go +++ b/multi.go @@ -133,7 +133,7 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) { } err = c.execCmds(cn, cmds) - c.base.putConn(cn, err) + c.base.putConn(cn, err, false) return retCmds, err } diff --git a/pipeline.go b/pipeline.go index 8c800be..8caae6b 100644 --- a/pipeline.go +++ b/pipeline.go @@ -98,7 +98,7 @@ func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) { resetCmds(failedCmds) } failedCmds, err = execCmds(cn, failedCmds) - pipe.client.putConn(cn, err) + pipe.client.putConn(cn, err, false) if err != nil && retErr == nil { retErr = err } diff --git a/pool.go b/pool.go index bb713bf..3725c40 100644 --- a/pool.go +++ b/pool.go @@ -18,7 +18,8 @@ var ( // PoolStats contains pool state information and accumulated stats. type PoolStats struct { Requests uint32 // number of times a connection was requested by the pool - Waits uint32 // number of times our pool had to wait for a connection + Hits uint32 // number of times free connection was found in the pool + Waits uint32 // number of times the pool had to wait for a connection Timeouts uint32 // number of times a wait timeout occurred TotalConns uint32 // the number of total connections in the pool @@ -241,6 +242,7 @@ func (p *connPool) Get() (cn *conn, isNew bool, err error) { // Fetch first non-idle connection, if available. if cn = p.First(); cn != nil { + atomic.AddUint32(&p.stats.Hits, 1) return } diff --git a/pool_test.go b/pool_test.go index 4d787a6..bc88c5f 100644 --- a/pool_test.go +++ b/pool_test.go @@ -123,6 +123,12 @@ var _ = Describe("pool", func() { pool := client.Pool() Expect(pool.Len()).To(Equal(1)) Expect(pool.FreeLen()).To(Equal(1)) + + stats := pool.Stats() + Expect(stats.Requests).To(Equal(uint32(3))) + Expect(stats.Hits).To(Equal(uint32(2))) + Expect(stats.Waits).To(Equal(uint32(0))) + Expect(stats.Timeouts).To(Equal(uint32(0))) }) It("should reuse connections", func() { @@ -135,6 +141,12 @@ var _ = Describe("pool", func() { pool := client.Pool() Expect(pool.Len()).To(Equal(1)) Expect(pool.FreeLen()).To(Equal(1)) + + stats := pool.Stats() + Expect(stats.Requests).To(Equal(uint32(100))) + Expect(stats.Hits).To(Equal(uint32(99))) + Expect(stats.Waits).To(Equal(uint32(0))) + Expect(stats.Timeouts).To(Equal(uint32(0))) }) It("should unblock client when connection is removed", func() { diff --git a/pubsub.go b/pubsub.go index bde81b5..c1fb462 100644 --- a/pubsub.go +++ b/pubsub.go @@ -245,10 +245,11 @@ func (c *PubSub) Receive() (interface{}, error) { return c.ReceiveTimeout(0) } -// ReceiveMessage returns a message or error. It automatically -// reconnects to Redis in case of network errors. +// ReceiveMessage returns a Message or error ignoring Subscription or Pong +// messages. It automatically reconnects to Redis Server and resubscribes +// to channels in case of network errors. func (c *PubSub) ReceiveMessage() (*Message, error) { - var errNum int + var errNum uint for { msgi, err := c.ReceiveTimeout(5 * time.Second) if err != nil { @@ -260,10 +261,9 @@ func (c *PubSub) ReceiveMessage() (*Message, error) { if errNum < 3 { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { err := c.Ping("") - if err == nil { - continue + if err != nil { + Logger.Printf("PubSub.Ping failed: %s", err) } - Logger.Printf("PubSub.Ping failed: %s", err) } } else { // 3 consequent errors - connection is bad @@ -297,7 +297,7 @@ func (c *PubSub) ReceiveMessage() (*Message, error) { } func (c *PubSub) putConn(cn *conn, err error) { - if !c.base.putConn(cn, err) { + if !c.base.putConn(cn, err, true) { c.nsub = 0 } } diff --git a/pubsub_test.go b/pubsub_test.go index 36c75c3..669c073 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -33,12 +33,6 @@ var _ = Describe("PubSub", func() { Expect(err).NotTo(HaveOccurred()) defer pubsub.Close() - n, err := client.Publish("mychannel1", "hello").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(n).To(Equal(int64(1))) - - Expect(pubsub.PUnsubscribe("mychannel*")).NotTo(HaveOccurred()) - { msgi, err := pubsub.ReceiveTimeout(time.Second) Expect(err).NotTo(HaveOccurred()) @@ -48,6 +42,18 @@ var _ = Describe("PubSub", func() { Expect(subscr.Count).To(Equal(1)) } + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err.(net.Error).Timeout()).To(Equal(true)) + Expect(msgi).To(BeNil()) + } + + n, err := client.Publish("mychannel1", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + Expect(pubsub.PUnsubscribe("mychannel*")).NotTo(HaveOccurred()) + { msgi, err := pubsub.ReceiveTimeout(time.Second) Expect(err).NotTo(HaveOccurred()) @@ -66,11 +72,8 @@ var _ = Describe("PubSub", func() { Expect(subscr.Count).To(Equal(0)) } - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - Expect(err.(net.Error).Timeout()).To(Equal(true)) - Expect(msgi).NotTo(HaveOccurred()) - } + stats := client.Pool().Stats() + Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2))) }) It("should pub/sub channels", func() { @@ -128,16 +131,6 @@ var _ = Describe("PubSub", func() { Expect(err).NotTo(HaveOccurred()) defer pubsub.Close() - n, err := client.Publish("mychannel", "hello").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(n).To(Equal(int64(1))) - - n, err = client.Publish("mychannel2", "hello2").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(n).To(Equal(int64(1))) - - Expect(pubsub.Unsubscribe("mychannel", "mychannel2")).NotTo(HaveOccurred()) - { msgi, err := pubsub.ReceiveTimeout(time.Second) Expect(err).NotTo(HaveOccurred()) @@ -156,6 +149,22 @@ var _ = Describe("PubSub", func() { Expect(subscr.Count).To(Equal(2)) } + { + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err.(net.Error).Timeout()).To(Equal(true)) + Expect(msgi).NotTo(HaveOccurred()) + } + + n, err := client.Publish("mychannel", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + n, err = client.Publish("mychannel2", "hello2").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + Expect(pubsub.Unsubscribe("mychannel", "mychannel2")).NotTo(HaveOccurred()) + { msgi, err := pubsub.ReceiveTimeout(time.Second) Expect(err).NotTo(HaveOccurred()) @@ -190,11 +199,8 @@ var _ = Describe("PubSub", func() { Expect(subscr.Count).To(Equal(0)) } - { - msgi, err := pubsub.ReceiveTimeout(time.Second) - Expect(err.(net.Error).Timeout()).To(Equal(true)) - Expect(msgi).NotTo(HaveOccurred()) - } + stats := client.Pool().Stats() + Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2))) }) It("should ping/pong", func() { @@ -277,6 +283,9 @@ var _ = Describe("PubSub", func() { Expect(msg.Payload).To(Equal("hello")) Eventually(done).Should(Receive()) + + stats := client.Pool().Stats() + Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2))) }) expectReceiveMessageOnError := func(pubsub *redis.PubSub) { @@ -305,6 +314,9 @@ var _ = Describe("PubSub", func() { Expect(msg.Payload).To(Equal("hello")) Eventually(done).Should(Receive()) + + stats := client.Pool().Stats() + Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2))) } It("Subscribe should reconnect on ReceiveMessage error", func() { diff --git a/redis.go b/redis.go index 488bfb9..5558ad1 100644 --- a/redis.go +++ b/redis.go @@ -13,6 +13,8 @@ var Logger = log.New(os.Stderr, "redis: ", log.LstdFlags) type baseClient struct { connPool pool opt *Options + + onClose func() error // hook called when client is closed } func (c *baseClient) String() string { @@ -23,8 +25,8 @@ func (c *baseClient) conn() (*conn, bool, error) { return c.connPool.Get() } -func (c *baseClient) putConn(cn *conn, err error) bool { - if isBadConn(err) { +func (c *baseClient) putConn(cn *conn, err error, allowTimeout bool) bool { + if isBadConn(err, allowTimeout) { err = c.connPool.Remove(cn, err) if err != nil { Logger.Printf("pool.Remove failed: %s", err) @@ -51,20 +53,16 @@ func (c *baseClient) process(cmd Cmder) { return } - if timeout := cmd.writeTimeout(); timeout != nil { - cn.WriteTimeout = *timeout - } else { - cn.WriteTimeout = c.opt.WriteTimeout - } - - if timeout := cmd.readTimeout(); timeout != nil { - cn.ReadTimeout = *timeout + readTimeout := cmd.readTimeout() + if readTimeout != nil { + cn.ReadTimeout = *readTimeout } else { cn.ReadTimeout = c.opt.ReadTimeout } + cn.WriteTimeout = c.opt.WriteTimeout if err := cn.writeCmds(cmd); err != nil { - c.putConn(cn, err) + c.putConn(cn, err, false) cmd.setErr(err) if shouldRetry(err) { continue @@ -73,7 +71,7 @@ func (c *baseClient) process(cmd Cmder) { } err = cmd.readReply(cn) - c.putConn(cn, err) + c.putConn(cn, err, readTimeout != nil) if shouldRetry(err) { continue } @@ -87,7 +85,16 @@ func (c *baseClient) process(cmd Cmder) { // It is rare to Close a Client, as the Client is meant to be // long-lived and shared between many goroutines. func (c *baseClient) Close() error { - return c.connPool.Close() + var retErr error + if c.onClose != nil { + if err := c.onClose(); err != nil && retErr == nil { + retErr = err + } + } + if err := c.connPool.Close(); err != nil && retErr == nil { + retErr = err + } + return retErr } //------------------------------------------------------------------------------ @@ -190,8 +197,10 @@ type Client struct { func newClient(opt *Options, pool pool) *Client { base := baseClient{opt: opt, connPool: pool} return &Client{ - baseClient: base, - commandable: commandable{process: base.process}, + baseClient: base, + commandable: commandable{ + process: base.process, + }, } } diff --git a/ring.go b/ring.go index 1d9b902..f1ae8ad 100644 --- a/ring.go +++ b/ring.go @@ -326,7 +326,7 @@ func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) { resetCmds(cmds) } failedCmds, err := execCmds(cn, cmds) - client.putConn(cn, err) + client.putConn(cn, err, false) if err != nil && retErr == nil { retErr = err } diff --git a/sentinel.go b/sentinel.go index 175c57e..db5db64 100644 --- a/sentinel.go +++ b/sentinel.go @@ -65,18 +65,31 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { opt: opt, } - return newClient(opt, failover.Pool()) + base := baseClient{ + opt: opt, + connPool: failover.Pool(), + + onClose: func() error { + return failover.Close() + }, + } + return &Client{ + baseClient: base, + commandable: commandable{ + process: base.process, + }, + } } //------------------------------------------------------------------------------ type sentinelClient struct { + baseClient commandable - *baseClient } func newSentinel(opt *Options) *sentinelClient { - base := &baseClient{ + base := baseClient{ opt: opt, connPool: newConnPool(opt), } @@ -116,8 +129,12 @@ type sentinelFailover struct { pool pool poolOnce sync.Once - lock sync.RWMutex - _sentinel *sentinelClient + mu sync.RWMutex + sentinel *sentinelClient +} + +func (d *sentinelFailover) Close() error { + return d.resetSentinel() } func (d *sentinelFailover) dial() (net.Conn, error) { @@ -137,15 +154,15 @@ func (d *sentinelFailover) Pool() pool { } func (d *sentinelFailover) MasterAddr() (string, error) { - defer d.lock.Unlock() - d.lock.Lock() + defer d.mu.Unlock() + d.mu.Lock() // Try last working sentinel. - if d._sentinel != nil { - addr, err := d._sentinel.GetMasterAddrByName(d.masterName).Result() + if d.sentinel != nil { + addr, err := d.sentinel.GetMasterAddrByName(d.masterName).Result() if err != nil { Logger.Printf("sentinel: GetMasterAddrByName %q failed: %s", d.masterName, err) - d.resetSentinel() + d._resetSentinel() } else { addr := net.JoinHostPort(addr[0], addr[1]) Logger.Printf("sentinel: %q addr is %s", d.masterName, addr) @@ -186,10 +203,26 @@ func (d *sentinelFailover) MasterAddr() (string, error) { func (d *sentinelFailover) setSentinel(sentinel *sentinelClient) { d.discoverSentinels(sentinel) - d._sentinel = sentinel + d.sentinel = sentinel go d.listen() } +func (d *sentinelFailover) resetSentinel() error { + d.mu.Lock() + err := d._resetSentinel() + d.mu.Unlock() + return err +} + +func (d *sentinelFailover) _resetSentinel() error { + var err error + if d.sentinel != nil { + err = d.sentinel.Close() + d.sentinel = nil + } + return err +} + func (d *sentinelFailover) discoverSentinels(sentinel *sentinelClient) { sentinels, err := sentinel.Sentinels(d.masterName).Result() if err != nil { @@ -247,55 +280,41 @@ func (d *sentinelFailover) listen() { var pubsub *PubSub for { if pubsub == nil { - pubsub = d._sentinel.PubSub() + pubsub = d.sentinel.PubSub() if err := pubsub.Subscribe("+switch-master"); err != nil { Logger.Printf("sentinel: Subscribe failed: %s", err) - d.lock.Lock() d.resetSentinel() - d.lock.Unlock() return } } - msg, err := pubsub.Receive() + msg, err := pubsub.ReceiveMessage() if err != nil { - Logger.Printf("sentinel: Receive failed: %s", err) + Logger.Printf("sentinel: ReceiveMessage failed: %s", err) pubsub.Close() + d.resetSentinel() return } - switch msg := msg.(type) { - case *Message: - switch msg.Channel { - case "+switch-master": - parts := strings.Split(msg.Payload, " ") - if parts[0] != d.masterName { - Logger.Printf("sentinel: ignore new %s addr", parts[0]) - continue - } - addr := net.JoinHostPort(parts[3], parts[4]) - Logger.Printf( - "sentinel: new %q addr is %s", - d.masterName, addr, - ) - - d.closeOldConns(addr) - default: - Logger.Printf("sentinel: unsupported message: %s", msg) + switch msg.Channel { + case "+switch-master": + parts := strings.Split(msg.Payload, " ") + if parts[0] != d.masterName { + Logger.Printf("sentinel: ignore new %s addr", parts[0]) + continue } - case *Subscription: - // Ignore. - default: - Logger.Printf("sentinel: unsupported message: %s", msg) + + addr := net.JoinHostPort(parts[3], parts[4]) + Logger.Printf( + "sentinel: new %q addr is %s", + d.masterName, addr, + ) + + d.closeOldConns(addr) } } } -func (d *sentinelFailover) resetSentinel() { - d._sentinel.Close() - d._sentinel = nil -} - func contains(slice []string, str string) bool { for _, s := range slice { if s == str {