diff --git a/pool.go b/pool.go index 764777b..b2e2d08 100644 --- a/pool.go +++ b/pool.go @@ -24,7 +24,7 @@ var ( type pool interface { First() *conn - Get() (*conn, bool, error) + Get() (*conn, error) Put(*conn) error Remove(*conn) error Len() int @@ -171,23 +171,59 @@ func (p *connPool) new() (*conn, error) { ) return nil, err } + cn, err := p.dial() if err != nil { p.lastDialErr = err + return nil, err } - return cn, err + + if err := p.init(cn); err != nil { + p.Remove(cn) + return nil, err + } + + return cn, nil +} + +// Initialize connection +func (p *connPool) init(cn *conn) error { + if p.opt.Password == "" && p.opt.DB == 0 { + return nil + } + + // Use connection to connect to redis + pool := newSingleConnPool(p, false) + pool.SetConn(cn) + + // Client is not closed because we want to reuse underlying connection. + client := newClient(p.opt, pool) + + if p.opt.Password != "" { + if err := client.Auth(p.opt.Password).Err(); err != nil { + return err + } + } + + if p.opt.DB > 0 { + if err := client.Select(p.opt.DB).Err(); err != nil { + return err + } + } + + return nil } // Get returns existed connection from the pool or creates a new one // if needed. -func (p *connPool) Get() (*conn, bool, error) { +func (p *connPool) Get() (*conn, error) { if p.isClosed() { - return nil, false, errClosed + return nil, errClosed } // Fetch first non-idle connection, if available if cn := p.First(); cn != nil { - return cn, false, nil + return cn, nil } // Try to create a new one @@ -195,18 +231,18 @@ func (p *connPool) Get() (*conn, bool, error) { cn, err := p.new() if err != nil { atomic.AddInt32(&p.size, -1) // Undo ref increment - return nil, false, err + return nil, err } - return cn, true, nil + return cn, nil } atomic.AddInt32(&p.size, -1) // Otherwise, wait for the available connection if cn := p.wait(p.opt.PoolTimeout); cn != nil { - return cn, false, nil + return cn, nil } - return nil, false, errPoolTimeout + return nil, errPoolTimeout } func (p *connPool) Put(cn *conn) error { @@ -300,24 +336,24 @@ func (p *singleConnPool) First() *conn { return p.cn } -func (p *singleConnPool) Get() (*conn, bool, error) { +func (p *singleConnPool) Get() (*conn, error) { defer p.cnMtx.Unlock() p.cnMtx.Lock() if p.closed { - return nil, false, errClosed + return nil, errClosed } if p.cn != nil { - return p.cn, false, nil + return p.cn, nil } - cn, isNew, err := p.pool.Get() + cn, err := p.pool.Get() if err != nil { - return nil, false, err + return nil, err } p.cn = cn - return p.cn, isNew, nil + return p.cn, nil } func (p *singleConnPool) Put(cn *conn) error { diff --git a/pool_test.go b/pool_test.go index 310730b..4b1bc94 100644 --- a/pool_test.go +++ b/pool_test.go @@ -106,7 +106,7 @@ var _ = Describe("Pool", func() { }) It("should remove broken connections", func() { - cn, _, err := client.Pool().Get() + cn, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) Expect(cn.Close()).NotTo(HaveOccurred()) Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) @@ -140,12 +140,12 @@ var _ = Describe("Pool", func() { pool := client.Pool() // Reserve one connection. - cn, _, err := client.Pool().Get() + cn, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) // Reserve the rest of connections. for i := 0; i < 9; i++ { - _, _, err := client.Pool().Get() + _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) } @@ -190,7 +190,7 @@ func BenchmarkPool(b *testing.B) { pool := client.Pool() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - conn, _, err := pool.Get() + conn, err := pool.Get() if err != nil { b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) } diff --git a/redis.go b/redis.go index 3f44693..d03425c 100644 --- a/redis.go +++ b/redis.go @@ -12,45 +12,7 @@ type baseClient struct { } func (c *baseClient) conn() (*conn, error) { - cn, isNew, err := c.connPool.Get() - if err != nil { - return nil, err - } - - if isNew { - if err := c.initConn(cn); err != nil { - c.putConn(cn, err) - return nil, err - } - } - - return cn, nil -} - -func (c *baseClient) initConn(cn *conn) error { - if c.opt.Password == "" && c.opt.DB == 0 { - return nil - } - - pool := newSingleConnPool(c.connPool, false) - pool.SetConn(cn) - - // Client is not closed because we want to reuse underlying connection. - client := newClient(c.opt, pool) - - if c.opt.Password != "" { - if err := client.Auth(c.opt.Password).Err(); err != nil { - return err - } - } - - if c.opt.DB > 0 { - if err := client.Select(c.opt.DB).Err(); err != nil { - return err - } - } - - return nil + return c.connPool.Get() } func (c *baseClient) putConn(cn *conn, ei error) {