redis/pubsub.go

593 lines
13 KiB
Go
Raw Normal View History

2012-07-25 17:00:50 +04:00
package redis
import (
"context"
2018-10-09 10:52:30 +03:00
"errors"
2012-07-25 17:00:50 +04:00
"fmt"
"strings"
"sync"
2014-05-11 11:42:40 +04:00
"time"
2017-02-18 17:42:34 +03:00
"github.com/go-redis/redis/internal"
"github.com/go-redis/redis/internal/pool"
2018-08-15 11:53:15 +03:00
"github.com/go-redis/redis/internal/proto"
2012-07-25 17:00:50 +04:00
)
2019-07-01 17:21:32 +03:00
const pingTimeout = 30 * time.Second
2018-10-09 10:52:30 +03:00
var errPingTimeout = errors.New("redis: ping timeout")
2019-04-08 15:06:31 +03:00
// PubSub implements Pub/Sub commands as described in
2018-07-23 15:55:13 +03:00
// http://redis.io/topics/pubsub. Message receiving is NOT safe
// for concurrent use by multiple goroutines.
2017-05-11 17:02:26 +03:00
//
2018-07-23 15:55:13 +03:00
// PubSub automatically reconnects to Redis Server and resubscribes
// to the channels in case of network errors.
2014-05-11 11:42:40 +04:00
type PubSub struct {
2017-07-09 10:07:20 +03:00
opt *Options
newConn func([]string) (*pool.Conn, error)
closeConn func(*pool.Conn) error
mu sync.Mutex
cn *pool.Conn
channels map[string]struct{}
patterns map[string]struct{}
closed bool
exit chan struct{}
2017-04-24 12:43:15 +03:00
cmd *Cmd
2017-10-30 13:09:57 +03:00
chOnce sync.Once
2019-07-01 17:21:32 +03:00
msgCh chan *Message
allCh chan interface{}
2018-07-24 09:41:14 +03:00
ping chan struct{}
2016-09-29 15:07:04 +03:00
}
func (c *PubSub) String() string {
channels := mapKeys(c.channels)
channels = append(channels, mapKeys(c.patterns)...)
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
}
2018-07-23 15:55:13 +03:00
func (c *PubSub) init() {
c.exit = make(chan struct{})
}
2019-06-17 12:32:40 +03:00
func (c *PubSub) connWithLock() (*pool.Conn, error) {
2017-04-24 12:43:15 +03:00
c.mu.Lock()
2019-06-17 12:32:40 +03:00
cn, err := c.conn(nil)
2017-06-29 17:05:08 +03:00
c.mu.Unlock()
return cn, err
}
2017-04-24 12:43:15 +03:00
2019-06-17 12:32:40 +03:00
func (c *PubSub) conn(newChannels []string) (*pool.Conn, error) {
2017-04-24 12:43:15 +03:00
if c.closed {
2017-06-29 17:05:08 +03:00
return nil, pool.ErrClosed
2017-04-24 12:43:15 +03:00
}
if c.cn != nil {
2017-06-29 17:05:08 +03:00
return c.cn, nil
2017-04-24 12:43:15 +03:00
}
channels := mapKeys(c.channels)
channels = append(channels, newChannels...)
2017-07-09 10:07:20 +03:00
cn, err := c.newConn(channels)
2016-09-29 15:07:04 +03:00
if err != nil {
2017-06-29 17:05:08 +03:00
return nil, err
2016-09-29 15:07:04 +03:00
}
2017-04-24 12:43:15 +03:00
if err := c.resubscribe(cn); err != nil {
2017-07-09 10:07:20 +03:00
_ = c.closeConn(cn)
2017-06-29 17:05:08 +03:00
return nil, err
2017-04-24 12:43:15 +03:00
}
c.cn = cn
2017-06-29 17:05:08 +03:00
return cn, nil
}
func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
2018-08-17 13:56:37 +03:00
return writeCmd(wr, cmd)
2018-08-15 11:53:15 +03:00
})
}
2017-04-24 12:43:15 +03:00
func (c *PubSub) resubscribe(cn *pool.Conn) error {
var firstErr error
2018-07-23 15:55:13 +03:00
2017-04-24 12:43:15 +03:00
if len(c.channels) > 0 {
2019-07-25 13:53:00 +03:00
firstErr = c._subscribe(cn, "subscribe", mapKeys(c.channels))
}
2018-07-23 15:55:13 +03:00
2017-04-24 12:43:15 +03:00
if len(c.patterns) > 0 {
err := c._subscribe(cn, "psubscribe", mapKeys(c.patterns))
2018-07-23 15:55:13 +03:00
if err != nil && firstErr == nil {
firstErr = err
}
}
2018-07-23 15:55:13 +03:00
return firstErr
}
2018-07-23 15:55:13 +03:00
func mapKeys(m map[string]struct{}) []string {
s := make([]string, len(m))
i := 0
for k := range m {
s[i] = k
i++
}
return s
}
func (c *PubSub) _subscribe(
cn *pool.Conn, redisCmd string, channels []string,
) error {
args := make([]interface{}, 0, 1+len(channels))
args = append(args, redisCmd)
for _, channel := range channels {
args = append(args, channel)
}
cmd := NewSliceCmd(args...)
return c.writeCmd(context.TODO(), cn, cmd)
}
2019-06-17 12:32:40 +03:00
func (c *PubSub) releaseConnWithLock(cn *pool.Conn, err error, allowTimeout bool) {
2017-04-24 12:43:15 +03:00
c.mu.Lock()
2019-06-17 12:32:40 +03:00
c.releaseConn(cn, err, allowTimeout)
2017-04-24 12:43:15 +03:00
c.mu.Unlock()
}
2019-06-17 12:32:40 +03:00
func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
2017-08-31 15:22:47 +03:00
if c.cn != cn {
return
}
if isBadConn(err, allowTimeout) {
2019-06-17 12:32:40 +03:00
c.reconnect(err)
2017-08-01 14:21:26 +03:00
}
}
2019-06-17 12:32:40 +03:00
func (c *PubSub) reconnect(reason error) {
_ = c.closeTheCn(reason)
_, _ = c.conn(nil)
2017-04-24 12:43:15 +03:00
}
2019-06-17 12:32:40 +03:00
func (c *PubSub) closeTheCn(reason error) error {
2018-08-07 10:33:07 +03:00
if c.cn == nil {
return nil
}
if !c.closed {
2019-06-17 12:32:40 +03:00
internal.Logger.Printf("redis: discarding bad PubSub connection: %s", reason)
2018-08-07 10:33:07 +03:00
}
err := c.closeConn(c.cn)
c.cn = nil
return err
2018-07-23 15:55:13 +03:00
}
2017-04-24 12:43:15 +03:00
func (c *PubSub) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
2017-04-24 12:43:15 +03:00
return pool.ErrClosed
}
2017-04-24 12:43:15 +03:00
c.closed = true
2018-07-23 15:55:13 +03:00
close(c.exit)
2019-06-17 12:32:40 +03:00
return c.closeTheCn(pool.ErrClosed)
2017-04-24 12:43:15 +03:00
}
2018-01-24 21:38:47 +03:00
// Subscribe the client to the specified channels. It returns
2017-05-11 17:02:26 +03:00
// empty subscription if there are no channels.
2015-09-06 13:50:16 +03:00
func (c *PubSub) Subscribe(channels ...string) error {
c.mu.Lock()
2018-07-23 15:55:13 +03:00
defer c.mu.Unlock()
2017-06-29 17:05:08 +03:00
err := c.subscribe("subscribe", channels...)
if c.channels == nil {
c.channels = make(map[string]struct{})
}
for _, s := range channels {
c.channels[s] = struct{}{}
}
2017-06-29 17:05:08 +03:00
return err
2015-09-06 13:50:16 +03:00
}
2018-01-24 21:38:47 +03:00
// PSubscribe the client to the given patterns. It returns
2017-05-11 17:02:26 +03:00
// empty subscription if there are no patterns.
2015-09-06 13:50:16 +03:00
func (c *PubSub) PSubscribe(patterns ...string) error {
c.mu.Lock()
2018-07-23 15:55:13 +03:00
defer c.mu.Unlock()
2017-06-29 17:05:08 +03:00
err := c.subscribe("psubscribe", patterns...)
if c.patterns == nil {
c.patterns = make(map[string]struct{})
}
for _, s := range patterns {
c.patterns[s] = struct{}{}
}
2017-06-29 17:05:08 +03:00
return err
2015-09-06 13:50:16 +03:00
}
2018-01-24 21:38:47 +03:00
// Unsubscribe the client from the given channels, or from all of
2015-09-06 13:50:16 +03:00
// them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error {
c.mu.Lock()
2018-07-23 15:55:13 +03:00
defer c.mu.Unlock()
for _, channel := range channels {
delete(c.channels, channel)
}
err := c.subscribe("unsubscribe", channels...)
2017-06-29 17:05:08 +03:00
return err
2015-09-06 13:50:16 +03:00
}
2018-01-24 21:38:47 +03:00
// PUnsubscribe the client from the given patterns, or from all of
2015-09-06 13:50:16 +03:00
// them if none is given.
func (c *PubSub) PUnsubscribe(patterns ...string) error {
c.mu.Lock()
2018-07-23 15:55:13 +03:00
defer c.mu.Unlock()
for _, pattern := range patterns {
delete(c.patterns, pattern)
}
err := c.subscribe("punsubscribe", patterns...)
2017-06-29 17:05:08 +03:00
return err
2015-09-06 13:50:16 +03:00
}
func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
2019-06-17 12:32:40 +03:00
cn, err := c.conn(channels)
if err != nil {
return err
}
err = c._subscribe(cn, redisCmd, channels)
2019-06-17 12:32:40 +03:00
c.releaseConn(cn, err, false)
return err
}
2017-02-23 16:29:38 +03:00
func (c *PubSub) Ping(payload ...string) error {
args := []interface{}{"ping"}
2017-02-23 16:29:38 +03:00
if len(payload) == 1 {
args = append(args, payload[0])
2015-07-11 13:12:47 +03:00
}
cmd := NewCmd(args...)
2019-06-17 12:32:40 +03:00
cn, err := c.connWithLock()
if err != nil {
return err
}
err = c.writeCmd(context.TODO(), cn, cmd)
2019-06-17 12:32:40 +03:00
c.releaseConnWithLock(cn, err, false)
return err
2015-07-11 13:12:47 +03:00
}
2018-01-24 21:38:47 +03:00
// Subscription received after a successful subscription to channel.
2015-07-11 13:42:44 +03:00
type Subscription struct {
// Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe".
Kind string
// Channel name we have subscribed to.
Channel string
// Number of channels we are currently subscribed to.
Count int
}
func (m *Subscription) String() string {
return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
}
2015-05-23 18:17:45 +03:00
// Message received as result of a PUBLISH command issued by another client.
2012-07-25 17:00:50 +04:00
type Message struct {
2014-05-11 11:42:40 +04:00
Channel string
2015-09-06 13:50:16 +03:00
Pattern string
2014-05-11 11:42:40 +04:00
Payload string
}
2012-07-25 17:00:50 +04:00
2014-05-11 18:11:55 +04:00
func (m *Message) String() string {
return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
}
2015-07-11 13:12:47 +03:00
// Pong received as result of a PING command issued by another client.
type Pong struct {
Payload string
}
func (p *Pong) String() string {
if p.Payload != "" {
return fmt.Sprintf("Pong<%s>", p.Payload)
}
return "Pong"
}
func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
switch reply := reply.(type) {
case string:
2015-07-11 13:12:47 +03:00
return &Pong{
Payload: reply,
2015-07-11 13:12:47 +03:00
}, nil
case []interface{}:
switch kind := reply[0].(string); kind {
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
return &Subscription{
Kind: kind,
Channel: reply[1].(string),
Count: int(reply[2].(int64)),
}, nil
case "message":
return &Message{
Channel: reply[1].(string),
Payload: reply[2].(string),
}, nil
case "pmessage":
return &Message{
Pattern: reply[1].(string),
Channel: reply[2].(string),
Payload: reply[3].(string),
}, nil
case "pong":
return &Pong{
Payload: reply[1].(string),
}, nil
default:
return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
}
2015-07-11 13:12:47 +03:00
default:
return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
2015-07-11 13:12:47 +03:00
}
}
2015-07-11 13:42:44 +03:00
// ReceiveTimeout acts like Receive but returns an error if message
2018-07-24 09:41:14 +03:00
// is not received in time. This is low-level API and in most cases
// Channel should be used instead.
2015-07-11 13:12:47 +03:00
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
if c.cmd == nil {
c.cmd = NewCmd()
}
2019-06-17 12:32:40 +03:00
cn, err := c.connWithLock()
2015-07-11 13:12:47 +03:00
if err != nil {
return nil, err
2014-05-11 11:42:40 +04:00
}
err = cn.WithReader(context.TODO(), timeout, func(rd *proto.Reader) error {
2018-08-15 11:53:15 +03:00
return c.cmd.readReply(rd)
})
2019-06-17 12:32:40 +03:00
c.releaseConnWithLock(cn, err, timeout > 0)
if err != nil {
2015-07-11 13:12:47 +03:00
return nil, err
}
return c.newMessage(c.cmd.Val())
2014-05-11 11:42:40 +04:00
}
2012-07-25 17:00:50 +04:00
2016-04-09 11:45:56 +03:00
// Receive returns a message as a Subscription, Message, Pong or error.
2018-07-24 09:41:14 +03:00
// See PubSub example for details. This is low-level API and in most cases
// Channel should be used instead.
2015-09-06 13:50:16 +03:00
func (c *PubSub) Receive() (interface{}, error) {
return c.ReceiveTimeout(0)
}
2014-05-11 11:42:40 +04:00
2018-07-24 09:41:14 +03:00
// ReceiveMessage returns a Message or error ignoring Subscription and Pong
// messages. This is low-level API and in most cases Channel should be used
// instead.
2015-09-06 13:50:16 +03:00
func (c *PubSub) ReceiveMessage() (*Message, error) {
for {
2018-07-23 15:55:13 +03:00
msg, err := c.Receive()
2015-09-06 13:50:16 +03:00
if err != nil {
2018-07-23 15:55:13 +03:00
return nil, err
2015-09-06 13:50:16 +03:00
}
2018-07-23 15:55:13 +03:00
switch msg := msg.(type) {
2015-09-06 13:50:16 +03:00
case *Subscription:
// Ignore.
case *Pong:
// Ignore.
case *Message:
return msg, nil
default:
2018-07-23 15:55:13 +03:00
err := fmt.Errorf("redis: unknown message: %T", msg)
return nil, err
2015-09-06 13:50:16 +03:00
}
}
}
2017-10-30 13:09:57 +03:00
// Channel returns a Go channel for concurrently receiving messages.
2019-07-01 17:21:32 +03:00
// The channel is closed together with the PubSub. If the Go channel
// is blocked full for 30 seconds the message is dropped.
// Receive* APIs can not be used after channel is created.
//
2019-07-01 17:21:32 +03:00
// go-redis periodically sends ping messages to test connection health
// and re-subscribes if ping can not not received for 30 seconds.
2017-04-11 16:18:35 +03:00
func (c *PubSub) Channel() <-chan *Message {
2019-07-01 17:21:32 +03:00
return c.ChannelSize(100)
2019-03-12 13:48:32 +03:00
}
// ChannelSize is like Channel, but creates a Go channel
// with specified buffer size.
func (c *PubSub) ChannelSize(size int) <-chan *Message {
2019-07-01 17:21:32 +03:00
c.chOnce.Do(func() {
c.initPing()
c.initMsgChan(size)
})
if c.msgCh == nil {
err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
panic(err)
}
if cap(c.msgCh) != size {
err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created")
panic(err)
}
return c.msgCh
2019-03-12 13:48:32 +03:00
}
2019-07-01 17:21:32 +03:00
// ChannelWithSubscriptions is like Channel, but message type can be either
// *Subscription or *Message. Subscription messages can be used to detect
// reconnections.
//
// ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
func (c *PubSub) ChannelWithSubscriptions(size int) <-chan interface{} {
2019-03-12 13:48:32 +03:00
c.chOnce.Do(func() {
2019-07-01 17:21:32 +03:00
c.initPing()
c.initAllChan(size)
2019-03-12 13:48:32 +03:00
})
2019-07-01 17:21:32 +03:00
if c.allCh == nil {
err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
panic(err)
}
if cap(c.allCh) != size {
err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created")
2019-03-12 13:48:32 +03:00
panic(err)
}
2019-07-01 17:21:32 +03:00
return c.allCh
2018-07-23 15:55:13 +03:00
}
2019-07-01 17:21:32 +03:00
func (c *PubSub) initPing() {
c.ping = make(chan struct{}, 1)
2019-07-01 17:21:32 +03:00
go func() {
timer := time.NewTimer(pingTimeout)
timer.Stop()
2018-07-24 09:41:14 +03:00
2019-07-01 17:21:32 +03:00
healthy := true
for {
timer.Reset(pingTimeout)
select {
case <-c.ping:
healthy = true
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
pingErr := c.Ping()
if healthy {
healthy = false
} else {
if pingErr == nil {
pingErr = errPingTimeout
}
c.mu.Lock()
c.reconnect(pingErr)
c.mu.Unlock()
}
case <-c.exit:
return
}
}
}()
}
// initMsgChan must be in sync with initAllChan.
func (c *PubSub) initMsgChan(size int) {
c.msgCh = make(chan *Message, size)
2018-07-23 15:55:13 +03:00
go func() {
2019-07-01 17:21:32 +03:00
timer := time.NewTimer(pingTimeout)
timer.Stop()
2018-07-23 15:55:13 +03:00
var errCount int
for {
2018-07-24 09:41:14 +03:00
msg, err := c.Receive()
2018-07-23 15:55:13 +03:00
if err != nil {
if err == pool.ErrClosed {
2019-07-01 17:21:32 +03:00
close(c.msgCh)
2018-07-23 15:55:13 +03:00
return
}
if errCount > 0 {
time.Sleep(c.retryBackoff(errCount))
2017-04-11 16:18:35 +03:00
}
2018-07-23 15:55:13 +03:00
errCount++
continue
2017-04-11 16:18:35 +03:00
}
2018-07-23 15:55:13 +03:00
errCount = 0
2018-07-24 09:41:14 +03:00
// Any message is as good as a ping.
select {
case c.ping <- struct{}{}:
default:
}
switch msg := msg.(type) {
case *Subscription:
// Ignore.
case *Pong:
// Ignore.
case *Message:
2019-07-01 17:21:32 +03:00
timer.Reset(pingTimeout)
select {
2019-07-01 17:21:32 +03:00
case c.msgCh <- msg:
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
2019-06-17 12:32:40 +03:00
internal.Logger.Printf(
2019-07-01 17:21:32 +03:00
"redis: %s channel is full for %s (message is dropped)", c, pingTimeout)
}
2018-07-24 09:41:14 +03:00
default:
2019-06-17 12:32:40 +03:00
internal.Logger.Printf("redis: unknown message type: %T", msg)
2018-07-24 09:41:14 +03:00
}
2018-07-23 15:55:13 +03:00
}
}()
2019-07-01 17:21:32 +03:00
}
2018-07-23 15:55:13 +03:00
2019-07-01 17:21:32 +03:00
// initAllChan must be in sync with initMsgChan.
func (c *PubSub) initAllChan(size int) {
c.allCh = make(chan interface{}, size)
2018-07-23 15:55:13 +03:00
go func() {
2019-07-01 17:21:32 +03:00
timer := time.NewTimer(pingTimeout)
2018-07-23 15:55:13 +03:00
timer.Stop()
2019-07-01 17:21:32 +03:00
var errCount int
2018-07-23 15:55:13 +03:00
for {
2019-07-01 17:21:32 +03:00
msg, err := c.Receive()
if err != nil {
if err == pool.ErrClosed {
close(c.allCh)
return
2018-07-23 15:55:13 +03:00
}
2019-07-01 17:21:32 +03:00
if errCount > 0 {
time.Sleep(c.retryBackoff(errCount))
2018-07-23 15:55:13 +03:00
}
2019-07-01 17:21:32 +03:00
errCount++
continue
}
errCount = 0
// Any message is as good as a ping.
select {
case c.ping <- struct{}{}:
default:
}
switch msg := msg.(type) {
case *Subscription:
c.sendMessage(msg, timer)
case *Pong:
// Ignore.
case *Message:
c.sendMessage(msg, timer)
default:
internal.Logger.Printf("redis: unknown message type: %T", msg)
2018-07-23 15:55:13 +03:00
}
}
}()
}
2019-07-01 17:21:32 +03:00
func (c *PubSub) sendMessage(msg interface{}, timer *time.Timer) {
timer.Reset(pingTimeout)
select {
case c.allCh <- msg:
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
internal.Logger.Printf(
"redis: %s channel is full for %s (message is dropped)", c, pingTimeout)
}
}
2018-07-23 15:55:13 +03:00
func (c *PubSub) retryBackoff(attempt int) time.Duration {
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
2017-04-11 16:18:35 +03:00
}