Extract code to channel struct and tweak API

This commit is contained in:
Vladimir Mihailenco 2021-05-25 14:38:40 +03:00
parent 336824d981
commit 1d38942c5f
2 changed files with 95 additions and 99 deletions

192
pubsub.go
View File

@ -39,57 +39,21 @@ type PubSub struct {
cmd *Cmd cmd *Cmd
size int
chOnce sync.Once chOnce sync.Once
msgCh chan *Message msgCh *channel
allCh chan interface{} allCh *channel
ping chan struct{}
pingTimeout time.Duration
healthTimeout time.Duration
} }
type PubSubOption func(c *PubSub) func (c *PubSub) init() {
c.exit = make(chan struct{})
// WithChannelSize go-chan size(default 100).
func WithChannelSize(size int) PubSubOption {
return func(c *PubSub) {
c.size = size
}
} }
// WithPingTimeout health(ping) check interval(default: 1s).
func WithPingTimeout(d time.Duration) PubSubOption {
return func(c *PubSub) {
c.pingTimeout = d
}
}
// WithHealthTimeout health check timeout,
// the maximum time to wait for a response after the ping command(default: 5s).
func WithHealthTimeout(d time.Duration) PubSubOption {
return func(c *PubSub) {
c.healthTimeout = d
}
}
// ---------------------------------
func (c *PubSub) String() string { func (c *PubSub) String() string {
channels := mapKeys(c.channels) channels := mapKeys(c.channels)
channels = append(channels, mapKeys(c.patterns)...) channels = append(channels, mapKeys(c.patterns)...)
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
} }
func (c *PubSub) init() {
c.exit = make(chan struct{})
c.size = 100
c.pingTimeout = time.Second
c.healthTimeout = 5 * time.Second
}
func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) { func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
c.mu.Lock() c.mu.Lock()
cn, err := c.conn(ctx, nil) cn, err := c.conn(ctx, nil)
@ -450,6 +414,15 @@ func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
} }
} }
func (c *PubSub) getContext() context.Context {
if c.cmd != nil {
return c.cmd.ctx
}
return context.Background()
}
//------------------------------------------------------------------------------
// Channel returns a Go channel for concurrently receiving messages. // Channel returns a Go channel for concurrently receiving messages.
// The channel is closed together with the PubSub. If the Go channel // The channel is closed together with the PubSub. If the Go channel
// is blocked full for 30 seconds the message is dropped. // is blocked full for 30 seconds the message is dropped.
@ -458,27 +431,25 @@ func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
// go-redis periodically sends ping messages to test connection health // go-redis periodically sends ping messages to test connection health
// and re-subscribes if ping can not not received for 30 seconds. // and re-subscribes if ping can not not received for 30 seconds.
// Deprecated: use ChannelMessage(), remove in v9. // Deprecated: use ChannelMessage(), remove in v9.
func (c *PubSub) Channel() <-chan *Message { func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message {
return c.ChannelSize(100)
}
// ChannelSize is like Channel, but creates a Go channel
// with specified buffer size.
// Deprecated: use ChannelMessage(), remove in v9.
func (c *PubSub) ChannelSize(size int) <-chan *Message {
c.chOnce.Do(func() { c.chOnce.Do(func() {
c.initPing() c.msgCh = newChannel(c, opts...)
c.initMsgChan(size) c.msgCh.initPing()
c.msgCh.initMsgChan()
}) })
if c.msgCh == nil { if c.msgCh == nil {
err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions") err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
panic(err) panic(err)
} }
if cap(c.msgCh) != size { return c.msgCh.msgCh
err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created") }
panic(err)
} // ChannelSize is like Channel, but creates a Go channel
return c.msgCh // with specified buffer size.
//
// Deprecated: use Channel(WithChannelSize(size)), remove in v9.
func (c *PubSub) ChannelSize(size int) <-chan *Message {
return c.Channel(WithChannelSize(size))
} }
// ChannelWithSubscriptions is like Channel, but message type can be either // ChannelWithSubscriptions is like Channel, but message type can be either
@ -489,48 +460,71 @@ func (c *PubSub) ChannelSize(size int) <-chan *Message {
// Deprecated: use ChannelSubscriptionMessage(), remove in v9. // Deprecated: use ChannelSubscriptionMessage(), remove in v9.
func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} { func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} {
c.chOnce.Do(func() { c.chOnce.Do(func() {
c.initPing() c.allCh = newChannel(c, WithChannelSize(size))
c.initAllChan(size) c.allCh.initPing()
c.allCh.initAllChan()
}) })
if c.allCh == nil { if c.allCh == nil {
err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel") err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
panic(err) panic(err)
} }
if cap(c.allCh) != size { return c.allCh.allCh
err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created")
panic(err)
}
return c.allCh
} }
// ChannelMessage replacement function of Channel(ChannelSize) in the future, type ChannelOption func(c *channel)
// allowing custom options.
func (c *PubSub) ChannelMessage(opts ...PubSubOption) <-chan *Message { // WithChannelSize go-chan size(default 100).
func WithChannelSize(size int) ChannelOption {
return func(c *channel) {
c.chanSize = size
}
}
// WithPingTimeout health(ping) check interval(default: 1s).
func WithPingTimeout(d time.Duration) ChannelOption {
return func(c *channel) {
c.pingTimeout = d
}
}
// WithHealthTimeout health check timeout,
// the maximum time to wait for a response after the ping command(default: 5s).
func WithHealthTimeout(d time.Duration) ChannelOption {
return func(c *channel) {
c.healthTimeout = d
}
}
type channel struct {
pubSub *PubSub
msgCh chan *Message
allCh chan interface{}
ping chan struct{}
chanSize int
pingTimeout time.Duration
healthTimeout time.Duration
}
func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel {
c := &channel{
pubSub: pubSub,
chanSize: 100,
pingTimeout: time.Second,
healthTimeout: 5 * time.Second,
}
for _, opt := range opts { for _, opt := range opts {
opt(c) opt(c)
} }
return c.ChannelSize(c.size) return c
} }
// ChannelSubscriptionMessage replacement function of ChannelWithSubscriptions in the future, func (c *channel) initPing() {
// allowing custom options.
func (c *PubSub) ChannelSubscriptionMessage(opts ...PubSubOption) <-chan interface{} {
for _, opt := range opts {
opt(c)
}
return c.ChannelWithSubscriptions(context.Background(), c.size)
}
func (c *PubSub) getContext() context.Context {
if c.cmd != nil {
return c.cmd.ctx
}
return context.Background()
}
func (c *PubSub) initPing() {
ctx := context.TODO() ctx := context.TODO()
c.ping = make(chan struct{}, 1) c.ping = make(chan struct{}, 1)
go func() { go func() {
timer := time.NewTimer(time.Minute) timer := time.NewTimer(time.Minute)
timer.Stop() timer.Stop()
@ -550,23 +544,23 @@ func (c *PubSub) initPing() {
var healthyErr error var healthyErr error
if healthy { if healthy {
healthyErr = c.Ping(ctx) healthyErr = c.pubSub.Ping(ctx)
healthy = false healthy = false
} else { } else {
healthyErr = errPingTimeout healthyErr = errPingTimeout
} }
if healthyErr != nil { if healthyErr != nil {
c.mu.Lock() c.pubSub.mu.Lock()
c.reconnect(ctx, healthyErr) c.pubSub.reconnect(ctx, healthyErr)
c.mu.Unlock() c.pubSub.mu.Unlock()
healthy = true healthy = true
timeout = c.pingTimeout timeout = c.pingTimeout
} else { } else {
timeout = c.healthTimeout timeout = c.healthTimeout
} }
case <-c.exit: case <-c.pubSub.exit:
return return
} }
} }
@ -574,16 +568,17 @@ func (c *PubSub) initPing() {
} }
// initMsgChan must be in sync with initAllChan. // initMsgChan must be in sync with initAllChan.
func (c *PubSub) initMsgChan(size int) { func (c *channel) initMsgChan() {
ctx := context.TODO() ctx := context.TODO()
c.msgCh = make(chan *Message, size) c.msgCh = make(chan *Message, c.chanSize)
go func() { go func() {
timer := time.NewTimer(time.Minute) timer := time.NewTimer(time.Minute)
timer.Stop() timer.Stop()
var errCount int var errCount int
for { for {
msg, err := c.Receive(ctx) msg, err := c.pubSub.Receive(ctx)
if err != nil { if err != nil {
if err == pool.ErrClosed { if err == pool.ErrClosed {
close(c.msgCh) close(c.msgCh)
@ -618,30 +613,31 @@ func (c *PubSub) initMsgChan(size int) {
} }
case <-timer.C: case <-timer.C:
internal.Logger.Printf( internal.Logger.Printf(
c.getContext(), ctx,
"redis: %s channel is full for %s (message is dropped)", "redis: %s channel is full for %s (message is dropped)",
c, c,
chanSendTimeout, chanSendTimeout,
) )
} }
default: default:
internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
} }
} }
}() }()
} }
// initAllChan must be in sync with initMsgChan. // initAllChan must be in sync with initMsgChan.
func (c *PubSub) initAllChan(size int) { func (c *channel) initAllChan() {
ctx := context.TODO() ctx := context.TODO()
c.allCh = make(chan interface{}, size) c.allCh = make(chan interface{}, c.chanSize)
go func() { go func() {
timer := time.NewTimer(time.Minute) timer := time.NewTimer(time.Minute)
timer.Stop() timer.Stop()
var errCount int var errCount int
for { for {
msg, err := c.Receive(ctx) msg, err := c.pubSub.Receive(ctx)
if err != nil { if err != nil {
if err == pool.ErrClosed { if err == pool.ErrClosed {
close(c.allCh) close(c.allCh)
@ -674,13 +670,13 @@ func (c *PubSub) initAllChan(size int) {
} }
case <-timer.C: case <-timer.C:
internal.Logger.Printf( internal.Logger.Printf(
c.getContext(), ctx,
"redis: %s channel is full for %s (message is dropped)", "redis: %s channel is full for %s (message is dropped)",
c, c,
chanSendTimeout) chanSendTimeout)
} }
default: default:
internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
} }
} }
}() }()

View File

@ -478,7 +478,7 @@ var _ = Describe("PubSub", func() {
pubsub := client.Subscribe(ctx, "mychannel") pubsub := client.Subscribe(ctx, "mychannel")
defer pubsub.Close() defer pubsub.Close()
ch := pubsub.ChannelMessage( ch := pubsub.Channel(
redis.WithChannelSize(10), redis.WithChannelSize(10),
redis.WithPingTimeout(time.Second), redis.WithPingTimeout(time.Second),
redis.WithHealthTimeout(time.Minute), redis.WithHealthTimeout(time.Minute),