diff --git a/pubsub.go b/pubsub.go index 2cfcd15..f0fcb8a 100644 --- a/pubsub.go +++ b/pubsub.go @@ -113,17 +113,17 @@ func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) return writeCmd(cn, cmd) } -func (c *PubSub) releaseConn(cn *pool.Conn, err error) { +func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { c.mu.Lock() - c._releaseConn(cn, err) + c._releaseConn(cn, err, allowTimeout) c.mu.Unlock() } -func (c *PubSub) _releaseConn(cn *pool.Conn, err error) { +func (c *PubSub) _releaseConn(cn *pool.Conn, err error, allowTimeout bool) { if c.cn != cn { return } - if internal.IsBadConn(err, true) { + if internal.IsBadConn(err, allowTimeout) { c._reconnect() } } @@ -137,12 +137,6 @@ func (c *PubSub) _closeTheCn() error { return err } -func (c *PubSub) reconnect() { - c.mu.Lock() - c._reconnect() - c.mu.Unlock() -} - func (c *PubSub) _reconnect() { _ = c._closeTheCn() _, _ = c._conn(nil) @@ -227,7 +221,7 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { } err = c._subscribe(cn, redisCmd, channels...) - c._releaseConn(cn, err) + c._releaseConn(cn, err, false) return err } @@ -245,7 +239,7 @@ func (c *PubSub) Ping(payload ...string) error { cn.SetWriteTimeout(c.opt.WriteTimeout) err = writeCmd(cn, cmd) - c.releaseConn(cn, err) + c.releaseConn(cn, err, false) return err } @@ -338,7 +332,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { cn.SetReadTimeout(timeout) err = c.cmd.readReply(cn) - c.releaseConn(cn, err) + c.releaseConn(cn, err, timeout > 0) if err != nil { return nil, err } @@ -446,7 +440,9 @@ func (c *PubSub) initChannel() { hasPing = false _ = c.Ping() } else { - c.reconnect() + c.mu.Lock() + c._reconnect() + c.mu.Unlock() } case <-c.exit: return