From 4e9cea88768f5a80aae8fbf1eeaaa6c27c724eb5 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sat, 3 Aug 2019 17:21:12 +0300 Subject: [PATCH] Add proper SingleConnPool implementation --- cluster.go | 22 ++--- internal/error.go => error.go | 35 ++++--- example_test.go | 21 +++++ internal/pool/pool_single.go | 168 +++++++++++++++++++++++++++++----- pubsub.go | 2 +- redis.go | 29 +++--- ring.go | 4 +- 7 files changed, 212 insertions(+), 69 deletions(-) rename internal/error.go => error.go (65%) diff --git a/cluster.go b/cluster.go index 6d311f05..a784e28a 100644 --- a/cluster.go +++ b/cluster.go @@ -787,7 +787,7 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { } // If slave is loading - pick another node. - if c.opt.ReadOnly && internal.IsLoadingError(err) { + if c.opt.ReadOnly && isLoadingError(err) { node.MarkAsFailing() node = nil continue @@ -795,7 +795,7 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { var moved bool var addr string - moved, ask, addr = internal.IsMovedError(err) + moved, ask, addr = isMovedError(err) if moved || ask { node, err = c.nodes.Get(addr) if err != nil { @@ -804,12 +804,12 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { continue } - if err == pool.ErrClosed || internal.IsReadOnlyError(err) { + if err == pool.ErrClosed || isReadOnlyError(err) { node = nil continue } - if internal.IsRetryableError(err, true) { + if isRetryableError(err, true) { // First retry the same node. if attempt == 0 { continue @@ -1173,9 +1173,9 @@ func (c *ClusterClient) pipelineReadCmds( continue } - if c.opt.ReadOnly && internal.IsLoadingError(err) { + if c.opt.ReadOnly && isLoadingError(err) { node.MarkAsFailing() - } else if internal.IsRedisError(err) { + } else if isRedisError(err) { continue } @@ -1192,7 +1192,7 @@ func (c *ClusterClient) pipelineReadCmds( func (c *ClusterClient) checkMovedErr( cmd Cmder, err error, failedCmds *cmdsMap, ) bool { - moved, ask, addr := internal.IsMovedError(err) + moved, ask, addr := isMovedError(err) if moved { c.state.LazyReload() @@ -1346,7 +1346,7 @@ func (c *ClusterClient) txPipelineReadQueued( continue } - if c.checkMovedErr(cmd, err, failedCmds) || internal.IsRedisError(err) { + if c.checkMovedErr(cmd, err, failedCmds) || isRedisError(err) { continue } @@ -1418,7 +1418,7 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke c.state.LazyReload() } - moved, ask, addr := internal.IsMovedError(err) + moved, ask, addr := isMovedError(err) if moved || ask { node, err = c.nodes.Get(addr) if err != nil { @@ -1427,7 +1427,7 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke continue } - if err == pool.ErrClosed || internal.IsReadOnlyError(err) { + if err == pool.ErrClosed || isReadOnlyError(err) { node, err = c.slotMasterNode(slot) if err != nil { return err @@ -1435,7 +1435,7 @@ func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, ke continue } - if internal.IsRetryableError(err, true) { + if isRetryableError(err, true) { continue } diff --git a/internal/error.go b/error.go similarity index 65% rename from internal/error.go rename to error.go index a425574e..d9cfd760 100644 --- a/internal/error.go +++ b/error.go @@ -1,20 +1,18 @@ -package internal +package redis import ( "context" - "errors" "io" "net" "strings" + "github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/proto" ) -var ErrSingleConnPoolClosed = errors.New("redis: SingleConnPool is closed") - -func IsRetryableError(err error, retryTimeout bool) bool { +func isRetryableError(err error, retryTimeout bool) bool { switch err { - case nil, context.Canceled, context.DeadlineExceeded: + case nil, context.Canceled, context.DeadlineExceeded, pool.ErrBadConn: return false case io.EOF: return true @@ -25,9 +23,6 @@ func IsRetryableError(err error, retryTimeout bool) bool { } return true } - if err == ErrSingleConnPoolClosed { - return true - } s := err.Error() if s == "ERR max number of clients reached" { @@ -45,18 +40,20 @@ func IsRetryableError(err error, retryTimeout bool) bool { return false } -func IsRedisError(err error) bool { +func isRedisError(err error) bool { _, ok := err.(proto.RedisError) return ok } -func IsBadConn(err error, allowTimeout bool) bool { - if err == nil { +func isBadConn(err error, allowTimeout bool) bool { + switch err { + case nil: return false + case pool.ErrBadConn: + return true } - if IsRedisError(err) { - // #790 - return IsReadOnlyError(err) + if isRedisError(err) { + return isReadOnlyError(err) // #790 } if allowTimeout { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { @@ -66,8 +63,8 @@ func IsBadConn(err error, allowTimeout bool) bool { return true } -func IsMovedError(err error) (moved bool, ask bool, addr string) { - if !IsRedisError(err) { +func isMovedError(err error) (moved bool, ask bool, addr string) { + if !isRedisError(err) { return } @@ -89,10 +86,10 @@ func IsMovedError(err error) (moved bool, ask bool, addr string) { return } -func IsLoadingError(err error) bool { +func isLoadingError(err error) bool { return strings.HasPrefix(err.Error(), "LOADING ") } -func IsReadOnlyError(err error) bool { +func isReadOnlyError(err error) bool { return strings.HasPrefix(err.Error(), "READONLY ") } diff --git a/example_test.go b/example_test.go index 764232fc..d986a814 100644 --- a/example_test.go +++ b/example_test.go @@ -151,6 +151,27 @@ func ExampleClient() { // missing_key does not exist } +func ExampleConn() { + conn := redisdb.Conn() + + err := conn.ClientSetName("foobar").Err() + if err != nil { + panic(err) + } + + // Open other connections. + for i := 0; i < 10; i++ { + go redisdb.Ping() + } + + s, err := conn.ClientGetName().Result() + if err != nil { + panic(err) + } + fmt.Println(s) + // Output: foobar +} + func ExampleClient_Set() { // Last argument is expiration. Zero means the key has no // expiration time. diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 12542c6d..54d6c4e5 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -2,66 +2,184 @@ package pool import ( "context" - - "github.com/go-redis/redis/internal" + "fmt" + "sync/atomic" ) +const ( + stateDefault = 0 + stateInited = 1 + stateClosed = 2 +) + +var ErrBadConn = fmt.Errorf("pg: Conn is in a bad state") + type SingleConnPool struct { - cn *Conn - cnClosed bool + pool Pooler + + state uint32 // atomic + ch chan *Conn + + level int32 // atomic + _hasBadConn uint32 // atomic } var _ Pooler = (*SingleConnPool)(nil) -func NewSingleConnPool(cn *Conn) *SingleConnPool { - return &SingleConnPool{ - cn: cn, +func NewSingleConnPool(pool Pooler) *SingleConnPool { + p, ok := pool.(*SingleConnPool) + if !ok { + p = &SingleConnPool{ + pool: pool, + ch: make(chan *Conn, 1), + } + } + atomic.AddInt32(&p.level, 1) + return p +} + +func (p *SingleConnPool) SetConn(cn *Conn) { + if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { + p.ch <- cn + } else { + panic("not reached") } } -func (p *SingleConnPool) NewConn(context.Context) (*Conn, error) { - panic("not implemented") +func (p *SingleConnPool) NewConn(c context.Context) (*Conn, error) { + return p.pool.NewConn(c) } -func (p *SingleConnPool) CloseConn(*Conn) error { - panic("not implemented") +func (p *SingleConnPool) CloseConn(cn *Conn) error { + return p.pool.CloseConn(cn) } -func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { - if p.cnClosed { - return nil, internal.ErrSingleConnPoolClosed +func (p *SingleConnPool) Get(c context.Context) (*Conn, error) { + // In worst case this races with Close which is not a very common operation. + for i := 0; i < 1000; i++ { + switch atomic.LoadUint32(&p.state) { + case stateDefault: + cn, err := p.pool.Get(c) + if err != nil { + return nil, err + } + if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { + return cn, nil + } + p.pool.Remove(cn) + case stateInited: + if p.hasBadConn() { + return nil, ErrBadConn + } + cn, ok := <-p.ch + if !ok { + return nil, ErrClosed + } + return cn, nil + case stateClosed: + return nil, ErrClosed + default: + panic("not reached") + } } - return p.cn, nil + return nil, fmt.Errorf("pg: SingleConnPool.Get: infinite loop") } func (p *SingleConnPool) Put(cn *Conn) { - if p.cn != cn { - panic("p.cn != cn") + defer func() { + if recover() != nil { + p.freeConn(cn) + } + }() + p.ch <- cn +} + +func (p *SingleConnPool) freeConn(cn *Conn) { + if p.hasBadConn() { + p.pool.Remove(cn) + } else { + p.pool.Put(cn) } } func (p *SingleConnPool) Remove(cn *Conn) { - if p.cn != cn { - panic("p.cn != cn") - } - p.cnClosed = true + defer func() { + if recover() != nil { + p.pool.Remove(cn) + } + }() + atomic.StoreUint32(&p._hasBadConn, 1) + p.ch <- cn } func (p *SingleConnPool) Len() int { - if p.cnClosed { + switch atomic.LoadUint32(&p.state) { + case stateDefault: return 0 + case stateInited: + return 1 + case stateClosed: + return 0 + default: + panic("not reached") } - return 1 } func (p *SingleConnPool) IdleLen() int { - return 0 + return len(p.ch) } func (p *SingleConnPool) Stats() *Stats { - return nil + return &Stats{} } func (p *SingleConnPool) Close() error { + level := atomic.AddInt32(&p.level, -1) + if level > 0 { + return nil + } + + for i := 0; i < 1000; i++ { + state := atomic.LoadUint32(&p.state) + if state == stateClosed { + return ErrClosed + } + if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { + close(p.ch) + cn, ok := <-p.ch + if ok { + p.freeConn(cn) + } + return nil + } + } + + return fmt.Errorf("pg: SingleConnPool.Close: infinite loop") +} + +func (p *SingleConnPool) Reset() error { + if !atomic.CompareAndSwapUint32(&p._hasBadConn, 1, 0) { + return nil + } + + select { + case cn, ok := <-p.ch: + if !ok { + return ErrClosed + } + p.pool.Remove(cn) + default: + return fmt.Errorf("pg: SingleConnPool does not have a Conn") + } + + if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { + state := atomic.LoadUint32(&p.state) + return fmt.Errorf("pg: invalid SingleConnPool state: %d", state) + } + return nil } + +func (p *SingleConnPool) hasBadConn() bool { + return atomic.LoadUint32(&p._hasBadConn) == 1 +} diff --git a/pubsub.go b/pubsub.go index 139169f2..72226d35 100644 --- a/pubsub.go +++ b/pubsub.go @@ -142,7 +142,7 @@ func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { if c.cn != cn { return } - if internal.IsBadConn(err, allowTimeout) { + if isBadConn(err, allowTimeout) { c.reconnect(err) } } diff --git a/redis.go b/redis.go index e2bf5e91..279f0378 100644 --- a/redis.go +++ b/redis.go @@ -183,7 +183,7 @@ func (c *baseClient) releaseConn(cn *pool.Conn, err error) { c.limiter.ReportResult(err) } - if internal.IsBadConn(err, false) { + if isBadConn(err, false) { c.connPool.Remove(cn) } else { c.connPool.Put(cn) @@ -195,7 +195,7 @@ func (c *baseClient) releaseConnStrict(cn *pool.Conn, err error) { c.limiter.ReportResult(err) } - if err == nil || internal.IsRedisError(err) { + if err == nil || isRedisError(err) { c.connPool.Put(cn) } else { c.connPool.Remove(cn) @@ -215,7 +215,10 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return nil } - conn := newConn(ctx, c.opt, cn) + connPool := pool.NewSingleConnPool(nil) + connPool.SetConn(cn) + conn := newConn(ctx, c.opt, connPool) + _, err := conn.Pipelined(func(pipe Pipeliner) error { if c.opt.Password != "" { pipe.Auth(c.opt.Password) @@ -252,7 +255,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error { cn, err := c.getConn(ctx) if err != nil { cmd.setErr(err) - if internal.IsRetryableError(err, true) { + if isRetryableError(err, true) { continue } return err @@ -264,7 +267,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error { if err != nil { c.releaseConn(cn, err) cmd.setErr(err) - if internal.IsRetryableError(err, true) { + if isRetryableError(err, true) { continue } return err @@ -272,7 +275,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error { err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) c.releaseConn(cn, err) - if err != nil && internal.IsRetryableError(err, cmd.readTimeout() == nil) { + if err != nil && isRetryableError(err, cmd.readTimeout() == nil) { continue } @@ -347,7 +350,7 @@ func (c *baseClient) generalProcessPipeline( canRetry, err := p(ctx, cn, cmds) c.releaseConnStrict(cn, err) - if !canRetry || !internal.IsRetryableError(err, true) { + if !canRetry || !isRetryableError(err, true) { break } } @@ -374,7 +377,7 @@ func (c *baseClient) pipelineProcessCmds( func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { for _, cmd := range cmds { err := cmd.readReply(rd) - if err != nil && !internal.IsRedisError(err) { + if err != nil && !isRedisError(err) { return err } } @@ -421,7 +424,7 @@ func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error { for range cmds { err = statusCmd.readReply(rd) - if err != nil && !internal.IsRedisError(err) { + if err != nil && !isRedisError(err) { return err } } @@ -500,6 +503,10 @@ func (c *Client) WithContext(ctx context.Context) *Client { return &clone } +func (c *Client) Conn() *Conn { + return newConn(c.ctx, c.opt, pool.NewSingleConnPool(c.connPool)) +} + // Do creates a Cmd from the args and processes the cmd. func (c *Client) Do(args ...interface{}) *Cmd { return c.DoContext(c.ctx, args...) @@ -643,12 +650,12 @@ type Conn struct { ctx context.Context } -func newConn(ctx context.Context, opt *Options, cn *pool.Conn) *Conn { +func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn { c := Conn{ conn: &conn{ baseClient: baseClient{ opt: opt, - connPool: pool.NewSingleConnPool(cn), + connPool: connPool, }, }, ctx: ctx, diff --git a/ring.go b/ring.go index 94ea77b6..30d87d24 100644 --- a/ring.go +++ b/ring.go @@ -568,7 +568,7 @@ func (c *Ring) process(ctx context.Context, cmd Cmder) error { if err == nil { return nil } - if !internal.IsRetryableError(err, cmd.readTimeout() == nil) { + if !isRetryableError(err, cmd.readTimeout() == nil) { return err } } @@ -662,7 +662,7 @@ func (c *Ring) generalProcessPipeline( } shard.Client.releaseConnStrict(cn, err) - if canRetry && internal.IsRetryableError(err, true) { + if canRetry && isRetryableError(err, true) { mu.Lock() if failedCmdsMap == nil { failedCmdsMap = make(map[string][]Cmder)