diff --git a/multi.go b/multi.go index d76af2f..e75fbc8 100644 --- a/multi.go +++ b/multi.go @@ -17,7 +17,7 @@ func (c *Client) Multi() *Multi { Client: &Client{ baseClient: &baseClient{ opt: c.opt, - connPool: newSingleConnPool(c.connPool, nil, true), + connPool: newSingleConnPool(c.connPool, true), }, }, } diff --git a/parser.go b/parser.go index fdaea65..125f891 100644 --- a/parser.go +++ b/parser.go @@ -17,7 +17,7 @@ var ( //------------------------------------------------------------------------------ -func appendCmd(buf []byte, args []string) []byte { +func appendArgs(buf []byte, args []string) []byte { buf = append(buf, '*') buf = strconv.AppendUint(buf, uint64(len(args)), 10) buf = append(buf, '\r', '\n') diff --git a/pool.go b/pool.go index 4ecb228..6fe4538 100644 --- a/pool.go +++ b/pool.go @@ -162,7 +162,7 @@ func (p *connPool) Get() (*conn, bool, error) { } if p.conns.Len() < p.opt.PoolSize { - cn, err := p.dial() + cn, err := p.new() if err != nil { p.cond.L.Unlock() return nil, false, err @@ -277,60 +277,68 @@ func (p *connPool) Close() error { type singleConnPool struct { pool pool - l sync.RWMutex - cn *conn + cnMtx sync.Mutex + cn *conn + reusable bool closed bool } -func newSingleConnPool(pool pool, cn *conn, reusable bool) *singleConnPool { +func newSingleConnPool(pool pool, reusable bool) *singleConnPool { return &singleConnPool{ pool: pool, - cn: cn, reusable: reusable, } } +func (p *singleConnPool) SetConn(cn *conn) { + p.cnMtx.Lock() + p.cn = cn + p.cnMtx.Unlock() +} + func (p *singleConnPool) Get() (*conn, bool, error) { - p.l.RLock() + defer p.cnMtx.Unlock() + p.cnMtx.Lock() + if p.closed { - p.l.RUnlock() return nil, false, errClosed } if p.cn != nil { - p.l.RUnlock() return p.cn, false, nil } - p.l.RUnlock() - p.l.Lock() cn, isNew, err := p.pool.Get() if err != nil { - p.l.Unlock() return nil, false, err } p.cn = cn - p.l.Unlock() - return cn, isNew, nil + + return p.cn, isNew, nil } func (p *singleConnPool) Put(cn *conn) error { - p.l.Lock() + defer p.cnMtx.Unlock() + p.cnMtx.Lock() if p.cn != cn { panic("p.cn != cn") } if p.closed { - p.l.Unlock() return errClosed } - p.l.Unlock() return nil } +func (p *singleConnPool) put() error { + err := p.pool.Put(p.cn) + p.cn = nil + return err +} + func (p *singleConnPool) Remove(cn *conn) error { - defer p.l.Unlock() - p.l.Lock() + defer p.cnMtx.Unlock() + p.cnMtx.Lock() if p.cn == nil { panic("p.cn == nil") } @@ -350,8 +358,8 @@ func (p *singleConnPool) remove() error { } func (p *singleConnPool) Len() int { - defer p.l.Unlock() - p.l.Lock() + defer p.cnMtx.Unlock() + p.cnMtx.Lock() if p.cn == nil { return 0 } @@ -359,8 +367,8 @@ func (p *singleConnPool) Len() int { } func (p *singleConnPool) Size() int { - defer p.l.Unlock() - p.l.Lock() + defer p.cnMtx.Unlock() + p.cnMtx.Lock() if p.cn == nil { return 0 } @@ -368,18 +376,18 @@ func (p *singleConnPool) Size() int { } func (p *singleConnPool) Filter(f func(*conn) bool) { - p.l.Lock() + p.cnMtx.Lock() if p.cn != nil { if !f(p.cn) { p.remove() } } - p.l.Unlock() + p.cnMtx.Unlock() } func (p *singleConnPool) Close() error { - defer p.l.Unlock() - p.l.Lock() + defer p.cnMtx.Unlock() + p.cnMtx.Lock() if p.closed { return nil } @@ -387,11 +395,10 @@ func (p *singleConnPool) Close() error { var err error if p.cn != nil { if p.reusable { - err = p.pool.Put(p.cn) + err = p.put() } else { - err = p.pool.Remove(p.cn) + err = p.remove() } } - p.cn = nil return err } diff --git a/pubsub.go b/pubsub.go index bc69fb1..6ac130b 100644 --- a/pubsub.go +++ b/pubsub.go @@ -14,7 +14,7 @@ func (c *Client) PubSub() *PubSub { return &PubSub{ baseClient: &baseClient{ opt: c.opt, - connPool: newSingleConnPool(c.connPool, nil, false), + connPool: newSingleConnPool(c.connPool, false), }, } } diff --git a/redis.go b/redis.go index 8ebab47..c124bf7 100644 --- a/redis.go +++ b/redis.go @@ -14,9 +14,9 @@ type baseClient struct { } func (c *baseClient) writeCmd(cn *conn, cmds ...Cmder) error { - buf := make([]byte, 0, 1000) + buf := make([]byte, 0, 64) for _, cmd := range cmds { - buf = appendCmd(buf, cmd.args()) + buf = appendArgs(buf, cmd.args()) } _, err := cn.Write(buf) @@ -29,8 +29,8 @@ func (c *baseClient) conn() (*conn, error) { return nil, err } - if isNew && (c.opt.Password != "" || c.opt.DB > 0) { - if err = c.init(cn, c.opt.Password, c.opt.DB); err != nil { + if isNew { + if err := c.initConn(cn); err != nil { c.removeConn(cn) return nil, err } @@ -39,26 +39,31 @@ func (c *baseClient) conn() (*conn, error) { return cn, nil } -func (c *baseClient) init(cn *conn, password string, db int64) error { - // Client is not closed on purpose. +func (c *baseClient) initConn(cn *conn) error { + if c.opt.Password == "" || c.opt.DB == 0 { + return nil + } + + pool := newSingleConnPool(c.connPool, false) + pool.SetConn(cn) + + // Client is not closed because we want to reuse underlying connection. client := &Client{ baseClient: &baseClient{ opt: c.opt, - connPool: newSingleConnPool(c.connPool, cn, false), + connPool: pool, }, } - if password != "" { - auth := client.Auth(password) - if auth.Err() != nil { - return auth.Err() + if c.opt.Password != "" { + if err := client.Auth(c.opt.Password).Err(); err != nil { + return err } } - if db > 0 { - sel := client.Select(db) - if sel.Err() != nil { - return sel.Err() + if c.opt.DB > 0 { + if err := client.Select(c.opt.DB).Err(); err != nil { + return err } } @@ -102,14 +107,16 @@ func (c *baseClient) run(cmd Cmder) { return } - cn.writeTimeout = c.opt.WriteTimeout if timeout := cmd.writeTimeout(); timeout != nil { cn.writeTimeout = *timeout + } else { + cn.writeTimeout = c.opt.WriteTimeout } - cn.readTimeout = c.opt.ReadTimeout if timeout := cmd.readTimeout(); timeout != nil { cn.readTimeout = *timeout + } else { + cn.readTimeout = c.opt.ReadTimeout } if err := c.writeCmd(cn, cmd); err != nil { diff --git a/sentinel.go b/sentinel.go index ca0499f..d118da1 100644 --- a/sentinel.go +++ b/sentinel.go @@ -94,7 +94,7 @@ func (c *sentinelClient) PubSub() *PubSub { return &PubSub{ baseClient: &baseClient{ opt: c.opt, - connPool: newSingleConnPool(c.connPool, nil, false), + connPool: newSingleConnPool(c.connPool, false), }, } }