From 336824d98193f2c5fcbe12fb4e1780e8484f1a7f Mon Sep 17 00:00:00 2001 From: monkey92t Date: Tue, 25 May 2021 17:35:32 +0800 Subject: [PATCH] Improve pubsub Signed-off-by: monkey92t --- pubsub.go | 120 ++++++++++++++++++++++++++++++++++++------------- pubsub_test.go | 20 +++++++++ 2 files changed, 109 insertions(+), 31 deletions(-) diff --git a/pubsub.go b/pubsub.go index c56270b..a02b2a3 100644 --- a/pubsub.go +++ b/pubsub.go @@ -13,10 +13,7 @@ import ( "github.com/go-redis/redis/v8/internal/proto" ) -const ( - pingTimeout = time.Second - chanSendTimeout = time.Minute -) +const chanSendTimeout = time.Minute var errPingTimeout = errors.New("redis: ping timeout") @@ -42,12 +39,43 @@ type PubSub struct { cmd *Cmd + size int + chOnce sync.Once msgCh chan *Message allCh chan interface{} ping chan struct{} + + pingTimeout time.Duration + healthTimeout time.Duration } +type PubSubOption func(c *PubSub) + +// WithChannelSize go-chan size(default 100). +func WithChannelSize(size int) PubSubOption { + return func(c *PubSub) { + c.size = size + } +} + +// WithPingTimeout health(ping) check interval(default: 1s). +func WithPingTimeout(d time.Duration) PubSubOption { + return func(c *PubSub) { + c.pingTimeout = d + } +} + +// WithHealthTimeout health check timeout, +// the maximum time to wait for a response after the ping command(default: 5s). +func WithHealthTimeout(d time.Duration) PubSubOption { + return func(c *PubSub) { + c.healthTimeout = d + } +} + +// --------------------------------- + func (c *PubSub) String() string { channels := mapKeys(c.channels) channels = append(channels, mapKeys(c.patterns)...) @@ -56,6 +84,10 @@ func (c *PubSub) String() string { func (c *PubSub) init() { c.exit = make(chan struct{}) + + c.size = 100 + c.pingTimeout = time.Second + c.healthTimeout = 5 * time.Second } func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) { @@ -425,12 +457,14 @@ func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) { // // go-redis periodically sends ping messages to test connection health // and re-subscribes if ping can not not received for 30 seconds. +// Deprecated: use ChannelMessage(), remove in v9. func (c *PubSub) Channel() <-chan *Message { return c.ChannelSize(100) } // ChannelSize is like Channel, but creates a Go channel // with specified buffer size. +// Deprecated: use ChannelMessage(), remove in v9. func (c *PubSub) ChannelSize(size int) <-chan *Message { c.chOnce.Do(func() { c.initPing() @@ -452,7 +486,8 @@ func (c *PubSub) ChannelSize(size int) <-chan *Message { // reconnections. // // ChannelWithSubscriptions can not be used together with Channel or ChannelSize. -func (c *PubSub) ChannelWithSubscriptions(ctx context.Context, size int) <-chan interface{} { +// Deprecated: use ChannelSubscriptionMessage(), remove in v9. +func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} { c.chOnce.Do(func() { c.initPing() c.initAllChan(size) @@ -468,6 +503,24 @@ func (c *PubSub) ChannelWithSubscriptions(ctx context.Context, size int) <-chan return c.allCh } +// ChannelMessage replacement function of Channel(ChannelSize) in the future, +// allowing custom options. +func (c *PubSub) ChannelMessage(opts ...PubSubOption) <-chan *Message { + for _, opt := range opts { + opt(c) + } + return c.ChannelSize(c.size) +} + +// ChannelSubscriptionMessage replacement function of ChannelWithSubscriptions in the future, +// allowing custom options. +func (c *PubSub) ChannelSubscriptionMessage(opts ...PubSubOption) <-chan interface{} { + for _, opt := range opts { + opt(c) + } + return c.ChannelWithSubscriptions(context.Background(), c.size) +} + func (c *PubSub) getContext() context.Context { if c.cmd != nil { return c.cmd.ctx @@ -483,26 +536,35 @@ func (c *PubSub) initPing() { timer.Stop() healthy := true + timeout := c.pingTimeout for { - timer.Reset(pingTimeout) + timer.Reset(timeout) select { case <-c.ping: healthy = true + timeout = c.pingTimeout if !timer.Stop() { <-timer.C } case <-timer.C: - pingErr := c.Ping(ctx) + var healthyErr error + if healthy { + healthyErr = c.Ping(ctx) healthy = false } else { - if pingErr == nil { - pingErr = errPingTimeout - } + healthyErr = errPingTimeout + } + + if healthyErr != nil { c.mu.Lock() - c.reconnect(ctx, pingErr) - healthy = true + c.reconnect(ctx, healthyErr) c.mu.Unlock() + + healthy = true + timeout = c.pingTimeout + } else { + timeout = c.healthTimeout } case <-c.exit: return @@ -574,7 +636,7 @@ func (c *PubSub) initAllChan(size int) { ctx := context.TODO() c.allCh = make(chan interface{}, size) go func() { - timer := time.NewTimer(pingTimeout) + timer := time.NewTimer(time.Minute) timer.Stop() var errCount int @@ -601,29 +663,25 @@ func (c *PubSub) initAllChan(size int) { } switch msg := msg.(type) { - case *Subscription: - c.sendMessage(msg, timer) case *Pong: // Ignore. - case *Message: - c.sendMessage(msg, timer) + case *Subscription, *Message: + timer.Reset(chanSendTimeout) + select { + case c.allCh <- msg: + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + internal.Logger.Printf( + c.getContext(), + "redis: %s channel is full for %s (message is dropped)", + c, + chanSendTimeout) + } default: internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) } } }() } - -func (c *PubSub) sendMessage(msg interface{}, timer *time.Timer) { - timer.Reset(pingTimeout) - select { - case c.allCh <- msg: - if !timer.Stop() { - <-timer.C - } - case <-timer.C: - internal.Logger.Printf( - c.getContext(), - "redis: %s channel is full for %s (message is dropped)", c, pingTimeout) - } -} diff --git a/pubsub_test.go b/pubsub_test.go index d32d5e0..02ad2cc 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -473,4 +473,24 @@ var _ = Describe("PubSub", func() { Fail("timeout") } }) + + It("should ChannelMessage", func() { + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + ch := pubsub.ChannelMessage( + redis.WithChannelSize(10), + redis.WithPingTimeout(time.Second), + redis.WithHealthTimeout(time.Minute), + ) + + text := "test channel message" + err := client.Publish(ctx, "mychannel", text).Err() + Expect(err).NotTo(HaveOccurred()) + + var msg *redis.Message + Eventually(ch).Should(Receive(&msg)) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal(text)) + }) })