From 1d38942c5f07fa55f4e4025bf9026213a711249d Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Tue, 25 May 2021 14:38:40 +0300 Subject: [PATCH] Extract code to channel struct and tweak API --- pubsub.go | 192 ++++++++++++++++++++++++------------------------- pubsub_test.go | 2 +- 2 files changed, 95 insertions(+), 99 deletions(-) diff --git a/pubsub.go b/pubsub.go index a02b2a3..e746d23 100644 --- a/pubsub.go +++ b/pubsub.go @@ -39,57 +39,21 @@ type PubSub struct { cmd *Cmd - size int - chOnce sync.Once - msgCh chan *Message - allCh chan interface{} - ping chan struct{} - - pingTimeout time.Duration - healthTimeout time.Duration + msgCh *channel + allCh *channel } -type PubSubOption func(c *PubSub) - -// WithChannelSize go-chan size(default 100). -func WithChannelSize(size int) PubSubOption { - return func(c *PubSub) { - c.size = size - } +func (c *PubSub) init() { + c.exit = make(chan struct{}) } -// WithPingTimeout health(ping) check interval(default: 1s). -func WithPingTimeout(d time.Duration) PubSubOption { - return func(c *PubSub) { - c.pingTimeout = d - } -} - -// WithHealthTimeout health check timeout, -// the maximum time to wait for a response after the ping command(default: 5s). -func WithHealthTimeout(d time.Duration) PubSubOption { - return func(c *PubSub) { - c.healthTimeout = d - } -} - -// --------------------------------- - func (c *PubSub) String() string { channels := mapKeys(c.channels) channels = append(channels, mapKeys(c.patterns)...) return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) } -func (c *PubSub) init() { - c.exit = make(chan struct{}) - - c.size = 100 - c.pingTimeout = time.Second - c.healthTimeout = 5 * time.Second -} - func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) { c.mu.Lock() cn, err := c.conn(ctx, nil) @@ -450,6 +414,15 @@ func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) { } } +func (c *PubSub) getContext() context.Context { + if c.cmd != nil { + return c.cmd.ctx + } + return context.Background() +} + +//------------------------------------------------------------------------------ + // Channel returns a Go channel for concurrently receiving messages. // The channel is closed together with the PubSub. If the Go channel // is blocked full for 30 seconds the message is dropped. @@ -458,27 +431,25 @@ func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) { // go-redis periodically sends ping messages to test connection health // and re-subscribes if ping can not not received for 30 seconds. // Deprecated: use ChannelMessage(), remove in v9. -func (c *PubSub) Channel() <-chan *Message { - return c.ChannelSize(100) -} - -// ChannelSize is like Channel, but creates a Go channel -// with specified buffer size. -// Deprecated: use ChannelMessage(), remove in v9. -func (c *PubSub) ChannelSize(size int) <-chan *Message { +func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message { c.chOnce.Do(func() { - c.initPing() - c.initMsgChan(size) + c.msgCh = newChannel(c, opts...) + c.msgCh.initPing() + c.msgCh.initMsgChan() }) if c.msgCh == nil { err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions") panic(err) } - if cap(c.msgCh) != size { - err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created") - panic(err) - } - return c.msgCh + return c.msgCh.msgCh +} + +// ChannelSize is like Channel, but creates a Go channel +// with specified buffer size. +// +// Deprecated: use Channel(WithChannelSize(size)), remove in v9. +func (c *PubSub) ChannelSize(size int) <-chan *Message { + return c.Channel(WithChannelSize(size)) } // ChannelWithSubscriptions is like Channel, but message type can be either @@ -489,48 +460,71 @@ func (c *PubSub) ChannelSize(size int) <-chan *Message { // Deprecated: use ChannelSubscriptionMessage(), remove in v9. func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} { c.chOnce.Do(func() { - c.initPing() - c.initAllChan(size) + c.allCh = newChannel(c, WithChannelSize(size)) + c.allCh.initPing() + c.allCh.initAllChan() }) if c.allCh == nil { err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel") panic(err) } - if cap(c.allCh) != size { - err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created") - panic(err) - } - return c.allCh + return c.allCh.allCh } -// ChannelMessage replacement function of Channel(ChannelSize) in the future, -// allowing custom options. -func (c *PubSub) ChannelMessage(opts ...PubSubOption) <-chan *Message { +type ChannelOption func(c *channel) + +// WithChannelSize go-chan size(default 100). +func WithChannelSize(size int) ChannelOption { + return func(c *channel) { + c.chanSize = size + } +} + +// WithPingTimeout health(ping) check interval(default: 1s). +func WithPingTimeout(d time.Duration) ChannelOption { + return func(c *channel) { + c.pingTimeout = d + } +} + +// WithHealthTimeout health check timeout, +// the maximum time to wait for a response after the ping command(default: 5s). +func WithHealthTimeout(d time.Duration) ChannelOption { + return func(c *channel) { + c.healthTimeout = d + } +} + +type channel struct { + pubSub *PubSub + + msgCh chan *Message + allCh chan interface{} + ping chan struct{} + + chanSize int + pingTimeout time.Duration + healthTimeout time.Duration +} + +func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel { + c := &channel{ + pubSub: pubSub, + + chanSize: 100, + pingTimeout: time.Second, + healthTimeout: 5 * time.Second, + } for _, opt := range opts { opt(c) } - return c.ChannelSize(c.size) + return c } -// ChannelSubscriptionMessage replacement function of ChannelWithSubscriptions in the future, -// allowing custom options. -func (c *PubSub) ChannelSubscriptionMessage(opts ...PubSubOption) <-chan interface{} { - for _, opt := range opts { - opt(c) - } - return c.ChannelWithSubscriptions(context.Background(), c.size) -} - -func (c *PubSub) getContext() context.Context { - if c.cmd != nil { - return c.cmd.ctx - } - return context.Background() -} - -func (c *PubSub) initPing() { +func (c *channel) initPing() { ctx := context.TODO() c.ping = make(chan struct{}, 1) + go func() { timer := time.NewTimer(time.Minute) timer.Stop() @@ -550,23 +544,23 @@ func (c *PubSub) initPing() { var healthyErr error if healthy { - healthyErr = c.Ping(ctx) + healthyErr = c.pubSub.Ping(ctx) healthy = false } else { healthyErr = errPingTimeout } if healthyErr != nil { - c.mu.Lock() - c.reconnect(ctx, healthyErr) - c.mu.Unlock() + c.pubSub.mu.Lock() + c.pubSub.reconnect(ctx, healthyErr) + c.pubSub.mu.Unlock() healthy = true timeout = c.pingTimeout } else { timeout = c.healthTimeout } - case <-c.exit: + case <-c.pubSub.exit: return } } @@ -574,16 +568,17 @@ func (c *PubSub) initPing() { } // initMsgChan must be in sync with initAllChan. -func (c *PubSub) initMsgChan(size int) { +func (c *channel) initMsgChan() { ctx := context.TODO() - c.msgCh = make(chan *Message, size) + c.msgCh = make(chan *Message, c.chanSize) + go func() { timer := time.NewTimer(time.Minute) timer.Stop() var errCount int for { - msg, err := c.Receive(ctx) + msg, err := c.pubSub.Receive(ctx) if err != nil { if err == pool.ErrClosed { close(c.msgCh) @@ -618,30 +613,31 @@ func (c *PubSub) initMsgChan(size int) { } case <-timer.C: internal.Logger.Printf( - c.getContext(), + ctx, "redis: %s channel is full for %s (message is dropped)", c, chanSendTimeout, ) } default: - internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) + internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) } } }() } // initAllChan must be in sync with initMsgChan. -func (c *PubSub) initAllChan(size int) { +func (c *channel) initAllChan() { ctx := context.TODO() - c.allCh = make(chan interface{}, size) + c.allCh = make(chan interface{}, c.chanSize) + go func() { timer := time.NewTimer(time.Minute) timer.Stop() var errCount int for { - msg, err := c.Receive(ctx) + msg, err := c.pubSub.Receive(ctx) if err != nil { if err == pool.ErrClosed { close(c.allCh) @@ -674,13 +670,13 @@ func (c *PubSub) initAllChan(size int) { } case <-timer.C: internal.Logger.Printf( - c.getContext(), + ctx, "redis: %s channel is full for %s (message is dropped)", c, chanSendTimeout) } default: - internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) + internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) } } }() diff --git a/pubsub_test.go b/pubsub_test.go index 02ad2cc..4ce3f30 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -478,7 +478,7 @@ var _ = Describe("PubSub", func() { pubsub := client.Subscribe(ctx, "mychannel") defer pubsub.Close() - ch := pubsub.ChannelMessage( + ch := pubsub.Channel( redis.WithChannelSize(10), redis.WithPingTimeout(time.Second), redis.WithHealthTimeout(time.Minute),