From 191f839e81e7a2885c7f91b5799be1a03d9cab97 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Mon, 17 Apr 2017 16:05:01 +0300 Subject: [PATCH] Fix race between Subscribe and resubscribe --- options.go | 2 +- pubsub.go | 52 +++++++++++++++++++++++++++++++++++----------------- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/options.go b/options.go index 8c30a67..d2aefb4 100644 --- a/options.go +++ b/options.go @@ -47,7 +47,7 @@ type Options struct { WriteTimeout time.Duration // Maximum number of socket connections. - // Default is 100 connections. + // Default is 10 connections. PoolSize int // Amount of time client waits for connection if all connections // are busy before returning an error. diff --git a/pubsub.go b/pubsub.go index 497b27c..a581556 100644 --- a/pubsub.go +++ b/pubsub.go @@ -22,6 +22,7 @@ type PubSub struct { cmd *Cmd + subMu sync.Mutex channels []string patterns []string } @@ -33,23 +34,32 @@ func (c *PubSub) conn() (*pool.Conn, error) { } if isNew { - c.resubscribe() + if err := c.resubscribe(); err != nil { + internal.Logf("resubscribe failed: %s", err) + } } return cn, nil } -func (c *PubSub) resubscribe() { - if len(c.channels) > 0 { - if err := c.subscribe("subscribe", c.channels...); err != nil { - internal.Logf("Subscribe failed: %s", err) +func (c *PubSub) resubscribe() error { + c.subMu.Lock() + channels := c.channels + patterns := c.patterns + c.subMu.Unlock() + + var firstErr error + if len(channels) > 0 { + if err := c.subscribe("subscribe", channels...); err != nil && firstErr == nil { + firstErr = err } } - if len(c.patterns) > 0 { - if err := c.subscribe("psubscribe", c.patterns...); err != nil { - internal.Logf("PSubscribe failed: %s", err) + if len(patterns) > 0 { + if err := c.subscribe("psubscribe", patterns...); err != nil && firstErr == nil { + firstErr = err } } + return firstErr } func (c *PubSub) _conn() (*pool.Conn, bool, error) { @@ -91,11 +101,15 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { } cmd := NewSliceCmd(args...) - cn, err := c.conn() + cn, isNew, err := c._conn() if err != nil { return err } + if isNew { + return c.resubscribe() + } + cn.SetWriteTimeout(c.base.opt.WriteTimeout) err = writeCmd(cn, cmd) c.putConn(cn, err) @@ -104,32 +118,36 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { // Subscribes the client to the specified channels. func (c *PubSub) Subscribe(channels ...string) error { - err := c.subscribe("subscribe", channels...) + c.subMu.Lock() c.channels = appendIfNotExists(c.channels, channels...) - return err + c.subMu.Unlock() + return c.subscribe("subscribe", channels...) } // Subscribes the client to the given patterns. func (c *PubSub) PSubscribe(patterns ...string) error { - err := c.subscribe("psubscribe", patterns...) + c.subMu.Lock() c.patterns = appendIfNotExists(c.patterns, patterns...) - return err + c.subMu.Unlock() + return c.subscribe("psubscribe", patterns...) } // Unsubscribes the client from the given channels, or from all of // them if none is given. func (c *PubSub) Unsubscribe(channels ...string) error { - err := c.subscribe("unsubscribe", channels...) + c.subMu.Lock() c.channels = remove(c.channels, channels...) - return err + c.subMu.Unlock() + return c.subscribe("unsubscribe", channels...) } // Unsubscribes the client from the given patterns, or from all of // them if none is given. func (c *PubSub) PUnsubscribe(patterns ...string) error { - err := c.subscribe("punsubscribe", patterns...) + c.subMu.Lock() c.patterns = remove(c.patterns, patterns...) - return err + c.subMu.Unlock() + return c.subscribe("punsubscribe", patterns...) } func (c *PubSub) Close() error {