Merge pull request #281 from go-redis/fix/client-init

Fix connection initialization.
This commit is contained in:
Vladimir Mihailenco 2016-03-15 15:16:38 +02:00
commit 9d394cc7fb
18 changed files with 73 additions and 69 deletions

View File

@ -69,7 +69,7 @@ func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) {
continue continue
} }
cn, _, err := client.conn() cn, err := client.conn()
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
retErr = err retErr = err

View File

@ -254,9 +254,9 @@ func ExamplePubSub_Receive() {
} }
fmt.Println(n, "clients received message") fmt.Println(n, "clients received message")
for { for i := 0; i < 2; i++ {
// ReceiveTimeout is a low level API. Use ReceiveMessage instead. // ReceiveTimeout is a low level API. Use ReceiveMessage instead.
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(5 * time.Second)
if err != nil { if err != nil {
break break
} }

View File

@ -20,7 +20,7 @@ func benchmarkPoolGetPut(b *testing.B, poolSize int) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
conn, _, err := pool.Get() conn, err := pool.Get()
if err != nil { if err != nil {
b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) b.Fatalf("no error expected on pool.Get but received: %s", err.Error())
} }
@ -56,7 +56,7 @@ func benchmarkPoolGetReplace(b *testing.B, poolSize int) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
conn, _, err := pool.Get() conn, err := pool.Get()
if err != nil { if err != nil {
b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) b.Fatalf("no error expected on pool.Get but received: %s", err.Error())
} }

View File

