Improve pubsub (#1764)

* Improve pubsub

Signed-off-by: monkey92t <golang@88.com>

* Extract code to channel struct and tweak API

* Move chanSendTimeout to channel

* Cleanup health check

* Add WithChannelSendTimeout and tweak comments

* clear notes

Signed-off-by: monkey92t <golang@88.com>

Co-authored-by: Vladimir Mihailenco <vladimir.webdev@gmail.com>
This commit is contained in:
monkey92t 2021-05-26 11:25:18 +08:00 committed by GitHub
parent f83600d1a5
commit 8e8510431d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 166 additions and 114 deletions

253
pubsub.go
View File

@ -2,7 +2,6 @@ package redis
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
@ -13,13 +12,6 @@ import (
"github.com/go-redis/redis/v8/internal/proto" "github.com/go-redis/redis/v8/internal/proto"
) )
const (
pingTimeout = time.Second
chanSendTimeout = time.Minute
)
var errPingTimeout = errors.New("redis: ping timeout")
// PubSub implements Pub/Sub commands as described in // PubSub implements Pub/Sub commands as described in
// http://redis.io/topics/pubsub. Message receiving is NOT safe // http://redis.io/topics/pubsub. Message receiving is NOT safe
// for concurrent use by multiple goroutines. // for concurrent use by multiple goroutines.
@ -43,9 +35,12 @@ type PubSub struct {
cmd *Cmd cmd *Cmd
chOnce sync.Once chOnce sync.Once
msgCh chan *Message msgCh *channel
allCh chan interface{} allCh *channel
ping chan struct{} }
func (c *PubSub) init() {
c.exit = make(chan struct{})
} }
func (c *PubSub) String() string { func (c *PubSub) String() string {
@ -54,10 +49,6 @@ func (c *PubSub) String() string {
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
} }
func (c *PubSub) init() {
c.exit = make(chan struct{})
}
func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) { func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
c.mu.Lock() c.mu.Lock()
cn, err := c.conn(ctx, nil) cn, err := c.conn(ctx, nil)
@ -418,56 +409,6 @@ func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
} }
} }
// Channel returns a Go channel for concurrently receiving messages.
// The channel is closed together with the PubSub. If the Go channel
// is blocked full for 30 seconds the message is dropped.
// Receive* APIs can not be used after channel is created.
//
// go-redis periodically sends ping messages to test connection health
// and re-subscribes if ping can not not received for 30 seconds.
func (c *PubSub) Channel() <-chan *Message {
return c.ChannelSize(100)
}
// ChannelSize is like Channel, but creates a Go channel
// with specified buffer size.
func (c *PubSub) ChannelSize(size int) <-chan *Message {
c.chOnce.Do(func() {
c.initPing()
c.initMsgChan(size)
})
if c.msgCh == nil {
err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
panic(err)
}
if cap(c.msgCh) != size {
err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created")
panic(err)
}
return c.msgCh
}
// ChannelWithSubscriptions is like Channel, but message type can be either
// *Subscription or *Message. Subscription messages can be used to detect
// reconnections.
//
// ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
func (c *PubSub) ChannelWithSubscriptions(ctx context.Context, size int) <-chan interface{} {
c.chOnce.Do(func() {
c.initPing()
c.initAllChan(size)
})
if c.allCh == nil {
err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
panic(err)
}
if cap(c.allCh) != size {
err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created")
panic(err)
}
return c.allCh
}
func (c *PubSub) getContext() context.Context { func (c *PubSub) getContext() context.Context {
if c.cmd != nil { if c.cmd != nil {
return c.cmd.ctx return c.cmd.ctx
@ -475,36 +416,135 @@ func (c *PubSub) getContext() context.Context {
return context.Background() return context.Background()
} }
func (c *PubSub) initPing() { //------------------------------------------------------------------------------
// Channel returns a Go channel for concurrently receiving messages.
// The channel is closed together with the PubSub. If the Go channel
// is blocked full for 30 seconds the message is dropped.
// Receive* APIs can not be used after channel is created.
//
// go-redis periodically sends ping messages to test connection health
// and re-subscribes if ping can not not received for 30 seconds.
func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message {
c.chOnce.Do(func() {
c.msgCh = newChannel(c, opts...)
c.msgCh.initMsgChan()
})
if c.msgCh == nil {
err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
panic(err)
}
return c.msgCh.msgCh
}
// ChannelSize is like Channel, but creates a Go channel
// with specified buffer size.
//
// Deprecated: use Channel(WithChannelSize(size)), remove in v9.
func (c *PubSub) ChannelSize(size int) <-chan *Message {
return c.Channel(WithChannelSize(size))
}
// ChannelWithSubscriptions is like Channel, but message type can be either
// *Subscription or *Message. Subscription messages can be used to detect
// reconnections.
//
// ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} {
c.chOnce.Do(func() {
c.allCh = newChannel(c, WithChannelSize(size))
c.allCh.initAllChan()
})
if c.allCh == nil {
err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
panic(err)
}
return c.allCh.allCh
}
type ChannelOption func(c *channel)
// WithChannelSize specifies the Go chan size that is used to buffer incoming messages.
//
// The default is 100 messages.
func WithChannelSize(size int) ChannelOption {
return func(c *channel) {
c.chanSize = size
}
}
// WithChannelHealthCheckInterval specifies the health check interval.
// PubSub will ping Redis Server if it does not receive any messages within the interval.
// To disable health check, use zero interval.
//
// The default is 3 seconds.
func WithChannelHealthCheckInterval(d time.Duration) ChannelOption {
return func(c *channel) {
c.checkInterval = d
}
}
// WithChannelSendTimeout specifies that channel send timeout after which
// the message is dropped.
//
// The default is 60 seconds.
func WithChannelSendTimeout(d time.Duration) ChannelOption {
return func(c *channel) {
c.chanSendTimeout = d
}
}
type channel struct {
pubSub *PubSub
msgCh chan *Message
allCh chan interface{}
ping chan struct{}
chanSize int
chanSendTimeout time.Duration
checkInterval time.Duration
}
func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel {
c := &channel{
pubSub: pubSub,
chanSize: 100,
chanSendTimeout: time.Minute,
checkInterval: 3 * time.Second,
}
for _, opt := range opts {
opt(c)
}
if c.checkInterval > 0 {
c.initHealthCheck()
}
return c
}
func (c *channel) initHealthCheck() {
ctx := context.TODO() ctx := context.TODO()
c.ping = make(chan struct{}, 1) c.ping = make(chan struct{}, 1)
go func() { go func() {
timer := time.NewTimer(time.Minute) timer := time.NewTimer(time.Minute)
timer.Stop() timer.Stop()
healthy := true
for { for {
timer.Reset(pingTimeout) timer.Reset(c.checkInterval)
select { select {
case <-c.ping: case <-c.ping:
healthy = true
if !timer.Stop() { if !timer.Stop() {
<-timer.C <-timer.C
} }
case <-timer.C: case <-timer.C:
pingErr := c.Ping(ctx) if pingErr := c.pubSub.Ping(ctx); pingErr != nil {
if healthy { c.pubSub.mu.Lock()
healthy = false c.pubSub.reconnect(ctx, pingErr)
} else { c.pubSub.mu.Unlock()
if pingErr == nil {
pingErr = errPingTimeout
} }
c.mu.Lock() case <-c.pubSub.exit:
c.reconnect(ctx, pingErr)
healthy = true
c.mu.Unlock()
}
case <-c.exit:
return return
} }
} }
@ -512,16 +552,17 @@ func (c *PubSub) initPing() {
} }
// initMsgChan must be in sync with initAllChan. // initMsgChan must be in sync with initAllChan.
func (c *PubSub) initMsgChan(size int) { func (c *channel) initMsgChan() {
ctx := context.TODO() ctx := context.TODO()
c.msgCh = make(chan *Message, size) c.msgCh = make(chan *Message, c.chanSize)
go func() { go func() {
timer := time.NewTimer(time.Minute) timer := time.NewTimer(time.Minute)
timer.Stop() timer.Stop()
var errCount int var errCount int
for { for {
msg, err := c.Receive(ctx) msg, err := c.pubSub.Receive(ctx)
if err != nil { if err != nil {
if err == pool.ErrClosed { if err == pool.ErrClosed {
close(c.msgCh) close(c.msgCh)
@ -548,7 +589,7 @@ func (c *PubSub) initMsgChan(size int) {
case *Pong: case *Pong:
// Ignore. // Ignore.
case *Message: case *Message:
timer.Reset(chanSendTimeout) timer.Reset(c.chanSendTimeout)
select { select {
case c.msgCh <- msg: case c.msgCh <- msg:
if !timer.Stop() { if !timer.Stop() {
@ -556,30 +597,28 @@ func (c *PubSub) initMsgChan(size int) {
} }
case <-timer.C: case <-timer.C:
internal.Logger.Printf( internal.Logger.Printf(
c.getContext(), ctx, "redis: %s channel is full for %s (message is dropped)",
"redis: %s channel is full for %s (message is dropped)", c, c.chanSendTimeout)
c,
chanSendTimeout,
)
} }
default: default:
internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
} }
} }
}() }()
} }
// initAllChan must be in sync with initMsgChan. // initAllChan must be in sync with initMsgChan.
func (c *PubSub) initAllChan(size int) { func (c *channel) initAllChan() {
ctx := context.TODO() ctx := context.TODO()
c.allCh = make(chan interface{}, size) c.allCh = make(chan interface{}, c.chanSize)
go func() { go func() {
timer := time.NewTimer(pingTimeout) timer := time.NewTimer(time.Minute)
timer.Stop() timer.Stop()
var errCount int var errCount int
for { for {
msg, err := c.Receive(ctx) msg, err := c.pubSub.Receive(ctx)
if err != nil { if err != nil {
if err == pool.ErrClosed { if err == pool.ErrClosed {
close(c.allCh) close(c.allCh)
@ -601,21 +640,10 @@ func (c *PubSub) initAllChan(size int) {
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *Subscription:
c.sendMessage(msg, timer)
case *Pong: case *Pong:
// Ignore. // Ignore.
case *Message: case *Subscription, *Message:
c.sendMessage(msg, timer) timer.Reset(c.chanSendTimeout)
default:
internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg)
}
}
}()
}
func (c *PubSub) sendMessage(msg interface{}, timer *time.Timer) {
timer.Reset(pingTimeout)
select { select {
case c.allCh <- msg: case c.allCh <- msg:
if !timer.Stop() { if !timer.Stop() {
@ -623,7 +651,12 @@ func (c *PubSub) sendMessage(msg interface{}, timer *time.Timer) {
} }
case <-timer.C: case <-timer.C:
internal.Logger.Printf( internal.Logger.Printf(
c.getContext(), ctx, "redis: %s channel is full for %s (message is dropped)",
"redis: %s channel is full for %s (message is dropped)", c, pingTimeout) c, c.chanSendTimeout)
}
default:
internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
} }
} }
}()
}

View File

@ -473,4 +473,23 @@ var _ = Describe("PubSub", func() {
Fail("timeout") Fail("timeout")
} }
}) })
It("should ChannelMessage", func() {
pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close()
ch := pubsub.Channel(
redis.WithChannelSize(10),
redis.WithChannelHealthCheckInterval(time.Second),
)
text := "test channel message"
err := client.Publish(ctx, "mychannel", text).Err()
Expect(err).NotTo(HaveOccurred())
var msg *redis.Message
Eventually(ch).Should(Receive(&msg))
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal(text))
})
}) })