diff --git a/example_test.go b/example_test.go index e9484104..b54e043a 100644 --- a/example_test.go +++ b/example_test.go @@ -2,7 +2,6 @@ package redis_test import ( "fmt" - "net" "strconv" "sync" "time" @@ -179,14 +178,14 @@ func ExamplePubSub() { panic(err) } - for { + for i := 0; i < 4; i++ { msgi, err := pubsub.ReceiveTimeout(100 * time.Millisecond) if err != nil { - if neterr, ok := err.(net.Error); ok && neterr.Timeout() { - // There are no more messages to process. Stop. - break + err := pubsub.Ping("") + if err != nil { + panic(err) } - panic(err) + continue } switch msg := msgi.(type) { @@ -194,6 +193,8 @@ func ExamplePubSub() { fmt.Println(msg.Kind, msg.Channel) case *redis.Message: fmt.Println(msg.Channel, msg.Payload) + case *redis.Pong: + fmt.Println(msg) default: panic(fmt.Sprintf("unknown message: %#v", msgi)) } @@ -201,6 +202,7 @@ func ExamplePubSub() { // Output: subscribe mychannel // mychannel hello + // Pong } func ExampleScript() { diff --git a/pubsub.go b/pubsub.go index 26fa8528..e143448b 100644 --- a/pubsub.go +++ b/pubsub.go @@ -26,6 +26,20 @@ func (c *Client) Publish(channel, message string) *IntCmd { return req } +func (c *PubSub) Ping(payload string) error { + cn, err := c.conn() + if err != nil { + return err + } + + args := []interface{}{"PING"} + if payload != "" { + args = append(args, payload) + } + cmd := NewCmd(args...) + return cn.writeCmds(cmd) +} + // Message received as result of a PUBLISH command issued by another client. type Message struct { Channel string @@ -48,6 +62,18 @@ func (m *PMessage) String() string { return fmt.Sprintf("PMessage<%s: %s>", m.Channel, m.Payload) } +// Pong received as result of a PING command issued by another client. +type Pong struct { + Payload string +} + +func (p *Pong) String() string { + if p.Payload != "" { + return fmt.Sprintf("Pong<%s>", p.Payload) + } + return "Pong" +} + // Message received after a successful subscription to channel. type Subscription struct { // Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe". @@ -66,22 +92,8 @@ func (c *PubSub) Receive() (interface{}, error) { return c.ReceiveTimeout(0) } -func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { - cn, err := c.conn() - if err != nil { - return nil, err - } - cn.ReadTimeout = timeout - - cmd := NewSliceCmd() - if err := cmd.parseReply(cn.rd); err != nil { - return nil, err - } - - reply := cmd.Val() - - kind := reply[0].(string) - switch kind { +func newMessage(reply []interface{}) (interface{}, error) { + switch kind := reply[0].(string); kind { case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": return &Subscription{ Kind: kind, @@ -99,9 +111,27 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { Channel: reply[2].(string), Payload: reply[3].(string), }, nil + case "pong": + return &Pong{ + Payload: reply[1].(string), + }, nil + default: + return nil, fmt.Errorf("redis: unsupported pubsub notification: %q", kind) } +} - return nil, fmt.Errorf("redis: unsupported pubsub notification: %q", kind) +func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { + cn, err := c.conn() + if err != nil { + return nil, err + } + cn.ReadTimeout = timeout + + cmd := NewSliceCmd() + if err := cmd.parseReply(cn.rd); err != nil { + return nil, err + } + return newMessage(cmd.Val()) } func (c *PubSub) subscribe(cmd string, channels ...string) error { diff --git a/pubsub_test.go b/pubsub_test.go index 82c0ca49..bf59a2cd 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -12,24 +12,22 @@ import ( var _ = Describe("PubSub", func() { var client *redis.Client + var pubsub *redis.PubSub BeforeEach(func() { client = redis.NewClient(&redis.Options{ Addr: redisAddr, }) + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + pubsub = client.PubSub() }) AfterEach(func() { - Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(pubsub.Close()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred()) }) It("should support pattern matching", func() { - pubsub := client.PubSub() - defer func() { - Expect(pubsub.Close()).NotTo(HaveOccurred()) - }() - Expect(pubsub.PSubscribe("mychannel*")).NotTo(HaveOccurred()) pub := client.Publish("mychannel1", "hello") @@ -77,8 +75,6 @@ var _ = Describe("PubSub", func() { Expect(err).NotTo(HaveOccurred()) Expect(channels).To(BeEmpty()) - pubsub := client.PubSub() - defer pubsub.Close() Expect(pubsub.Subscribe("mychannel", "mychannel2")).NotTo(HaveOccurred()) channels, err = client.PubSubChannels("mychannel*").Result() @@ -95,8 +91,6 @@ var _ = Describe("PubSub", func() { }) It("should return the numbers of subscribers", func() { - pubsub := client.PubSub() - defer pubsub.Close() Expect(pubsub.Subscribe("mychannel", "mychannel2")).NotTo(HaveOccurred()) channels, err := client.PubSubNumSub("mychannel", "mychannel2", "mychannel3").Result() @@ -113,8 +107,6 @@ var _ = Describe("PubSub", func() { Expect(err).NotTo(HaveOccurred()) Expect(num).To(Equal(int64(0))) - pubsub := client.PubSub() - defer pubsub.Close() Expect(pubsub.PSubscribe("*")).NotTo(HaveOccurred()) num, err = client.PubSubNumPat().Result() @@ -123,11 +115,6 @@ var _ = Describe("PubSub", func() { }) It("should pub/sub", func() { - pubsub := client.PubSub() - defer func() { - Expect(pubsub.Close()).NotTo(HaveOccurred()) - }() - Expect(pubsub.Subscribe("mychannel", "mychannel2")).NotTo(HaveOccurred()) pub := client.Publish("mychannel", "hello") @@ -199,4 +186,36 @@ var _ = Describe("PubSub", func() { } }) + It("should ping/pong", func() { + err := pubsub.Subscribe("mychannel") + Expect(err).NotTo(HaveOccurred()) + + _, err = pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + + err = pubsub.Ping("") + Expect(err).NotTo(HaveOccurred()) + + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + pong := msgi.(*redis.Pong) + Expect(pong.Payload).To(Equal("")) + }) + + It("should ping/pong with payload", func() { + err := pubsub.Subscribe("mychannel") + Expect(err).NotTo(HaveOccurred()) + + _, err = pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + + err = pubsub.Ping("hello") + Expect(err).NotTo(HaveOccurred()) + + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + pong := msgi.(*redis.Pong) + Expect(pong.Payload).To(Equal("hello")) + }) + })