diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f6f3f18..3728c9d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - New methods ProcessContext, DoContext, and ExecContext. - Client respects Context.Deadline when setting net.Conn deadline. - Client listens on Context.Done while waiting for a connection from the pool. +- Add PubSub.ChannelWithSubscriptions that sends `*Subscription` in addition to `*Message` to allow detecting reconnections. ## v6.15 diff --git a/pubsub.go b/pubsub.go index 7bb3872a..600a6949 100644 --- a/pubsub.go +++ b/pubsub.go @@ -13,6 +13,8 @@ import ( "github.com/go-redis/redis/internal/proto" ) +const pingTimeout = 30 * time.Second + var errPingTimeout = errors.New("redis: ping timeout") // PubSub implements Pub/Sub commands as described in @@ -38,7 +40,8 @@ type PubSub struct { cmd *Cmd chOnce sync.Once - ch chan *Message + msgCh chan *Message + allCh chan interface{} ping chan struct{} } @@ -394,95 +397,64 @@ func (c *PubSub) ReceiveMessage() (*Message, error) { } // Channel returns a Go channel for concurrently receiving messages. -// It periodically sends Ping messages to test connection health. -// The channel is closed with PubSub. Receive* APIs can not be used -// after channel is created. +// The channel is closed together with the PubSub. If the Go channel +// is blocked full for 30 seconds the message is dropped. +// Receive* APIs can not be used after channel is created. // -// If the Go channel is full for 30 seconds the message is dropped. +// go-redis periodically sends ping messages to test connection health +// and re-subscribes if ping can not not received for 30 seconds. func (c *PubSub) Channel() <-chan *Message { - return c.channel(100) + return c.ChannelSize(100) } // ChannelSize is like Channel, but creates a Go channel // with specified buffer size. func (c *PubSub) ChannelSize(size int) <-chan *Message { - return c.channel(size) -} - -func (c *PubSub) channel(size int) <-chan *Message { c.chOnce.Do(func() { - c.initChannel(size) + c.initPing() + c.initMsgChan(size) }) - if cap(c.ch) != size { - err := fmt.Errorf("redis: PubSub.Channel is called with different buffer size") + if c.msgCh == nil { + err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions") panic(err) } - return c.ch + if cap(c.msgCh) != size { + err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created") + panic(err) + } + return c.msgCh } -func (c *PubSub) initChannel(size int) { - const timeout = 30 * time.Second +// ChannelWithSubscriptions is like Channel, but message type can be either +// *Subscription or *Message. Subscription messages can be used to detect +// reconnections. +// +// ChannelWithSubscriptions can not be used together with Channel or ChannelSize. +func (c *PubSub) ChannelWithSubscriptions(size int) <-chan interface{} { + c.chOnce.Do(func() { + c.initPing() + c.initAllChan(size) + }) + if c.allCh == nil { + err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel") + panic(err) + } + if cap(c.allCh) != size { + err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created") + panic(err) + } + return c.allCh +} - c.ch = make(chan *Message, size) +func (c *PubSub) initPing() { c.ping = make(chan struct{}, 1) - go func() { - timer := time.NewTimer(timeout) - timer.Stop() - - var errCount int - for { - msg, err := c.Receive() - if err != nil { - if err == pool.ErrClosed { - close(c.ch) - return - } - if errCount > 0 { - time.Sleep(c.retryBackoff(errCount)) - } - errCount++ - continue - } - - errCount = 0 - - // Any message is as good as a ping. - select { - case c.ping <- struct{}{}: - default: - } - - switch msg := msg.(type) { - case *Subscription: - // Ignore. - case *Pong: - // Ignore. - case *Message: - timer.Reset(timeout) - select { - case c.ch <- msg: - if !timer.Stop() { - <-timer.C - } - case <-timer.C: - internal.Logger.Printf( - "redis: %s channel is full for %s (message is dropped)", - c, timeout) - } - default: - internal.Logger.Printf("redis: unknown message type: %T", msg) - } - } - }() - - go func() { - timer := time.NewTimer(timeout) + timer := time.NewTimer(pingTimeout) timer.Stop() healthy := true for { - timer.Reset(timeout) + timer.Reset(pingTimeout) select { case <-c.ping: healthy = true @@ -508,6 +480,116 @@ func (c *PubSub) initChannel(size int) { }() } +// initMsgChan must be in sync with initAllChan. +func (c *PubSub) initMsgChan(size int) { + c.msgCh = make(chan *Message, size) + go func() { + timer := time.NewTimer(pingTimeout) + timer.Stop() + + var errCount int + for { + msg, err := c.Receive() + if err != nil { + if err == pool.ErrClosed { + close(c.msgCh) + return + } + if errCount > 0 { + time.Sleep(c.retryBackoff(errCount)) + } + errCount++ + continue + } + + errCount = 0 + + // Any message is as good as a ping. + select { + case c.ping <- struct{}{}: + default: + } + + switch msg := msg.(type) { + case *Subscription: + // Ignore. + case *Pong: + // Ignore. + case *Message: + timer.Reset(pingTimeout) + select { + case c.msgCh <- msg: + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + internal.Logger.Printf( + "redis: %s channel is full for %s (message is dropped)", c, pingTimeout) + } + default: + internal.Logger.Printf("redis: unknown message type: %T", msg) + } + } + }() +} + +// initAllChan must be in sync with initMsgChan. +func (c *PubSub) initAllChan(size int) { + c.allCh = make(chan interface{}, size) + go func() { + timer := time.NewTimer(pingTimeout) + timer.Stop() + + var errCount int + for { + msg, err := c.Receive() + if err != nil { + if err == pool.ErrClosed { + close(c.allCh) + return + } + if errCount > 0 { + time.Sleep(c.retryBackoff(errCount)) + } + errCount++ + continue + } + + errCount = 0 + + // Any message is as good as a ping. + select { + case c.ping <- struct{}{}: + default: + } + + switch msg := msg.(type) { + case *Subscription: + c.sendMessage(msg, timer) + case *Pong: + // Ignore. + case *Message: + c.sendMessage(msg, timer) + default: + internal.Logger.Printf("redis: unknown message type: %T", msg) + } + } + }() +} + +func (c *PubSub) sendMessage(msg interface{}, timer *time.Timer) { + timer.Reset(pingTimeout) + select { + case c.allCh <- msg: + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + internal.Logger.Printf( + "redis: %s channel is full for %s (message is dropped)", c, pingTimeout) + } +} + func (c *PubSub) retryBackoff(attempt int) time.Duration { return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) }