forked from mirror/redis
Merge pull request #164 from go-redis/fix/pubsub-receive-message
Add PubSub.ReceiveMessage.
This commit is contained in:
commit
02154c3b3a
|
@ -219,14 +219,31 @@ func ExamplePubSub() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 4; i++ {
|
msg, err := pubsub.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(msg.Channel, msg.Payload)
|
||||||
|
// Output: mychannel hello
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExamplePubSub_Receive() {
|
||||||
|
pubsub, err := client.Subscribe("mychannel")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer pubsub.Close()
|
||||||
|
|
||||||
|
err = client.Publish("mychannel", "hello").Err()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
msgi, err := pubsub.ReceiveTimeout(100 * time.Millisecond)
|
msgi, err := pubsub.ReceiveTimeout(100 * time.Millisecond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := pubsub.Ping("")
|
panic(err)
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch msg := msgi.(type) {
|
switch msg := msgi.(type) {
|
||||||
|
@ -234,8 +251,6 @@ func ExamplePubSub() {
|
||||||
fmt.Println(msg.Kind, msg.Channel)
|
fmt.Println(msg.Kind, msg.Channel)
|
||||||
case *redis.Message:
|
case *redis.Message:
|
||||||
fmt.Println(msg.Channel, msg.Payload)
|
fmt.Println(msg.Channel, msg.Payload)
|
||||||
case *redis.Pong:
|
|
||||||
fmt.Println(msg)
|
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("unknown message: %#v", msgi))
|
panic(fmt.Sprintf("unknown message: %#v", msgi))
|
||||||
}
|
}
|
||||||
|
@ -243,7 +258,6 @@ func ExamplePubSub() {
|
||||||
|
|
||||||
// Output: subscribe mychannel
|
// Output: subscribe mychannel
|
||||||
// mychannel hello
|
// mychannel hello
|
||||||
// Pong
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleScript() {
|
func ExampleScript() {
|
||||||
|
|
34
main_test.go
34
main_test.go
|
@ -8,6 +8,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -231,20 +232,33 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) {
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
type badNetConn struct {
|
var errTimeout = syscall.ETIMEDOUT
|
||||||
|
|
||||||
|
type badConn struct {
|
||||||
net.TCPConn
|
net.TCPConn
|
||||||
|
|
||||||
|
readDelay, writeDelay time.Duration
|
||||||
|
readErr, writeErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ net.Conn = &badNetConn{}
|
var _ net.Conn = &badConn{}
|
||||||
|
|
||||||
func newBadNetConn() net.Conn {
|
func (cn *badConn) Read([]byte) (int, error) {
|
||||||
return &badNetConn{}
|
if cn.readDelay != 0 {
|
||||||
|
time.Sleep(cn.readDelay)
|
||||||
|
}
|
||||||
|
if cn.readErr != nil {
|
||||||
|
return 0, cn.readErr
|
||||||
|
}
|
||||||
|
return 0, net.UnknownNetworkError("badConn")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (badNetConn) Read([]byte) (int, error) {
|
func (cn *badConn) Write([]byte) (int, error) {
|
||||||
return 0, net.UnknownNetworkError("badNetConn")
|
if cn.writeDelay != 0 {
|
||||||
}
|
time.Sleep(cn.writeDelay)
|
||||||
|
}
|
||||||
func (badNetConn) Write([]byte) (int, error) {
|
if cn.writeErr != nil {
|
||||||
return 0, net.UnknownNetworkError("badNetConn")
|
return 0, cn.writeErr
|
||||||
|
}
|
||||||
|
return 0, net.UnknownNetworkError("badConn")
|
||||||
}
|
}
|
||||||
|
|
4
pool.go
4
pool.go
|
@ -396,8 +396,8 @@ func (p *singleConnPool) Remove(cn *conn) error {
|
||||||
if p.cn == nil {
|
if p.cn == nil {
|
||||||
panic("p.cn == nil")
|
panic("p.cn == nil")
|
||||||
}
|
}
|
||||||
if p.cn != cn {
|
if cn != nil && cn != p.cn {
|
||||||
panic("p.cn != cn")
|
panic("cn != p.cn")
|
||||||
}
|
}
|
||||||
if p.closed {
|
if p.closed {
|
||||||
return errClosed
|
return errClosed
|
||||||
|
|
173
pubsub.go
173
pubsub.go
|
@ -2,6 +2,8 @@ package redis
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,6 +18,9 @@ func (c *Client) Publish(channel, message string) *IntCmd {
|
||||||
// http://redis.io/topics/pubsub.
|
// http://redis.io/topics/pubsub.
|
||||||
type PubSub struct {
|
type PubSub struct {
|
||||||
*baseClient
|
*baseClient
|
||||||
|
|
||||||
|
channels []string
|
||||||
|
patterns []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated. Use Subscribe/PSubscribe instead.
|
// Deprecated. Use Subscribe/PSubscribe instead.
|
||||||
|
@ -40,6 +45,71 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) {
|
||||||
return pubsub, pubsub.PSubscribe(channels...)
|
return pubsub, pubsub.PSubscribe(channels...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *PubSub) subscribe(cmd string, channels ...string) error {
|
||||||
|
cn, err := c.conn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
args := make([]interface{}, 1+len(channels))
|
||||||
|
args[0] = cmd
|
||||||
|
for i, channel := range channels {
|
||||||
|
args[1+i] = channel
|
||||||
|
}
|
||||||
|
req := NewSliceCmd(args...)
|
||||||
|
return cn.writeCmds(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribes the client to the specified channels.
|
||||||
|
func (c *PubSub) Subscribe(channels ...string) error {
|
||||||
|
err := c.subscribe("SUBSCRIBE", channels...)
|
||||||
|
if err == nil {
|
||||||
|
c.channels = append(c.channels, channels...)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribes the client to the given patterns.
|
||||||
|
func (c *PubSub) PSubscribe(patterns ...string) error {
|
||||||
|
err := c.subscribe("PSUBSCRIBE", patterns...)
|
||||||
|
if err == nil {
|
||||||
|
c.channels = append(c.channels, patterns...)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func remove(ss []string, es ...string) []string {
|
||||||
|
for _, e := range es {
|
||||||
|
for i, s := range ss {
|
||||||
|
if s == e {
|
||||||
|
ss = append(ss[:i], ss[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ss
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribes the client from the given channels, or from all of
|
||||||
|
// them if none is given.
|
||||||
|
func (c *PubSub) Unsubscribe(channels ...string) error {
|
||||||
|
err := c.subscribe("UNSUBSCRIBE", channels...)
|
||||||
|
if err == nil {
|
||||||
|
c.channels = remove(c.channels, channels...)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribes the client from the given patterns, or from all of
|
||||||
|
// them if none is given.
|
||||||
|
func (c *PubSub) PUnsubscribe(patterns ...string) error {
|
||||||
|
err := c.subscribe("PUNSUBSCRIBE", patterns...)
|
||||||
|
if err == nil {
|
||||||
|
c.patterns = remove(c.patterns, patterns...)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (c *PubSub) Ping(payload string) error {
|
func (c *PubSub) Ping(payload string) error {
|
||||||
cn, err := c.conn()
|
cn, err := c.conn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -71,6 +141,7 @@ func (m *Subscription) String() string {
|
||||||
// Message received as result of a PUBLISH command issued by another client.
|
// Message received as result of a PUBLISH command issued by another client.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Channel string
|
Channel string
|
||||||
|
Pattern string
|
||||||
Payload string
|
Payload string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,6 +149,8 @@ func (m *Message) String() string {
|
||||||
return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
|
return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: remove PMessage if favor of Message
|
||||||
|
|
||||||
// Message matching a pattern-matching subscription received as result
|
// Message matching a pattern-matching subscription received as result
|
||||||
// of a PUBLISH command issued by another client.
|
// of a PUBLISH command issued by another client.
|
||||||
type PMessage struct {
|
type PMessage struct {
|
||||||
|
@ -102,12 +175,6 @@ func (p *Pong) String() string {
|
||||||
return "Pong"
|
return "Pong"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a message as a Subscription, Message, PMessage, Pong or
|
|
||||||
// error. See PubSub example for details.
|
|
||||||
func (c *PubSub) Receive() (interface{}, error) {
|
|
||||||
return c.ReceiveTimeout(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMessage(reply []interface{}) (interface{}, error) {
|
func newMessage(reply []interface{}) (interface{}, error) {
|
||||||
switch kind := reply[0].(string); kind {
|
switch kind := reply[0].(string); kind {
|
||||||
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
|
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
|
||||||
|
@ -137,7 +204,8 @@ func newMessage(reply []interface{}) (interface{}, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReceiveTimeout acts like Receive but returns an error if message
|
// ReceiveTimeout acts like Receive but returns an error if message
|
||||||
// is not received in time.
|
// is not received in time. This is low-level API and most clients
|
||||||
|
// should use ReceiveMessage.
|
||||||
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
|
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
|
||||||
cn, err := c.conn()
|
cn, err := c.conn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -152,39 +220,74 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
|
||||||
return newMessage(cmd.Val())
|
return newMessage(cmd.Val())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *PubSub) subscribe(cmd string, channels ...string) error {
|
// Receive returns a message as a Subscription, Message, PMessage,
|
||||||
cn, err := c.conn()
|
// Pong or error. See PubSub example for details. This is low-level
|
||||||
if err != nil {
|
// API and most clients should use ReceiveMessage.
|
||||||
return err
|
func (c *PubSub) Receive() (interface{}, error) {
|
||||||
|
return c.ReceiveTimeout(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PubSub) reconnect() {
|
||||||
|
c.connPool.Remove(nil) // close current connection
|
||||||
|
if len(c.channels) > 0 {
|
||||||
|
if err := c.Subscribe(c.channels...); err != nil {
|
||||||
|
log.Printf("redis: Subscribe failed: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if len(c.patterns) > 0 {
|
||||||
args := make([]interface{}, 1+len(channels))
|
if err := c.PSubscribe(c.patterns...); err != nil {
|
||||||
args[0] = cmd
|
log.Printf("redis: Subscribe failed: %s", err)
|
||||||
for i, channel := range channels {
|
}
|
||||||
args[1+i] = channel
|
|
||||||
}
|
}
|
||||||
req := NewSliceCmd(args...)
|
|
||||||
return cn.writeCmds(req)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subscribes the client to the specified channels.
|
// ReceiveMessage returns a message or error. It automatically
|
||||||
func (c *PubSub) Subscribe(channels ...string) error {
|
// reconnects to Redis in case of network errors.
|
||||||
return c.subscribe("SUBSCRIBE", channels...)
|
func (c *PubSub) ReceiveMessage() (*Message, error) {
|
||||||
}
|
var badConn bool
|
||||||
|
for {
|
||||||
|
msgi, err := c.ReceiveTimeout(5 * time.Second)
|
||||||
|
if err != nil {
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
if badConn {
|
||||||
|
c.reconnect()
|
||||||
|
badConn = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Subscribes the client to the given patterns.
|
err := c.Ping("")
|
||||||
func (c *PubSub) PSubscribe(patterns ...string) error {
|
if err != nil {
|
||||||
return c.subscribe("PSUBSCRIBE", patterns...)
|
c.reconnect()
|
||||||
}
|
} else {
|
||||||
|
badConn = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Unsubscribes the client from the given channels, or from all of
|
if isNetworkError(err) {
|
||||||
// them if none is given.
|
c.reconnect()
|
||||||
func (c *PubSub) Unsubscribe(channels ...string) error {
|
continue
|
||||||
return c.subscribe("UNSUBSCRIBE", channels...)
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Unsubscribes the client from the given patterns, or from all of
|
return nil, err
|
||||||
// them if none is given.
|
}
|
||||||
func (c *PubSub) PUnsubscribe(patterns ...string) error {
|
|
||||||
return c.subscribe("PUNSUBSCRIBE", patterns...)
|
switch msg := msgi.(type) {
|
||||||
|
case *Subscription:
|
||||||
|
// Ignore.
|
||||||
|
case *Pong:
|
||||||
|
badConn = false
|
||||||
|
// Ignore.
|
||||||
|
case *Message:
|
||||||
|
return msg, nil
|
||||||
|
case *PMessage:
|
||||||
|
return &Message{
|
||||||
|
Channel: msg.Channel,
|
||||||
|
Pattern: msg.Pattern,
|
||||||
|
Payload: msg.Payload,
|
||||||
|
}, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("redis: unknown message: %T", msgi)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,10 +12,12 @@ import (
|
||||||
|
|
||||||
var _ = Describe("PubSub", func() {
|
var _ = Describe("PubSub", func() {
|
||||||
var client *redis.Client
|
var client *redis.Client
|
||||||
|
readTimeout := 3 * time.Second
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
client = redis.NewClient(&redis.Options{
|
client = redis.NewClient(&redis.Options{
|
||||||
Addr: redisAddr,
|
Addr: redisAddr,
|
||||||
|
ReadTimeout: readTimeout,
|
||||||
})
|
})
|
||||||
Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
|
Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
@ -227,4 +229,51 @@ var _ = Describe("PubSub", func() {
|
||||||
Expect(pong.Payload).To(Equal("hello"))
|
Expect(pong.Payload).To(Equal("hello"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("should ReceiveMessage", func() {
|
||||||
|
pubsub, err := client.Subscribe("mychannel")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
defer pubsub.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
|
||||||
|
time.Sleep(readTimeout + 100*time.Millisecond)
|
||||||
|
n, err := client.Publish("mychannel", "hello").Result()
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(int64(1)))
|
||||||
|
}()
|
||||||
|
|
||||||
|
msg, err := pubsub.ReceiveMessage()
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(msg.Channel).To(Equal("mychannel"))
|
||||||
|
Expect(msg.Payload).To(Equal("hello"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should reconnect on ReceiveMessage error", func() {
|
||||||
|
pubsub, err := client.Subscribe("mychannel")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
defer pubsub.Close()
|
||||||
|
|
||||||
|
cn, err := pubsub.Pool().Get()
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
cn.SetNetConn(&badConn{
|
||||||
|
readErr: errTimeout,
|
||||||
|
writeErr: errTimeout,
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
n, err := client.Publish("mychannel", "hello").Result()
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(int64(2)))
|
||||||
|
}()
|
||||||
|
|
||||||
|
msg, err := pubsub.ReceiveMessage()
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(msg.Channel).To(Equal("mychannel"))
|
||||||
|
Expect(msg.Payload).To(Equal("hello"))
|
||||||
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
|
@ -159,7 +159,8 @@ var _ = Describe("Client", func() {
|
||||||
// Put bad connection in the pool.
|
// Put bad connection in the pool.
|
||||||
cn, err := client.Pool().Get()
|
cn, err := client.Pool().Get()
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
cn.SetNetConn(newBadNetConn())
|
|
||||||
|
cn.SetNetConn(&badConn{})
|
||||||
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())
|
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())
|
||||||
|
|
||||||
err = client.Ping().Err()
|
err = client.Ping().Err()
|
||||||
|
|
Loading…
Reference in New Issue