Rework PubSub conn management

This commit is contained in:
Vladimir Mihailenco 2017-04-24 12:43:15 +03:00
parent 2528e7a712
commit 9ebd89772a
4 changed files with 80 additions and 88 deletions

View File

@ -704,7 +704,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
} }
err = c.pipelineProcessCmds(cn, cmds, failedCmds) err = c.pipelineProcessCmds(cn, cmds, failedCmds)
node.Client.putConn(cn, err, false) node.Client.putConn(cn, err)
} }
if len(failedCmds) == 0 { if len(failedCmds) == 0 {
@ -840,7 +840,7 @@ func (c *ClusterClient) txPipelineExec(cmds []Cmder) error {
} }
err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds) err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds)
node.Client.putConn(cn, err, false) node.Client.putConn(cn, err)
} }
if len(failedCmds) == 0 { if len(failedCmds) == 0 {

152
pubsub.go
View File

@ -20,49 +20,14 @@ type PubSub struct {
cn *pool.Conn cn *pool.Conn
closed bool closed bool
cmd *Cmd
subMu sync.Mutex subMu sync.Mutex
channels []string channels []string
patterns []string patterns []string
cmd *Cmd
} }
func (c *PubSub) conn() (*pool.Conn, error) { func (c *PubSub) conn() (*pool.Conn, bool, error) {
cn, isNew, err := c._conn()
if err != nil {
return nil, err
}
if isNew {
if err := c.resubscribe(); err != nil {
internal.Logf("resubscribe failed: %s", err)
}
}
return cn, nil
}
func (c *PubSub) resubscribe() error {
c.subMu.Lock()
channels := c.channels
patterns := c.patterns
c.subMu.Unlock()
var firstErr error
if len(channels) > 0 {
if err := c.subscribe("subscribe", channels...); err != nil && firstErr == nil {
firstErr = err
}
}
if len(patterns) > 0 {
if err := c.subscribe("psubscribe", patterns...); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
func (c *PubSub) _conn() (*pool.Conn, bool, error) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -86,21 +51,81 @@ func (c *PubSub) _conn() (*pool.Conn, bool, error) {
} }
} }
if err := c.resubscribe(cn); err != nil {
return nil, false, err
}
c.cn = cn c.cn = cn
return cn, true, nil return cn, true, nil
} }
func (c *PubSub) putConn(cn *pool.Conn, err error) { func (c *PubSub) resubscribe(cn *pool.Conn) error {
if internal.IsBadConn(err, true) { c.subMu.Lock()
c.mu.Lock() defer c.subMu.Unlock()
if c.cn == cn {
_ = c.closeConn() var firstErr error
if len(c.channels) > 0 {
if err := c._subscribe(cn, "subscribe", c.channels...); err != nil && firstErr == nil {
firstErr = err
} }
c.mu.Unlock()
} }
if len(c.patterns) > 0 {
if err := c._subscribe(cn, "psubscribe", c.patterns...); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
func (c *PubSub) putConn(cn *pool.Conn, err error) {
if !internal.IsBadConn(err, true) {
return
}
c.mu.Lock()
if c.cn == cn {
_ = c.closeConn()
}
c.mu.Unlock()
}
func (c *PubSub) closeConn() error {
err := c.base.connPool.CloseConn(c.cn)
c.cn = nil
return err
}
func (c *PubSub) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return pool.ErrClosed
}
c.closed = true
if c.cn != nil {
return c.closeConn()
}
return nil
} }
func (c *PubSub) subscribe(redisCmd string, channels ...string) error { func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
cn, isNew, err := c.conn()
if err != nil {
return err
}
if isNew {
return nil
}
err = c._subscribe(cn, redisCmd, channels...)
c.putConn(cn, err)
return err
}
func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error {
args := make([]interface{}, 1+len(channels)) args := make([]interface{}, 1+len(channels))
args[0] = redisCmd args[0] = redisCmd
for i, channel := range channels { for i, channel := range channels {
@ -108,19 +133,8 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
} }
cmd := NewSliceCmd(args...) cmd := NewSliceCmd(args...)
cn, isNew, err := c._conn()
if err != nil {
return err
}
if isNew {
return c.resubscribe()
}
cn.SetWriteTimeout(c.base.opt.WriteTimeout) cn.SetWriteTimeout(c.base.opt.WriteTimeout)
err = writeCmd(cn, cmd) return writeCmd(cn, cmd)
c.putConn(cn, err)
return err
} }
// Subscribes the client to the specified channels. // Subscribes the client to the specified channels.
@ -157,28 +171,6 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error {
return c.subscribe("punsubscribe", patterns...) return c.subscribe("punsubscribe", patterns...)
} }
func (c *PubSub) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return pool.ErrClosed
}
c.closed = true
if c.cn != nil {
_ = c.closeConn()
}
return nil
}
func (c *PubSub) closeConn() error {
err := c.base.connPool.CloseConn(c.cn)
c.cn = nil
return err
}
func (c *PubSub) Ping(payload ...string) error { func (c *PubSub) Ping(payload ...string) error {
args := []interface{}{"ping"} args := []interface{}{"ping"}
if len(payload) == 1 { if len(payload) == 1 {
@ -186,7 +178,7 @@ func (c *PubSub) Ping(payload ...string) error {
} }
cmd := NewCmd(args...) cmd := NewCmd(args...)
cn, err := c.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
return err return err
} }
@ -279,7 +271,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
c.cmd = NewCmd() c.cmd = NewCmd()
} }
cn, err := c.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -42,8 +42,8 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) {
return cn, isNew, nil return cn, isNew, nil
} }
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { func (c *baseClient) putConn(cn *pool.Conn, err error) bool {
if internal.IsBadConn(err, allowTimeout) { if internal.IsBadConn(err, false) {
_ = c.connPool.Remove(cn) _ = c.connPool.Remove(cn)
return false return false
} }
@ -104,7 +104,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
cn.SetWriteTimeout(c.opt.WriteTimeout) cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmd); err != nil { if err := writeCmd(cn, cmd); err != nil {
c.putConn(cn, err, false) c.putConn(cn, err)
cmd.setErr(err) cmd.setErr(err)
if err != nil && internal.IsRetryableError(err) { if err != nil && internal.IsRetryableError(err) {
continue continue
@ -114,7 +114,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error {
cn.SetReadTimeout(c.cmdTimeout(cmd)) cn.SetReadTimeout(c.cmdTimeout(cmd))
err = cmd.readReply(cn) err = cmd.readReply(cn)
c.putConn(cn, err, false) c.putConn(cn, err)
if err != nil && internal.IsRetryableError(err) { if err != nil && internal.IsRetryableError(err) {
continue continue
} }
@ -167,7 +167,7 @@ func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer {
} }
canRetry, err := p(cn, cmds) canRetry, err := p(cn, cmds)
c.putConn(cn, err, false) c.putConn(cn, err)
if err == nil { if err == nil {
return nil return nil
} }

View File

@ -428,7 +428,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) {
} }
canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds) canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds)
shard.Client.putConn(cn, err, false) shard.Client.putConn(cn, err)
if err == nil { if err == nil {
continue continue
} }