From 9ebd89772add2db0c136b8a634739de553984841 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Mon, 24 Apr 2017 12:43:15 +0300 Subject: [PATCH] Rework PubSub conn management --- cluster.go | 4 +- pubsub.go | 152 +++++++++++++++++++++++++---------------------------- redis.go | 10 ++-- ring.go | 2 +- 4 files changed, 80 insertions(+), 88 deletions(-) diff --git a/cluster.go b/cluster.go index 223d252..43cbf64 100644 --- a/cluster.go +++ b/cluster.go @@ -704,7 +704,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { } err = c.pipelineProcessCmds(cn, cmds, failedCmds) - node.Client.putConn(cn, err, false) + node.Client.putConn(cn, err) } if len(failedCmds) == 0 { @@ -840,7 +840,7 @@ func (c *ClusterClient) txPipelineExec(cmds []Cmder) error { } err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds) - node.Client.putConn(cn, err, false) + node.Client.putConn(cn, err) } if len(failedCmds) == 0 { diff --git a/pubsub.go b/pubsub.go index e47978c..3680323 100644 --- a/pubsub.go +++ b/pubsub.go @@ -20,49 +20,14 @@ type PubSub struct { cn *pool.Conn closed bool - cmd *Cmd - subMu sync.Mutex channels []string patterns []string + + cmd *Cmd } -func (c *PubSub) conn() (*pool.Conn, error) { - cn, isNew, err := c._conn() - if err != nil { - return nil, err - } - - if isNew { - if err := c.resubscribe(); err != nil { - internal.Logf("resubscribe failed: %s", err) - } - } - - return cn, nil -} - -func (c *PubSub) resubscribe() error { - c.subMu.Lock() - channels := c.channels - patterns := c.patterns - c.subMu.Unlock() - - var firstErr error - if len(channels) > 0 { - if err := c.subscribe("subscribe", channels...); err != nil && firstErr == nil { - firstErr = err - } - } - if len(patterns) > 0 { - if err := c.subscribe("psubscribe", patterns...); err != nil && firstErr == nil { - firstErr = err - } - } - return firstErr -} - -func (c *PubSub) _conn() (*pool.Conn, bool, error) { +func (c *PubSub) conn() (*pool.Conn, bool, error) { c.mu.Lock() defer c.mu.Unlock() @@ -86,21 +51,81 @@ func (c *PubSub) _conn() (*pool.Conn, bool, error) { } } + if err := c.resubscribe(cn); err != nil { + return nil, false, err + } + c.cn = cn return cn, true, nil } -func (c *PubSub) putConn(cn *pool.Conn, err error) { - if internal.IsBadConn(err, true) { - c.mu.Lock() - if c.cn == cn { - _ = c.closeConn() +func (c *PubSub) resubscribe(cn *pool.Conn) error { + c.subMu.Lock() + defer c.subMu.Unlock() + + var firstErr error + if len(c.channels) > 0 { + if err := c._subscribe(cn, "subscribe", c.channels...); err != nil && firstErr == nil { + firstErr = err } - c.mu.Unlock() } + if len(c.patterns) > 0 { + if err := c._subscribe(cn, "psubscribe", c.patterns...); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +func (c *PubSub) putConn(cn *pool.Conn, err error) { + if !internal.IsBadConn(err, true) { + return + } + + c.mu.Lock() + if c.cn == cn { + _ = c.closeConn() + } + c.mu.Unlock() +} + +func (c *PubSub) closeConn() error { + err := c.base.connPool.CloseConn(c.cn) + c.cn = nil + return err +} + +func (c *PubSub) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return pool.ErrClosed + } + c.closed = true + + if c.cn != nil { + return c.closeConn() + } + return nil } func (c *PubSub) subscribe(redisCmd string, channels ...string) error { + cn, isNew, err := c.conn() + if err != nil { + return err + } + + if isNew { + return nil + } + + err = c._subscribe(cn, redisCmd, channels...) + c.putConn(cn, err) + return err +} + +func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error { args := make([]interface{}, 1+len(channels)) args[0] = redisCmd for i, channel := range channels { @@ -108,19 +133,8 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { } cmd := NewSliceCmd(args...) - cn, isNew, err := c._conn() - if err != nil { - return err - } - - if isNew { - return c.resubscribe() - } - cn.SetWriteTimeout(c.base.opt.WriteTimeout) - err = writeCmd(cn, cmd) - c.putConn(cn, err) - return err + return writeCmd(cn, cmd) } // Subscribes the client to the specified channels. @@ -157,28 +171,6 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error { return c.subscribe("punsubscribe", patterns...) } -func (c *PubSub) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.closed { - return pool.ErrClosed - } - c.closed = true - - if c.cn != nil { - _ = c.closeConn() - } - - return nil -} - -func (c *PubSub) closeConn() error { - err := c.base.connPool.CloseConn(c.cn) - c.cn = nil - return err -} - func (c *PubSub) Ping(payload ...string) error { args := []interface{}{"ping"} if len(payload) == 1 { @@ -186,7 +178,7 @@ func (c *PubSub) Ping(payload ...string) error { } cmd := NewCmd(args...) - cn, err := c.conn() + cn, _, err := c.conn() if err != nil { return err } @@ -279,7 +271,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { c.cmd = NewCmd() } - cn, err := c.conn() + cn, _, err := c.conn() if err != nil { return nil, err } diff --git a/redis.go b/redis.go index ecf1fc0..b71b9fc 100644 --- a/redis.go +++ b/redis.go @@ -42,8 +42,8 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) { return cn, isNew, nil } -func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { - if internal.IsBadConn(err, allowTimeout) { +func (c *baseClient) putConn(cn *pool.Conn, err error) bool { + if internal.IsBadConn(err, false) { _ = c.connPool.Remove(cn) return false } @@ -104,7 +104,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { cn.SetWriteTimeout(c.opt.WriteTimeout) if err := writeCmd(cn, cmd); err != nil { - c.putConn(cn, err, false) + c.putConn(cn, err) cmd.setErr(err) if err != nil && internal.IsRetryableError(err) { continue @@ -114,7 +114,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { cn.SetReadTimeout(c.cmdTimeout(cmd)) err = cmd.readReply(cn) - c.putConn(cn, err, false) + c.putConn(cn, err) if err != nil && internal.IsRetryableError(err) { continue } @@ -167,7 +167,7 @@ func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer { } canRetry, err := p(cn, cmds) - c.putConn(cn, err, false) + c.putConn(cn, err) if err == nil { return nil } diff --git a/ring.go b/ring.go index 69715b6..d13a33b 100644 --- a/ring.go +++ b/ring.go @@ -428,7 +428,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) { } canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds) - shard.Client.putConn(cn, err, false) + shard.Client.putConn(cn, err) if err == nil { continue }