Ensure all new connections are initialised

This commit is contained in:
Dimitrij Denissenko 2015-04-22 08:32:54 +01:00
parent a5891da2f6
commit 750d14fe21
3 changed files with 56 additions and 58 deletions

66
pool.go
View File

@ -24,7 +24,7 @@ var (
type pool interface {
First() *conn
Get() (*conn, bool, error)
Get() (*conn, error)
Put(*conn) error
Remove(*conn) error
Len() int
@ -171,23 +171,59 @@ func (p *connPool) new() (*conn, error) {
)
return nil, err
}
cn, err := p.dial()
if err != nil {
p.lastDialErr = err
return nil, err
}
return cn, err
if err := p.init(cn); err != nil {
p.Remove(cn)
return nil, err
}
return cn, nil
}
// Initialize connection
func (p *connPool) init(cn *conn) error {
if p.opt.Password == "" && p.opt.DB == 0 {
return nil
}
// Use connection to connect to redis
pool := newSingleConnPool(p, false)
pool.SetConn(cn)
// Client is not closed because we want to reuse underlying connection.
client := newClient(p.opt, pool)
if p.opt.Password != "" {
if err := client.Auth(p.opt.Password).Err(); err != nil {
return err
}
}
if p.opt.DB > 0 {
if err := client.Select(p.opt.DB).Err(); err != nil {
return err
}
}
return nil
}
// Get returns existed connection from the pool or creates a new one
// if needed.
func (p *connPool) Get() (*conn, bool, error) {
func (p *connPool) Get() (*conn, error) {
if p.isClosed() {
return nil, false, errClosed
return nil, errClosed
}
// Fetch first non-idle connection, if available
if cn := p.First(); cn != nil {
return cn, false, nil
return cn, nil
}
// Try to create a new one
@ -195,18 +231,18 @@ func (p *connPool) Get() (*conn, bool, error) {
cn, err := p.new()
if err != nil {
atomic.AddInt32(&p.size, -1) // Undo ref increment
return nil, false, err
return nil, err
}
return cn, true, nil
return cn, nil
}
atomic.AddInt32(&p.size, -1)
// Otherwise, wait for the available connection
if cn := p.wait(p.opt.PoolTimeout); cn != nil {
return cn, false, nil
return cn, nil
}
return nil, false, errPoolTimeout
return nil, errPoolTimeout
}
func (p *connPool) Put(cn *conn) error {
@ -300,24 +336,24 @@ func (p *singleConnPool) First() *conn {
return p.cn
}
func (p *singleConnPool) Get() (*conn, bool, error) {
func (p *singleConnPool) Get() (*conn, error) {
defer p.cnMtx.Unlock()
p.cnMtx.Lock()
if p.closed {
return nil, false, errClosed
return nil, errClosed
}
if p.cn != nil {
return p.cn, false, nil
return p.cn, nil
}
cn, isNew, err := p.pool.Get()
cn, err := p.pool.Get()
if err != nil {
return nil, false, err
return nil, err
}
p.cn = cn
return p.cn, isNew, nil
return p.cn, nil
}
func (p *singleConnPool) Put(cn *conn) error {

View File

@ -106,7 +106,7 @@ var _ = Describe("Pool", func() {
})
It("should remove broken connections", func() {
cn, _, err := client.Pool().Get()
cn, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
Expect(cn.Close()).NotTo(HaveOccurred())
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())
@ -140,12 +140,12 @@ var _ = Describe("Pool", func() {
pool := client.Pool()
// Reserve one connection.
cn, _, err := client.Pool().Get()
cn, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
// Reserve the rest of connections.
for i := 0; i < 9; i++ {
_, _, err := client.Pool().Get()
_, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
}
@ -190,7 +190,7 @@ func BenchmarkPool(b *testing.B) {
pool := client.Pool()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
conn, _, err := pool.Get()
conn, err := pool.Get()
if err != nil {
b.Fatalf("no error expected on pool.Get but received: %s", err.Error())
}

View File

@ -12,45 +12,7 @@ type baseClient struct {
}
func (c *baseClient) conn() (*conn, error) {
cn, isNew, err := c.connPool.Get()
if err != nil {
return nil, err
}
if isNew {
if err := c.initConn(cn); err != nil {
c.putConn(cn, err)
return nil, err
}
}
return cn, nil
}
func (c *baseClient) initConn(cn *conn) error {
if c.opt.Password == "" && c.opt.DB == 0 {
return nil
}
pool := newSingleConnPool(c.connPool, false)
pool.SetConn(cn)
// Client is not closed because we want to reuse underlying connection.
client := newClient(c.opt, pool)
if c.opt.Password != "" {
if err := client.Auth(c.opt.Password).Err(); err != nil {
return err
}
}
if c.opt.DB > 0 {
if err := client.Select(c.opt.DB).Err(); err != nil {
return err
}
}
return nil
return c.connPool.Get()
}
func (c *baseClient) putConn(cn *conn, ei error) {