Merge pull request #164 from go-redis/fix/pubsub-receive-message

Add PubSub.ReceiveMessage.
This commit is contained in:
Vladimir Mihailenco 2015-09-10 13:07:36 +03:00
commit 02154c3b3a
6 changed files with 239 additions and 58 deletions

View File

@ -219,14 +219,31 @@ func ExamplePubSub() {
panic(err) panic(err)
} }
for i := 0; i < 4; i++ { msg, err := pubsub.ReceiveMessage()
if err != nil {
panic(err)
}
fmt.Println(msg.Channel, msg.Payload)
// Output: mychannel hello
}
func ExamplePubSub_Receive() {
pubsub, err := client.Subscribe("mychannel")
if err != nil {
panic(err)
}
defer pubsub.Close()
err = client.Publish("mychannel", "hello").Err()
if err != nil {
panic(err)
}
for i := 0; i < 2; i++ {
msgi, err := pubsub.ReceiveTimeout(100 * time.Millisecond) msgi, err := pubsub.ReceiveTimeout(100 * time.Millisecond)
if err != nil { if err != nil {
err := pubsub.Ping("") panic(err)
if err != nil {
panic(err)
}
continue
} }
switch msg := msgi.(type) { switch msg := msgi.(type) {
@ -234,8 +251,6 @@ func ExamplePubSub() {
fmt.Println(msg.Kind, msg.Channel) fmt.Println(msg.Kind, msg.Channel)
case *redis.Message: case *redis.Message:
fmt.Println(msg.Channel, msg.Payload) fmt.Println(msg.Channel, msg.Payload)
case *redis.Pong:
fmt.Println(msg)
default: default:
panic(fmt.Sprintf("unknown message: %#v", msgi)) panic(fmt.Sprintf("unknown message: %#v", msgi))
} }
@ -243,7 +258,6 @@ func ExamplePubSub() {
// Output: subscribe mychannel // Output: subscribe mychannel
// mychannel hello // mychannel hello
// Pong
} }
func ExampleScript() { func ExampleScript() {

View File

@ -8,6 +8,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"sync/atomic" "sync/atomic"
"syscall"
"testing" "testing"
"time" "time"
@ -231,20 +232,33 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type badNetConn struct { var errTimeout = syscall.ETIMEDOUT
type badConn struct {
net.TCPConn net.TCPConn
readDelay, writeDelay time.Duration
readErr, writeErr error
} }
var _ net.Conn = &badNetConn{} var _ net.Conn = &badConn{}
func newBadNetConn() net.Conn { func (cn *badConn) Read([]byte) (int, error) {
return &badNetConn{} if cn.readDelay != 0 {
time.Sleep(cn.readDelay)
}
if cn.readErr != nil {
return 0, cn.readErr
}
return 0, net.UnknownNetworkError("badConn")
} }
func (badNetConn) Read([]byte) (int, error) { func (cn *badConn) Write([]byte) (int, error) {
return 0, net.UnknownNetworkError("badNetConn") if cn.writeDelay != 0 {
} time.Sleep(cn.writeDelay)
}
func (badNetConn) Write([]byte) (int, error) { if cn.writeErr != nil {
return 0, net.UnknownNetworkError("badNetConn") return 0, cn.writeErr
}
return 0, net.UnknownNetworkError("badConn")
} }

View File

@ -396,8 +396,8 @@ func (p *singleConnPool) Remove(cn *conn) error {
if p.cn == nil { if p.cn == nil {
panic("p.cn == nil") panic("p.cn == nil")
} }
if p.cn != cn { if cn != nil && cn != p.cn {
panic("p.cn != cn") panic("cn != p.cn")
} }
if p.closed { if p.closed {
return errClosed return errClosed

173
pubsub.go
View File

@ -2,6 +2,8 @@ package redis
import ( import (
"fmt" "fmt"
"log"
"net"
"time" "time"
) )
@ -16,6 +18,9 @@ func (c *Client) Publish(channel, message string) *IntCmd {
// http://redis.io/topics/pubsub. // http://redis.io/topics/pubsub.
type PubSub struct { type PubSub struct {
*baseClient *baseClient
channels []string
patterns []string
} }
// Deprecated. Use Subscribe/PSubscribe instead. // Deprecated. Use Subscribe/PSubscribe instead.
@ -40,6 +45,71 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) {
return pubsub, pubsub.PSubscribe(channels...) return pubsub, pubsub.PSubscribe(channels...)
} }
func (c *PubSub) subscribe(cmd string, channels ...string) error {
cn, err := c.conn()
if err != nil {
return err
}
args := make([]interface{}, 1+len(channels))
args[0] = cmd
for i, channel := range channels {
args[1+i] = channel
}
req := NewSliceCmd(args...)
return cn.writeCmds(req)
}
// Subscribes the client to the specified channels.
func (c *PubSub) Subscribe(channels ...string) error {
err := c.subscribe("SUBSCRIBE", channels...)
if err == nil {
c.channels = append(c.channels, channels...)
}
return err
}
// Subscribes the client to the given patterns.
func (c *PubSub) PSubscribe(patterns ...string) error {
err := c.subscribe("PSUBSCRIBE", patterns...)
if err == nil {
c.channels = append(c.channels, patterns...)
}
return err
}
func remove(ss []string, es ...string) []string {
for _, e := range es {
for i, s := range ss {
if s == e {
ss = append(ss[:i], ss[i+1:]...)
break
}
}
}
return ss
}
// Unsubscribes the client from the given channels, or from all of
// them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error {
err := c.subscribe("UNSUBSCRIBE", channels...)
if err == nil {
c.channels = remove(c.channels, channels...)
}
return err
}
// Unsubscribes the client from the given patterns, or from all of
// them if none is given.
func (c *PubSub) PUnsubscribe(patterns ...string) error {
err := c.subscribe("PUNSUBSCRIBE", patterns...)
if err == nil {
c.patterns = remove(c.patterns, patterns...)
}
return err
}
func (c *PubSub) Ping(payload string) error { func (c *PubSub) Ping(payload string) error {
cn, err := c.conn() cn, err := c.conn()
if err != nil { if err != nil {
@ -71,6 +141,7 @@ func (m *Subscription) String() string {
// Message received as result of a PUBLISH command issued by another client. // Message received as result of a PUBLISH command issued by another client.
type Message struct { type Message struct {
Channel string Channel string
Pattern string
Payload string Payload string
} }
@ -78,6 +149,8 @@ func (m *Message) String() string {
return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload) return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
} }
// TODO: remove PMessage if favor of Message
// Message matching a pattern-matching subscription received as result // Message matching a pattern-matching subscription received as result
// of a PUBLISH command issued by another client. // of a PUBLISH command issued by another client.
type PMessage struct { type PMessage struct {
@ -102,12 +175,6 @@ func (p *Pong) String() string {
return "Pong" return "Pong"
} }
// Returns a message as a Subscription, Message, PMessage, Pong or
// error. See PubSub example for details.
func (c *PubSub) Receive() (interface{}, error) {
return c.ReceiveTimeout(0)
}
func newMessage(reply []interface{}) (interface{}, error) { func newMessage(reply []interface{}) (interface{}, error) {
switch kind := reply[0].(string); kind { switch kind := reply[0].(string); kind {
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
@ -137,7 +204,8 @@ func newMessage(reply []interface{}) (interface{}, error) {
} }
// ReceiveTimeout acts like Receive but returns an error if message // ReceiveTimeout acts like Receive but returns an error if message
// is not received in time. // is not received in time. This is low-level API and most clients
// should use ReceiveMessage.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
cn, err := c.conn() cn, err := c.conn()
if err != nil { if err != nil {
@ -152,39 +220,74 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
return newMessage(cmd.Val()) return newMessage(cmd.Val())
} }
func (c *PubSub) subscribe(cmd string, channels ...string) error { // Receive returns a message as a Subscription, Message, PMessage,
cn, err := c.conn() // Pong or error. See PubSub example for details. This is low-level
if err != nil { // API and most clients should use ReceiveMessage.
return err func (c *PubSub) Receive() (interface{}, error) {
return c.ReceiveTimeout(0)
}
func (c *PubSub) reconnect() {
c.connPool.Remove(nil) // close current connection
if len(c.channels) > 0 {
if err := c.Subscribe(c.channels...); err != nil {
log.Printf("redis: Subscribe failed: %s", err)
}
} }
if len(c.patterns) > 0 {
args := make([]interface{}, 1+len(channels)) if err := c.PSubscribe(c.patterns...); err != nil {
args[0] = cmd log.Printf("redis: Subscribe failed: %s", err)
for i, channel := range channels { }
args[1+i] = channel
} }
req := NewSliceCmd(args...)
return cn.writeCmds(req)
} }
// Subscribes the client to the specified channels. // ReceiveMessage returns a message or error. It automatically
func (c *PubSub) Subscribe(channels ...string) error { // reconnects to Redis in case of network errors.
return c.subscribe("SUBSCRIBE", channels...) func (c *PubSub) ReceiveMessage() (*Message, error) {
} var badConn bool
for {
msgi, err := c.ReceiveTimeout(5 * time.Second)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
if badConn {
c.reconnect()
badConn = false
continue
}
// Subscribes the client to the given patterns. err := c.Ping("")
func (c *PubSub) PSubscribe(patterns ...string) error { if err != nil {
return c.subscribe("PSUBSCRIBE", patterns...) c.reconnect()
} } else {
badConn = true
}
continue
}
// Unsubscribes the client from the given channels, or from all of if isNetworkError(err) {
// them if none is given. c.reconnect()
func (c *PubSub) Unsubscribe(channels ...string) error { continue
return c.subscribe("UNSUBSCRIBE", channels...) }
}
// Unsubscribes the client from the given patterns, or from all of return nil, err
// them if none is given. }
func (c *PubSub) PUnsubscribe(patterns ...string) error {
return c.subscribe("PUNSUBSCRIBE", patterns...) switch msg := msgi.(type) {
case *Subscription:
// Ignore.
case *Pong:
badConn = false
// Ignore.
case *Message:
return msg, nil
case *PMessage:
return &Message{
Channel: msg.Channel,
Pattern: msg.Pattern,
Payload: msg.Payload,
}, nil
default:
return nil, fmt.Errorf("redis: unknown message: %T", msgi)
}
}
} }

View File

@ -12,10 +12,12 @@ import (
var _ = Describe("PubSub", func() { var _ = Describe("PubSub", func() {
var client *redis.Client var client *redis.Client
readTimeout := 3 * time.Second
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(&redis.Options{ client = redis.NewClient(&redis.Options{
Addr: redisAddr, Addr: redisAddr,
ReadTimeout: readTimeout,
}) })
Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
}) })
@ -227,4 +229,51 @@ var _ = Describe("PubSub", func() {
Expect(pong.Payload).To(Equal("hello")) Expect(pong.Payload).To(Equal("hello"))
}) })
It("should ReceiveMessage", func() {
pubsub, err := client.Subscribe("mychannel")
Expect(err).NotTo(HaveOccurred())
defer pubsub.Close()
go func() {
defer GinkgoRecover()
time.Sleep(readTimeout + 100*time.Millisecond)
n, err := client.Publish("mychannel", "hello").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
}()
msg, err := pubsub.ReceiveMessage()
Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello"))
})
It("should reconnect on ReceiveMessage error", func() {
pubsub, err := client.Subscribe("mychannel")
Expect(err).NotTo(HaveOccurred())
defer pubsub.Close()
cn, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{
readErr: errTimeout,
writeErr: errTimeout,
})
go func() {
defer GinkgoRecover()
time.Sleep(100 * time.Millisecond)
n, err := client.Publish("mychannel", "hello").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(2)))
}()
msg, err := pubsub.ReceiveMessage()
Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello"))
})
}) })

View File

@ -159,7 +159,8 @@ var _ = Describe("Client", func() {
// Put bad connection in the pool. // Put bad connection in the pool.
cn, err := client.Pool().Get() cn, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(newBadNetConn())
cn.SetNetConn(&badConn{})
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())
err = client.Ping().Err() err = client.Ping().Err()