@ -18,7 +18,9 @@ type Conn struct {
Rd *bufio.Reader Rd *bufio.Reader
Buf []byte Buf []byte
Inited bool
UsedAt time.Time UsedAt time.Time
ReadTimeout time.Duration ReadTimeout time.Duration
WriteTimeout time.Duration WriteTimeout time.Duration
} }
@ -40,8 +42,12 @@ func (cn *Conn) Index() int {
return int(atomic.LoadInt32(&cn.idx)) return int(atomic.LoadInt32(&cn.idx))
} }
func (cn *Conn) SetIndex(idx int) { func (cn *Conn) SetIndex(newIdx int) int {
atomic.StoreInt32(&cn.idx, int32(idx)) oldIdx := cn.Index()
if !atomic.CompareAndSwapInt32(&cn.idx, int32(oldIdx), int32(newIdx)) {
return -1
}
return oldIdx
} }
func (cn *Conn) IsStale(timeout time.Duration) bool { func (cn *Conn) IsStale(timeout time.Duration) bool {
@ -72,11 +78,6 @@ func (cn *Conn) RemoteAddr() net.Addr {
return cn.NetConn.RemoteAddr() return cn.NetConn.RemoteAddr()
} }
func (cn *Conn) Close() int { func (cn *Conn) Close() error {
idx := cn.Index() return cn.NetConn.Close()
if !atomic.CompareAndSwapInt32(&cn.idx, int32(idx), -1) {
return -1
}
_ = cn.NetConn.Close()
return idx
} }

View File

@ -43,7 +43,7 @@ func (l *connList) Add(cn *Conn) {
l.mu.Lock() l.mu.Lock()
for i, c := range l.cns { for i, c := range l.cns {
if c == nil { if c == nil {
cn.SetIndex(i) cn.idx = int32(i)
l.cns[i] = cn l.cns[i] = cn
l.mu.Unlock() l.mu.Unlock()
return return
@ -76,6 +76,7 @@ func (l *connList) Close() error {
if c == nil { if c == nil {
continue continue
} }
c.idx = -1
c.Close() c.Close()
} }
l.cns = nil l.cns = nil

View File

@ -32,7 +32,7 @@ type PoolStats struct {
type Pooler interface { type Pooler interface {
First() *Conn First() *Conn
Get() (*Conn, bool, error) Get() (*Conn, error)
Put(*Conn) error Put(*Conn) error
Replace(*Conn, error) error Replace(*Conn, error) error
Len() int Len() int
@ -146,7 +146,7 @@ func (p *ConnPool) dial() (net.Conn, error) {
return cn, nil return cn, nil
} }
func (p *ConnPool) newConn() (*Conn, error) { func (p *ConnPool) NewConn() (*Conn, error) {
netConn, err := p.dial() netConn, err := p.dial()
if err != nil { if err != nil {
return nil, err return nil, err
@ -155,42 +155,38 @@ func (p *ConnPool) newConn() (*Conn, error) {
} }
// Get returns existed connection from the pool or creates a new one. // Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) { func (p *ConnPool) Get() (*Conn, error) {
if p.Closed() { if p.Closed() {
err = ErrClosed return nil, ErrClosed
return
} }
atomic.AddUint32(&p.stats.Requests, 1) atomic.AddUint32(&p.stats.Requests, 1)
// Fetch first non-idle connection, if available. // Fetch first non-idle connection, if available.
if cn = p.First(); cn != nil { if cn := p.First(); cn != nil {
atomic.AddUint32(&p.stats.Hits, 1) atomic.AddUint32(&p.stats.Hits, 1)
return return cn, nil
} }
// Try to create a new one. // Try to create a new one.
if p.conns.Reserve() { if p.conns.Reserve() {
isNew = true cn, err := p.NewConn()
cn, err = p.newConn()
if err != nil { if err != nil {
p.conns.CancelReservation() p.conns.CancelReservation()
return return nil, err
} }
p.conns.Add(cn) p.conns.Add(cn)
return return cn, nil
} }
// Otherwise, wait for the available connection. // Otherwise, wait for the available connection.
atomic.AddUint32(&p.stats.Waits, 1) atomic.AddUint32(&p.stats.Waits, 1)
if cn = p.wait(); cn != nil { if cn := p.wait(); cn != nil {
return return cn, nil
} }
atomic.AddUint32(&p.stats.Timeouts, 1) atomic.AddUint32(&p.stats.Timeouts, 1)
err = ErrPoolTimeout return nil, ErrPoolTimeout
return
} }
func (p *ConnPool) Put(cn *Conn) error { func (p *ConnPool) Put(cn *Conn) error {
@ -205,7 +201,9 @@ func (p *ConnPool) Put(cn *Conn) error {
} }
func (p *ConnPool) replace(cn *Conn) (*Conn, error) { func (p *ConnPool) replace(cn *Conn) (*Conn, error) {
idx := cn.Close() _ = cn.Close()
idx := cn.SetIndex(-1)
if idx == -1 { if idx == -1 {
return nil, errConnClosed return nil, errConnClosed
} }
@ -236,7 +234,9 @@ func (p *ConnPool) Replace(cn *Conn, reason error) error {
} }
func (p *ConnPool) Remove(cn *Conn, reason error) error { func (p *ConnPool) Remove(cn *Conn, reason error) error {
idx := cn.Close() _ = cn.Close()
idx := cn.SetIndex(-1)
if idx == -1 { if idx == -1 {
return errConnClosed return errConnClosed
} }

View File

@ -16,8 +16,8 @@ func (p *SingleConnPool) First() *Conn {
return p.cn return p.cn
} }
func (p *SingleConnPool) Get() (*Conn, bool, error) { func (p *SingleConnPool) Get() (*Conn, error) {
return p.cn, false, nil return p.cn, nil
} }
func (p *SingleConnPool) Put(cn *Conn) error { func (p *SingleConnPool) Put(cn *Conn) error {

View File

@ -30,25 +30,23 @@ func (p *StickyConnPool) First() *Conn {
return cn return cn
} }
func (p *StickyConnPool) Get() (cn *Conn, isNew bool, err error) { func (p *StickyConnPool) Get() (*Conn, error) {
defer p.mx.Unlock() defer p.mx.Unlock()
p.mx.Lock() p.mx.Lock()
if p.closed { if p.closed {
err = ErrClosed return nil, ErrClosed
return
} }
if p.cn != nil { if p.cn != nil {
cn = p.cn return p.cn, nil
return
} }
cn, isNew, err = p.pool.Get() cn, err := p.pool.Get()
if err != nil { if err != nil {
return return nil, err
} }
p.cn = cn p.cn = cn
return return cn, nil
} }
func (p *StickyConnPool) put() (err error) { func (p *StickyConnPool) put() (err error) {

View File

@ -69,9 +69,8 @@ var _ = Describe("conns reapser", func() {
cn := connPool.First() cn := connPool.First()
Expect(cn).To(BeNil()) Expect(cn).To(BeNil())
cn, isNew, err := connPool.Get() cn, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(isNew).To(BeTrue())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
Expect(connPool.Len()).To(Equal(4)) Expect(connPool.Len()).To(Equal(4))

View File

@ -128,7 +128,7 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) {
// Strip MULTI and EXEC commands. // Strip MULTI and EXEC commands.
retCmds := cmds[1 : len(cmds)-1] retCmds := cmds[1 : len(cmds)-1]
cn, _, err := c.base.conn() cn, err := c.base.conn()
if err != nil { if err != nil {
setCmdsErr(retCmds, err) setCmdsErr(retCmds, err)
return retCmds, err return retCmds, err

View File

@ -142,7 +142,7 @@ var _ = Describe("Multi", func() {
It("should recover from bad connection", func() { It("should recover from bad connection", func() {
// Put bad connection in the pool. // Put bad connection in the pool.
cn, _, err := client.Pool().Get() cn, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{} cn.NetConn = &badConn{}
@ -169,7 +169,7 @@ var _ = Describe("Multi", func() {
It("should recover from bad connection when there are no commands", func() { It("should recover from bad connection when there are no commands", func() {
// Put bad connection in the pool. // Put bad connection in the pool.
cn, _, err := client.Pool().Get() cn, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{} cn.NetConn = &badConn{}

View File

@ -90,7 +90,7 @@ func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) {
failedCmds := cmds failedCmds := cmds
for i := 0; i <= pipe.client.opt.MaxRetries; i++ { for i := 0; i <= pipe.client.opt.MaxRetries; i++ {
cn, _, err := pipe.client.conn() cn, err := pipe.client.conn()
if err != nil { if err != nil {
setCmdsErr(failedCmds, err) setCmdsErr(failedCmds, err)
return cmds, err return cmds, err

View File

@ -91,7 +91,7 @@ var _ = Describe("pool", func() {
}) })
It("should remove broken connections", func() { It("should remove broken connections", func() {
cn, _, err := client.Pool().Get() cn, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{} cn.NetConn = &badConn{}
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())
@ -136,12 +136,12 @@ var _ = Describe("pool", func() {
pool := client.Pool() pool := client.Pool()
// Reserve one connection. // Reserve one connection.
cn, _, err := pool.Get() cn, err := pool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Reserve the rest of connections. // Reserve the rest of connections.
for i := 0; i < 9; i++ { for i := 0; i < 9; i++ {
_, _, err := pool.Get() _, err := pool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
@ -181,7 +181,7 @@ var _ = Describe("pool", func() {
var rateErr error var rateErr error
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
cn, _, err := pool.Get() cn, err := pool.Get()
if err != nil { if err != nil {
rateErr = err rateErr = err
break break

View File

@ -50,7 +50,7 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) {
} }
func (c *PubSub) subscribe(redisCmd string, channels ...string) error { func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
cn, _, err := c.base.conn() cn, err := c.base.conn()
if err != nil { if err != nil {
return err return err
} }
@ -126,7 +126,7 @@ func (c *PubSub) Close() error {
} }
func (c *PubSub) Ping(payload string) error { func (c *PubSub) Ping(payload string) error {
cn, _, err := c.base.conn() cn, err := c.base.conn()
if err != nil { if err != nil {
return err return err
} }
@ -226,7 +226,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
c.resubscribe() c.resubscribe()
} }
cn, _, err := c.base.conn() cn, err := c.base.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -289,7 +289,7 @@ var _ = Describe("PubSub", func() {
}) })
expectReceiveMessageOnError := func(pubsub *redis.PubSub) { expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
cn1, _, err := pubsub.Pool().Get() cn1, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn1.NetConn = &badConn{ cn1.NetConn = &badConn{
readErr: io.EOF, readErr: io.EOF,

View File

@ -32,15 +32,18 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB) return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB)
} }
func (c *baseClient) conn() (*pool.Conn, bool, error) { func (c *baseClient) conn() (*pool.Conn, error) {
cn, isNew, err := c.connPool.Get() cn, err := c.connPool.Get()
if err == nil && isNew {
err = c.initConn(cn)
if err != nil { if err != nil {
c.putConn(cn, err, false) return nil, err
}
if !cn.Inited {
if err := c.initConn(cn); err != nil {
_ = c.connPool.Replace(cn, err)
return nil, err
} }
} }
return cn, isNew, err return cn, err
} }
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
@ -54,6 +57,8 @@ func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
} }
func (c *baseClient) initConn(cn *pool.Conn) error { func (c *baseClient) initConn(cn *pool.Conn) error {
cn.Inited = true
if c.opt.Password == "" && c.opt.DB == 0 { if c.opt.Password == "" && c.opt.DB == 0 {
return nil return nil
} }
@ -82,7 +87,7 @@ func (c *baseClient) process(cmd Cmder) {
cmd.reset() cmd.reset()
} }
cn, _, err := c.conn() cn, err := c.conn()
if err != nil { if err != nil {
cmd.setErr(err) cmd.setErr(err)
return return

View File

@ -157,7 +157,7 @@ var _ = Describe("Client", func() {
}) })
// Put bad connection in the pool. // Put bad connection in the pool.
cn, _, err := client.Pool().Get() cn, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{} cn.NetConn = &badConn{}
@ -169,7 +169,7 @@ var _ = Describe("Client", func() {
}) })
It("should maintain conn.UsedAt", func() { It("should maintain conn.UsedAt", func() {
cn, _, err := client.Pool().Get() cn, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn.UsedAt).NotTo(BeZero()) Expect(cn.UsedAt).NotTo(BeZero())
createdAt := cn.UsedAt createdAt := cn.UsedAt

View File

@ -314,7 +314,7 @@ func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) {
for name, cmds := range cmdsMap { for name, cmds := range cmdsMap {
client := pipe.ring.shards[name].Client client := pipe.ring.shards[name].Client
cn, _, err := client.conn() cn, err := client.conn()
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
if retErr == nil { if retErr == nil {