diff --git a/example_test.go b/example_test.go index 58b7dfe..cb4b8f6 100644 --- a/example_test.go +++ b/example_test.go @@ -219,14 +219,31 @@ func ExamplePubSub() { panic(err) } - for i := 0; i < 4; i++ { + msg, err := pubsub.ReceiveMessage() + if err != nil { + panic(err) + } + + fmt.Println(msg.Channel, msg.Payload) + // Output: mychannel hello +} + +func ExamplePubSub_Receive() { + pubsub, err := client.Subscribe("mychannel") + if err != nil { + panic(err) + } + defer pubsub.Close() + + err = client.Publish("mychannel", "hello").Err() + if err != nil { + panic(err) + } + + for i := 0; i < 2; i++ { msgi, err := pubsub.ReceiveTimeout(100 * time.Millisecond) if err != nil { - err := pubsub.Ping("") - if err != nil { - panic(err) - } - continue + panic(err) } switch msg := msgi.(type) { @@ -234,8 +251,6 @@ func ExamplePubSub() { fmt.Println(msg.Kind, msg.Channel) case *redis.Message: fmt.Println(msg.Channel, msg.Payload) - case *redis.Pong: - fmt.Println(msg) default: panic(fmt.Sprintf("unknown message: %#v", msgi)) } @@ -243,7 +258,6 @@ func ExamplePubSub() { // Output: subscribe mychannel // mychannel hello - // Pong } func ExampleScript() { diff --git a/main_test.go b/main_test.go index c4b5a59..d2f8d2a 100644 --- a/main_test.go +++ b/main_test.go @@ -8,6 +8,7 @@ import ( "path/filepath" "strings" "sync/atomic" + "syscall" "testing" "time" @@ -231,20 +232,33 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) { //------------------------------------------------------------------------------ -type badNetConn struct { +var errTimeout = syscall.ETIMEDOUT + +type badConn struct { net.TCPConn + + readDelay, writeDelay time.Duration + readErr, writeErr error } -var _ net.Conn = &badNetConn{} +var _ net.Conn = &badConn{} -func newBadNetConn() net.Conn { - return &badNetConn{} +func (cn *badConn) Read([]byte) (int, error) { + if cn.readDelay != 0 { + time.Sleep(cn.readDelay) + } + if cn.readErr != nil { + return 0, cn.readErr + } + return 0, net.UnknownNetworkError("badConn") } -func (badNetConn) Read([]byte) (int, error) { - return 0, net.UnknownNetworkError("badNetConn") -} - -func (badNetConn) Write([]byte) (int, error) { - return 0, net.UnknownNetworkError("badNetConn") +func (cn *badConn) Write([]byte) (int, error) { + if cn.writeDelay != 0 { + time.Sleep(cn.writeDelay) + } + if cn.writeErr != nil { + return 0, cn.writeErr + } + return 0, net.UnknownNetworkError("badConn") } diff --git a/pool.go b/pool.go index f52eb6f..bd494d8 100644 --- a/pool.go +++ b/pool.go @@ -396,8 +396,8 @@ func (p *singleConnPool) Remove(cn *conn) error { if p.cn == nil { panic("p.cn == nil") } - if p.cn != cn { - panic("p.cn != cn") + if cn != nil && cn != p.cn { + panic("cn != p.cn") } if p.closed { return errClosed diff --git a/pubsub.go b/pubsub.go index be36caa..b85e475 100644 --- a/pubsub.go +++ b/pubsub.go @@ -2,6 +2,8 @@ package redis import ( "fmt" + "log" + "net" "time" ) @@ -16,6 +18,9 @@ func (c *Client) Publish(channel, message string) *IntCmd { // http://redis.io/topics/pubsub. type PubSub struct { *baseClient + + channels []string + patterns []string } // Deprecated. Use Subscribe/PSubscribe instead. @@ -40,6 +45,71 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) { return pubsub, pubsub.PSubscribe(channels...) } +func (c *PubSub) subscribe(cmd string, channels ...string) error { + cn, err := c.conn() + if err != nil { + return err + } + + args := make([]interface{}, 1+len(channels)) + args[0] = cmd + for i, channel := range channels { + args[1+i] = channel + } + req := NewSliceCmd(args...) + return cn.writeCmds(req) +} + +// Subscribes the client to the specified channels. +func (c *PubSub) Subscribe(channels ...string) error { + err := c.subscribe("SUBSCRIBE", channels...) + if err == nil { + c.channels = append(c.channels, channels...) + } + return err +} + +// Subscribes the client to the given patterns. +func (c *PubSub) PSubscribe(patterns ...string) error { + err := c.subscribe("PSUBSCRIBE", patterns...) + if err == nil { + c.channels = append(c.channels, patterns...) + } + return err +} + +func remove(ss []string, es ...string) []string { + for _, e := range es { + for i, s := range ss { + if s == e { + ss = append(ss[:i], ss[i+1:]...) + break + } + } + } + return ss +} + +// Unsubscribes the client from the given channels, or from all of +// them if none is given. +func (c *PubSub) Unsubscribe(channels ...string) error { + err := c.subscribe("UNSUBSCRIBE", channels...) + if err == nil { + c.channels = remove(c.channels, channels...) + } + return err +} + +// Unsubscribes the client from the given patterns, or from all of +// them if none is given. +func (c *PubSub) PUnsubscribe(patterns ...string) error { + err := c.subscribe("PUNSUBSCRIBE", patterns...) + if err == nil { + c.patterns = remove(c.patterns, patterns...) + } + return err +} + func (c *PubSub) Ping(payload string) error { cn, err := c.conn() if err != nil { @@ -71,6 +141,7 @@ func (m *Subscription) String() string { // Message received as result of a PUBLISH command issued by another client. type Message struct { Channel string + Pattern string Payload string } @@ -78,6 +149,8 @@ func (m *Message) String() string { return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload) } +// TODO: remove PMessage if favor of Message + // Message matching a pattern-matching subscription received as result // of a PUBLISH command issued by another client. type PMessage struct { @@ -102,12 +175,6 @@ func (p *Pong) String() string { return "Pong" } -// Returns a message as a Subscription, Message, PMessage, Pong or -// error. See PubSub example for details. -func (c *PubSub) Receive() (interface{}, error) { - return c.ReceiveTimeout(0) -} - func newMessage(reply []interface{}) (interface{}, error) { switch kind := reply[0].(string); kind { case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": @@ -137,7 +204,8 @@ func newMessage(reply []interface{}) (interface{}, error) { } // ReceiveTimeout acts like Receive but returns an error if message -// is not received in time. +// is not received in time. This is low-level API and most clients +// should use ReceiveMessage. func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { cn, err := c.conn() if err != nil { @@ -152,39 +220,74 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { return newMessage(cmd.Val()) } -func (c *PubSub) subscribe(cmd string, channels ...string) error { - cn, err := c.conn() - if err != nil { - return err +// Receive returns a message as a Subscription, Message, PMessage, +// Pong or error. See PubSub example for details. This is low-level +// API and most clients should use ReceiveMessage. +func (c *PubSub) Receive() (interface{}, error) { + return c.ReceiveTimeout(0) +} + +func (c *PubSub) reconnect() { + c.connPool.Remove(nil) // close current connection + if len(c.channels) > 0 { + if err := c.Subscribe(c.channels...); err != nil { + log.Printf("redis: Subscribe failed: %s", err) + } } - - args := make([]interface{}, 1+len(channels)) - args[0] = cmd - for i, channel := range channels { - args[1+i] = channel + if len(c.patterns) > 0 { + if err := c.PSubscribe(c.patterns...); err != nil { + log.Printf("redis: Subscribe failed: %s", err) + } } - req := NewSliceCmd(args...) - return cn.writeCmds(req) } -// Subscribes the client to the specified channels. -func (c *PubSub) Subscribe(channels ...string) error { - return c.subscribe("SUBSCRIBE", channels...) -} +// ReceiveMessage returns a message or error. It automatically +// reconnects to Redis in case of network errors. +func (c *PubSub) ReceiveMessage() (*Message, error) { + var badConn bool + for { + msgi, err := c.ReceiveTimeout(5 * time.Second) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + if badConn { + c.reconnect() + badConn = false + continue + } -// Subscribes the client to the given patterns. -func (c *PubSub) PSubscribe(patterns ...string) error { - return c.subscribe("PSUBSCRIBE", patterns...) -} + err := c.Ping("") + if err != nil { + c.reconnect() + } else { + badConn = true + } + continue + } -// Unsubscribes the client from the given channels, or from all of -// them if none is given. -func (c *PubSub) Unsubscribe(channels ...string) error { - return c.subscribe("UNSUBSCRIBE", channels...) -} + if isNetworkError(err) { + c.reconnect() + continue + } -// Unsubscribes the client from the given patterns, or from all of -// them if none is given. -func (c *PubSub) PUnsubscribe(patterns ...string) error { - return c.subscribe("PUNSUBSCRIBE", patterns...) + return nil, err + } + + switch msg := msgi.(type) { + case *Subscription: + // Ignore. + case *Pong: + badConn = false + // Ignore. + case *Message: + return msg, nil + case *PMessage: + return &Message{ + Channel: msg.Channel, + Pattern: msg.Pattern, + Payload: msg.Payload, + }, nil + default: + return nil, fmt.Errorf("redis: unknown message: %T", msgi) + } + } } diff --git a/pubsub_test.go b/pubsub_test.go index ac1d629..5a7b0da 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -12,10 +12,12 @@ import ( var _ = Describe("PubSub", func() { var client *redis.Client + readTimeout := 3 * time.Second BeforeEach(func() { client = redis.NewClient(&redis.Options{ - Addr: redisAddr, + Addr: redisAddr, + ReadTimeout: readTimeout, }) Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) }) @@ -227,4 +229,51 @@ var _ = Describe("PubSub", func() { Expect(pong.Payload).To(Equal("hello")) }) + It("should ReceiveMessage", func() { + pubsub, err := client.Subscribe("mychannel") + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + go func() { + defer GinkgoRecover() + + time.Sleep(readTimeout + 100*time.Millisecond) + n, err := client.Publish("mychannel", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + }() + + msg, err := pubsub.ReceiveMessage() + Expect(err).NotTo(HaveOccurred()) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) + }) + + It("should reconnect on ReceiveMessage error", func() { + pubsub, err := client.Subscribe("mychannel") + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + cn, err := pubsub.Pool().Get() + Expect(err).NotTo(HaveOccurred()) + cn.SetNetConn(&badConn{ + readErr: errTimeout, + writeErr: errTimeout, + }) + + go func() { + defer GinkgoRecover() + + time.Sleep(100 * time.Millisecond) + n, err := client.Publish("mychannel", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(2))) + }() + + msg, err := pubsub.ReceiveMessage() + Expect(err).NotTo(HaveOccurred()) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) + }) + }) diff --git a/redis_test.go b/redis_test.go index b1a2547..acc8ca1 100644 --- a/redis_test.go +++ b/redis_test.go @@ -159,7 +159,8 @@ var _ = Describe("Client", func() { // Put bad connection in the pool. cn, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.SetNetConn(newBadNetConn()) + + cn.SetNetConn(&badConn{}) Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) err = client.Ping().Err()