diff --git a/error.go b/error.go index 1365ca19..dce10a37 100644 --- a/error.go +++ b/error.go @@ -33,14 +33,14 @@ func isNetworkError(err error) bool { return ok } -func isBadConn(cn *conn, ei error) bool { - if cn.rd.Buffered() > 0 { - return true - } - if ei == nil { +func isBadConn(err error) bool { + if err == nil { return false } - if _, ok := ei.(redisError); ok { + if _, ok := err.(redisError); ok { + return false + } + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { return false } return true diff --git a/export_test.go b/export_test.go index e7b4b056..4a6de2c6 100644 --- a/export_test.go +++ b/export_test.go @@ -6,6 +6,10 @@ func (c *baseClient) Pool() pool { return c.connPool } +func (c *PubSub) Pool() pool { + return c.base.connPool +} + var NewConnDialer = newConnDialer func (cn *conn) SetNetConn(netcn net.Conn) { diff --git a/main_test.go b/main_test.go index 471c8ee9..b9b3e218 100644 --- a/main_test.go +++ b/main_test.go @@ -7,7 +7,6 @@ import ( "os/exec" "path/filepath" "sync/atomic" - "syscall" "testing" "time" @@ -243,10 +242,6 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) { //------------------------------------------------------------------------------ -var ( - errTimeout = syscall.ETIMEDOUT -) - type badConnError string func (e badConnError) Error() string { return string(e) } diff --git a/multi.go b/multi.go index 236bd30c..7ffc7e05 100644 --- a/multi.go +++ b/multi.go @@ -45,18 +45,6 @@ func (c *Client) Multi() *Multi { return multi } -func (c *Multi) putConn(cn *conn, err error) { - if isBadConn(cn, err) { - // Close current connection. - c.base.connPool.(*stickyConnPool).Reset(err) - } else { - err := c.base.connPool.Put(cn) - if err != nil { - Logger.Printf("pool.Put failed: %s", err) - } - } -} - func (c *Multi) process(cmd Cmder) { if c.cmds == nil { c.base.process(cmd) @@ -145,7 +133,7 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) { } err = c.execCmds(cn, cmds) - c.putConn(cn, err) + c.base.putConn(cn, err) return retCmds, err } diff --git a/multi_test.go b/multi_test.go index 1e6f3603..459d0a62 100644 --- a/multi_test.go +++ b/multi_test.go @@ -166,4 +166,31 @@ var _ = Describe("Multi", func() { }) Expect(err).NotTo(HaveOccurred()) }) + + It("should recover from bad connection when there are no commands", func() { + // Put bad connection in the pool. + cn, _, err := client.Pool().Get() + Expect(err).NotTo(HaveOccurred()) + + cn.SetNetConn(&badConn{}) + err = client.Pool().Put(cn) + Expect(err).NotTo(HaveOccurred()) + + { + tx, err := client.Watch("key") + Expect(err).To(MatchError("bad connection")) + Expect(tx).To(BeNil()) + } + + { + tx, err := client.Watch("key") + Expect(err).NotTo(HaveOccurred()) + + err = tx.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + err = tx.Close() + Expect(err).NotTo(HaveOccurred()) + } + }) }) diff --git a/pool.go b/pool.go index 87eb5ed6..e2f9f2b3 100644 --- a/pool.go +++ b/pool.go @@ -246,13 +246,14 @@ func (p *connPool) Get() (cn *conn, isNew bool, err error) { // Try to create a new one. if p.conns.Reserve() { + isNew = true + cn, err = p.new() if err != nil { p.conns.Remove(nil) return } p.conns.Add(cn) - isNew = true return } @@ -481,13 +482,13 @@ func (p *stickyConnPool) Put(cn *conn) error { return nil } -func (p *stickyConnPool) remove(reason error) (err error) { - err = p.pool.Remove(p.cn, reason) +func (p *stickyConnPool) remove(reason error) error { + err := p.pool.Remove(p.cn, reason) p.cn = nil return err } -func (p *stickyConnPool) Remove(cn *conn, _ error) error { +func (p *stickyConnPool) Remove(cn *conn, reason error) error { defer p.mx.Unlock() p.mx.Lock() if p.closed { @@ -499,7 +500,7 @@ func (p *stickyConnPool) Remove(cn *conn, _ error) error { if cn != nil && p.cn != cn { panic("p.cn != cn") } - return nil + return p.remove(reason) } func (p *stickyConnPool) Len() int { @@ -522,15 +523,6 @@ func (p *stickyConnPool) FreeLen() int { func (p *stickyConnPool) Stats() *PoolStats { return nil } -func (p *stickyConnPool) Reset(reason error) (err error) { - p.mx.Lock() - if p.cn != nil { - err = p.remove(reason) - } - p.mx.Unlock() - return err -} - func (p *stickyConnPool) Close() error { defer p.mx.Unlock() p.mx.Lock() diff --git a/pubsub.go b/pubsub.go index 5c2d2e8b..bde81b5e 100644 --- a/pubsub.go +++ b/pubsub.go @@ -17,16 +17,18 @@ func (c *Client) Publish(channel, message string) *IntCmd { // http://redis.io/topics/pubsub. It's NOT safe for concurrent use by // multiple goroutines. type PubSub struct { - *baseClient + base *baseClient channels []string patterns []string + + nsub int // number of active subscriptions } // Deprecated. Use Subscribe/PSubscribe instead. func (c *Client) PubSub() *PubSub { return &PubSub{ - baseClient: &baseClient{ + base: &baseClient{ opt: c.opt, connPool: newStickyConnPool(c.connPool, false), }, @@ -46,7 +48,7 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) { } func (c *PubSub) subscribe(cmd string, channels ...string) error { - cn, _, err := c.conn() + cn, _, err := c.base.conn() if err != nil { return err } @@ -65,6 +67,7 @@ func (c *PubSub) Subscribe(channels ...string) error { err := c.subscribe("SUBSCRIBE", channels...) if err == nil { c.channels = append(c.channels, channels...) + c.nsub += len(channels) } return err } @@ -74,6 +77,7 @@ func (c *PubSub) PSubscribe(patterns ...string) error { err := c.subscribe("PSUBSCRIBE", patterns...) if err == nil { c.patterns = append(c.patterns, patterns...) + c.nsub += len(patterns) } return err } @@ -113,8 +117,12 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error { return err } +func (c *PubSub) Close() error { + return c.base.Close() +} + func (c *PubSub) Ping(payload string) error { - cn, _, err := c.conn() + cn, _, err := c.base.conn() if err != nil { return err } @@ -178,7 +186,7 @@ func (p *Pong) String() string { return "Pong" } -func newMessage(reply []interface{}) (interface{}, error) { +func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) { switch kind := reply[0].(string); kind { case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": return &Subscription{ @@ -210,7 +218,11 @@ func newMessage(reply []interface{}) (interface{}, error) { // is not received in time. This is low-level API and most clients // should use ReceiveMessage. func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { - cn, _, err := c.conn() + if c.nsub == 0 { + c.resubscribe() + } + + cn, _, err := c.base.conn() if err != nil { return nil, err } @@ -222,7 +234,8 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { if err != nil { return nil, err } - return newMessage(cmd.Val()) + + return c.newMessage(cmd.Val()) } // Receive returns a message as a Subscription, Message, PMessage, @@ -232,22 +245,6 @@ func (c *PubSub) Receive() (interface{}, error) { return c.ReceiveTimeout(0) } -func (c *PubSub) reconnect(reason error) { - // Close current connection. - c.connPool.(*stickyConnPool).Reset(reason) - - if len(c.channels) > 0 { - if err := c.Subscribe(c.channels...); err != nil { - Logger.Printf("Subscribe failed: %s", err) - } - } - if len(c.patterns) > 0 { - if err := c.PSubscribe(c.patterns...); err != nil { - Logger.Printf("PSubscribe failed: %s", err) - } - } -} - // ReceiveMessage returns a message or error. It automatically // reconnects to Redis in case of network errors. func (c *PubSub) ReceiveMessage() (*Message, error) { @@ -259,10 +256,8 @@ func (c *PubSub) ReceiveMessage() (*Message, error) { return nil, err } - goodConn := errNum == 0 errNum++ - - if goodConn { + if errNum < 3 { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { err := c.Ping("") if err == nil { @@ -270,16 +265,16 @@ func (c *PubSub) ReceiveMessage() (*Message, error) { } Logger.Printf("PubSub.Ping failed: %s", err) } - } - - if errNum > 2 { + } else { + // 3 consequent errors - connection is bad + // and/or Redis Server is down. + // Sleep to not exceed max number of open connections. time.Sleep(time.Second) } - c.reconnect(err) continue } - // Reset error number. + // Reset error number, because we received a message. errNum = 0 switch msg := msgi.(type) { @@ -300,3 +295,22 @@ func (c *PubSub) ReceiveMessage() (*Message, error) { } } } + +func (c *PubSub) putConn(cn *conn, err error) { + if !c.base.putConn(cn, err) { + c.nsub = 0 + } +} + +func (c *PubSub) resubscribe() { + if len(c.channels) > 0 { + if err := c.Subscribe(c.channels...); err != nil { + Logger.Printf("Subscribe failed: %s", err) + } + } + if len(c.patterns) > 0 { + if err := c.PSubscribe(c.patterns...); err != nil { + Logger.Printf("PSubscribe failed: %s", err) + } + } +} diff --git a/pubsub_test.go b/pubsub_test.go index bf940d47..36c75c38 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "io" "net" "sync" "time" @@ -230,18 +231,41 @@ var _ = Describe("PubSub", func() { Expect(pong.Payload).To(Equal("hello")) }) - It("should ReceiveMessage", func() { + It("should multi-ReceiveMessage", func() { pubsub, err := client.Subscribe("mychannel") Expect(err).NotTo(HaveOccurred()) defer pubsub.Close() - var wg sync.WaitGroup - wg.Add(1) + err = client.Publish("mychannel", "hello").Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.Publish("mychannel", "world").Err() + Expect(err).NotTo(HaveOccurred()) + + msg, err := pubsub.ReceiveMessage() + Expect(err).NotTo(HaveOccurred()) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) + + msg, err = pubsub.ReceiveMessage() + Expect(err).NotTo(HaveOccurred()) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("world")) + }) + + It("should ReceiveMessage after timeout", func() { + pubsub, err := client.Subscribe("mychannel") + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + done := make(chan bool, 1) go func() { defer GinkgoRecover() - defer wg.Done() + defer func() { + done <- true + }() - time.Sleep(readTimeout + 100*time.Millisecond) + time.Sleep(5*time.Second + 100*time.Millisecond) n, err := client.Publish("mychannel", "hello").Result() Expect(err).NotTo(HaveOccurred()) Expect(n).To(Equal(int64(1))) @@ -252,22 +276,23 @@ var _ = Describe("PubSub", func() { Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Payload).To(Equal("hello")) - wg.Wait() + Eventually(done).Should(Receive()) }) - expectReceiveMessage := func(pubsub *redis.PubSub) { + expectReceiveMessageOnError := func(pubsub *redis.PubSub) { cn1, _, err := pubsub.Pool().Get() Expect(err).NotTo(HaveOccurred()) cn1.SetNetConn(&badConn{ - readErr: errTimeout, - writeErr: errTimeout, + readErr: io.EOF, + writeErr: io.EOF, }) - var wg sync.WaitGroup - wg.Add(1) + done := make(chan bool, 1) go func() { defer GinkgoRecover() - defer wg.Done() + defer func() { + done <- true + }() time.Sleep(100 * time.Millisecond) err := client.Publish("mychannel", "hello").Err() @@ -279,7 +304,7 @@ var _ = Describe("PubSub", func() { Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Payload).To(Equal("hello")) - wg.Wait() + Eventually(done).Should(Receive()) } It("Subscribe should reconnect on ReceiveMessage error", func() { @@ -287,7 +312,7 @@ var _ = Describe("PubSub", func() { Expect(err).NotTo(HaveOccurred()) defer pubsub.Close() - expectReceiveMessage(pubsub) + expectReceiveMessageOnError(pubsub) }) It("PSubscribe should reconnect on ReceiveMessage error", func() { @@ -295,7 +320,7 @@ var _ = Describe("PubSub", func() { Expect(err).NotTo(HaveOccurred()) defer pubsub.Close() - expectReceiveMessage(pubsub) + expectReceiveMessageOnError(pubsub) }) It("should return on Close", func() { diff --git a/redis.go b/redis.go index 8372963a..f37c97b8 100644 --- a/redis.go +++ b/redis.go @@ -23,15 +23,20 @@ func (c *baseClient) conn() (*conn, bool, error) { return c.connPool.Get() } -func (c *baseClient) putConn(cn *conn, err error) { - if isBadConn(cn, err) { +func (c *baseClient) putConn(cn *conn, err error) bool { + if isBadConn(err) { err = c.connPool.Remove(cn, err) - } else { - err = c.connPool.Put(cn) + if err != nil { + log.Printf("pool.Remove failed: %s", err) + } + return false } + + err = c.connPool.Put(cn) if err != nil { - Logger.Printf("pool.Put failed: %s", err) + log.Printf("pool.Put failed: %s", err) } + return true } func (c *baseClient) process(cmd Cmder) { diff --git a/sentinel.go b/sentinel.go index bb950646..175c57e8 100644 --- a/sentinel.go +++ b/sentinel.go @@ -88,7 +88,7 @@ func newSentinel(opt *Options) *sentinelClient { func (c *sentinelClient) PubSub() *PubSub { return &PubSub{ - baseClient: &baseClient{ + base: &baseClient{ opt: c.opt, connPool: newStickyConnPool(c.connPool, false), },