Ensure all new connections are initialised

This commit is contained in:
Dimitrij Denissenko 2015-04-22 08:32:54 +01:00
parent a5891da2f6
commit 750d14fe21
3 changed files with 56 additions and 58 deletions

66
pool.go
View File

@ -24,7 +24,7 @@ var (
type pool interface { type pool interface {
First() *conn First() *conn
Get() (*conn, bool, error) Get() (*conn, error)
Put(*conn) error Put(*conn) error
Remove(*conn) error Remove(*conn) error
Len() int Len() int
@ -171,23 +171,59 @@ func (p *connPool) new() (*conn, error) {
) )
return nil, err return nil, err
} }
cn, err := p.dial() cn, err := p.dial()
if err != nil { if err != nil {
p.lastDialErr = err p.lastDialErr = err
return nil, err
} }
return cn, err
if err := p.init(cn); err != nil {
p.Remove(cn)
return nil, err
}
return cn, nil
}
// Initialize connection
func (p *connPool) init(cn *conn) error {
if p.opt.Password == "" && p.opt.DB == 0 {
return nil
}
// Use connection to connect to redis
pool := newSingleConnPool(p, false)
pool.SetConn(cn)
// Client is not closed because we want to reuse underlying connection.
client := newClient(p.opt, pool)
if p.opt.Password != "" {
if err := client.Auth(p.opt.Password).Err(); err != nil {
return err
}
}
if p.opt.DB > 0 {
if err := client.Select(p.opt.DB).Err(); err != nil {
return err
}
}
return nil
} }
// Get returns existed connection from the pool or creates a new one // Get returns existed connection from the pool or creates a new one
// if needed. // if needed.
func (p *connPool) Get() (*conn, bool, error) { func (p *connPool) Get() (*conn, error) {
if p.isClosed() { if p.isClosed() {
return nil, false, errClosed return nil, errClosed
} }
// Fetch first non-idle connection, if available // Fetch first non-idle connection, if available
if cn := p.First(); cn != nil { if cn := p.First(); cn != nil {
return cn, false, nil return cn, nil
} }
// Try to create a new one // Try to create a new one
@ -195,18 +231,18 @@ func (p *connPool) Get() (*conn, bool, error) {
cn, err := p.new() cn, err := p.new()
if err != nil { if err != nil {
atomic.AddInt32(&p.size, -1) // Undo ref increment atomic.AddInt32(&p.size, -1) // Undo ref increment
return nil, false, err return nil, err
} }
return cn, true, nil return cn, nil
} }
atomic.AddInt32(&p.size, -1) atomic.AddInt32(&p.size, -1)
// Otherwise, wait for the available connection // Otherwise, wait for the available connection
if cn := p.wait(p.opt.PoolTimeout); cn != nil { if cn := p.wait(p.opt.PoolTimeout); cn != nil {
return cn, false, nil return cn, nil
} }
return nil, false, errPoolTimeout return nil, errPoolTimeout
} }
func (p *connPool) Put(cn *conn) error { func (p *connPool) Put(cn *conn) error {
@ -300,24 +336,24 @@ func (p *singleConnPool) First() *conn {
return p.cn return p.cn
} }
func (p *singleConnPool) Get() (*conn, bool, error) { func (p *singleConnPool) Get() (*conn, error) {
defer p.cnMtx.Unlock() defer p.cnMtx.Unlock()
p.cnMtx.Lock() p.cnMtx.Lock()
if p.closed { if p.closed {
return nil, false, errClosed return nil, errClosed
} }
if p.cn != nil { if p.cn != nil {
return p.cn, false, nil return p.cn, nil
} }
cn, isNew, err := p.pool.Get() cn, err := p.pool.Get()
if err != nil { if err != nil {
return nil, false, err return nil, err
} }
p.cn = cn p.cn = cn
return p.cn, isNew, nil return p.cn, nil
} }
func (p *singleConnPool) Put(cn *conn) error { func (p *singleConnPool) Put(cn *conn) error {

View File

@ -106,7 +106,7 @@ var _ = Describe("Pool", func() {
}) })
It("should remove broken connections", func() { It("should remove broken connections", func() {
cn, _, err := client.Pool().Get() cn, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn.Close()).NotTo(HaveOccurred()) Expect(cn.Close()).NotTo(HaveOccurred())
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())
@ -140,12 +140,12 @@ var _ = Describe("Pool", func() {
pool := client.Pool() pool := client.Pool()
// Reserve one connection. // Reserve one connection.
cn, _, err := client.Pool().Get() cn, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Reserve the rest of connections. // Reserve the rest of connections.
for i := 0; i < 9; i++ { for i := 0; i < 9; i++ {
_, _, err := client.Pool().Get() _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
@ -190,7 +190,7 @@ func BenchmarkPool(b *testing.B) {
pool := client.Pool() pool := client.Pool()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
conn, _, err := pool.Get() conn, err := pool.Get()
if err != nil { if err != nil {
b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) b.Fatalf("no error expected on pool.Get but received: %s", err.Error())
} }

View File

@ -12,45 +12,7 @@ type baseClient struct {
} }
func (c *baseClient) conn() (*conn, error) { func (c *baseClient) conn() (*conn, error) {
cn, isNew, err := c.connPool.Get() return c.connPool.Get()
if err != nil {
return nil, err
}
if isNew {
if err := c.initConn(cn); err != nil {
c.putConn(cn, err)
return nil, err
}
}
return cn, nil
}
func (c *baseClient) initConn(cn *conn) error {
if c.opt.Password == "" && c.opt.DB == 0 {
return nil
}
pool := newSingleConnPool(c.connPool, false)
pool.SetConn(cn)
// Client is not closed because we want to reuse underlying connection.
client := newClient(c.opt, pool)
if c.opt.Password != "" {
if err := client.Auth(c.opt.Password).Err(); err != nil {
return err
}
}
if c.opt.DB > 0 {
if err := client.Select(c.opt.DB).Err(); err != nil {
return err
}
}
return nil
} }
func (c *baseClient) putConn(cn *conn, ei error) { func (c *baseClient) putConn(cn *conn, ei error) {