Merge pull request #276 from go-redis/fix/optimize-pool-remove

Optimize pool.Remove.
This commit is contained in:
Vladimir Mihailenco 2016-03-12 12:56:13 +02:00
commit 7af992dd83
13 changed files with 126 additions and 128 deletions

View File

@ -278,8 +278,8 @@ func BenchmarkZAdd(b *testing.B) {
} }
func benchmarkPoolGetPut(b *testing.B, poolSize int) { func benchmarkPoolGetPut(b *testing.B, poolSize int) {
dial := func() (*pool.Conn, error) { dial := func() (net.Conn, error) {
return pool.NewConn(&net.TCPConn{}), nil return &net.TCPConn{}, nil
} }
pool := pool.NewConnPool(dial, poolSize, time.Second, 0) pool := pool.NewConnPool(dial, poolSize, time.Second, 0)
@ -311,8 +311,8 @@ func BenchmarkPoolGetPut1000Conns(b *testing.B) {
} }
func benchmarkPoolGetRemove(b *testing.B, poolSize int) { func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
dial := func() (*pool.Conn, error) { dial := func() (net.Conn, error) {
return pool.NewConn(&net.TCPConn{}), nil return &net.TCPConn{}, nil
} }
pool := pool.NewConnPool(dial, poolSize, time.Second, 0) pool := pool.NewConnPool(dial, poolSize, time.Second, 0)
removeReason := errors.New("benchmark") removeReason := errors.New("benchmark")
@ -325,7 +325,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
if err != nil { if err != nil {
b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) b.Fatalf("no error expected on pool.Get but received: %s", err.Error())
} }
if err = pool.Remove(conn, removeReason); err != nil { if err = pool.Replace(conn, removeReason); err != nil {
b.Fatalf("no error expected on pool.Remove but received: %s", err.Error()) b.Fatalf("no error expected on pool.Remove but received: %s", err.Error())
} }
} }

View File

@ -8,10 +8,12 @@ import (
const defaultBufSize = 4096 const defaultBufSize = 4096
var noTimeout = time.Time{} var noDeadline = time.Time{}
type Conn struct { type Conn struct {
NetConn net.Conn idx int
netConn net.Conn
Rd *bufio.Reader Rd *bufio.Reader
Buf []byte Buf []byte
@ -22,7 +24,9 @@ type Conn struct {
func NewConn(netConn net.Conn) *Conn { func NewConn(netConn net.Conn) *Conn {
cn := &Conn{ cn := &Conn{
NetConn: netConn, idx: -1,
netConn: netConn,
Buf: make([]byte, defaultBufSize), Buf: make([]byte, defaultBufSize),
UsedAt: time.Now(), UsedAt: time.Now(),
@ -31,30 +35,35 @@ func NewConn(netConn net.Conn) *Conn {
return cn return cn
} }
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
cn.UsedAt = time.Now()
}
func (cn *Conn) Read(b []byte) (int, error) { func (cn *Conn) Read(b []byte) (int, error) {
cn.UsedAt = time.Now() cn.UsedAt = time.Now()
if cn.ReadTimeout != 0 { if cn.ReadTimeout != 0 {
cn.NetConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout)) cn.netConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout))
} else { } else {
cn.NetConn.SetReadDeadline(noTimeout) cn.netConn.SetReadDeadline(noDeadline)
} }
return cn.NetConn.Read(b) return cn.netConn.Read(b)
} }
func (cn *Conn) Write(b []byte) (int, error) { func (cn *Conn) Write(b []byte) (int, error) {
cn.UsedAt = time.Now() cn.UsedAt = time.Now()
if cn.WriteTimeout != 0 { if cn.WriteTimeout != 0 {
cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout)) cn.netConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout))
} else { } else {
cn.NetConn.SetWriteDeadline(noTimeout) cn.netConn.SetWriteDeadline(noDeadline)
} }
return cn.NetConn.Write(b) return cn.netConn.Write(b)
} }
func (cn *Conn) RemoteAddr() net.Addr { func (cn *Conn) RemoteAddr() net.Addr {
return cn.NetConn.RemoteAddr() return cn.netConn.RemoteAddr()
} }
func (cn *Conn) Close() error { func (cn *Conn) Close() error {
return cn.NetConn.Close() return cn.netConn.Close()
} }

View File

@ -7,14 +7,14 @@ import (
type connList struct { type connList struct {
cns []*Conn cns []*Conn
mx sync.Mutex mu sync.Mutex
len int32 // atomic len int32 // atomic
size int32 size int32
} }
func newConnList(size int) *connList { func newConnList(size int) *connList {
return &connList{ return &connList{
cns: make([]*Conn, 0, size), cns: make([]*Conn, size),
size: int32(size), size: int32(size),
} }
} }
@ -23,8 +23,8 @@ func (l *connList) Len() int {
return int(atomic.LoadInt32(&l.len)) return int(atomic.LoadInt32(&l.len))
} }
// Reserve reserves place in the list and returns true on success. The // Reserve reserves place in the list and returns true on success.
// caller must add or remove connection if place was reserved. // The caller must add or remove connection if place was reserved.
func (l *connList) Reserve() bool { func (l *connList) Reserve() bool {
len := atomic.AddInt32(&l.len, 1) len := atomic.AddInt32(&l.len, 1)
reserved := len <= l.size reserved := len <= l.size
@ -36,65 +36,49 @@ func (l *connList) Reserve() bool {
// Add adds connection to the list. The caller must reserve place first. // Add adds connection to the list. The caller must reserve place first.
func (l *connList) Add(cn *Conn) { func (l *connList) Add(cn *Conn) {
l.mx.Lock() l.mu.Lock()
l.cns = append(l.cns, cn) for i, c := range l.cns {
l.mx.Unlock() if c == nil {
cn.idx = i
l.cns[i] = cn
l.mu.Unlock()
return
}
}
panic("not reached")
} }
// Remove closes connection and removes it from the list. // Remove closes connection and removes it from the list.
func (l *connList) Remove(cn *Conn) error { func (l *connList) Remove(cn *Conn) error {
defer l.mx.Unlock()
l.mx.Lock()
if cn == nil {
atomic.AddInt32(&l.len, -1) atomic.AddInt32(&l.len, -1)
if cn == nil { // free reserved place
return nil return nil
} }
for i, c := range l.cns { l.mu.Lock()
if c == cn { if l.cns != nil {
l.cns = append(l.cns[:i], l.cns[i+1:]...) l.cns[cn.idx] = nil
atomic.AddInt32(&l.len, -1) cn.idx = -1
return cn.Close()
}
} }
l.mu.Unlock()
if l.closed() {
return nil return nil
} }
panic("conn not found in the list")
}
func (l *connList) Replace(cn, newcn *Conn) error { func (l *connList) Close() error {
defer l.mx.Unlock() var retErr error
l.mx.Lock() l.mu.Lock()
for i, c := range l.cns {
if c == cn {
l.cns[i] = newcn
return cn.Close()
}
}
if l.closed() {
return newcn.Close()
}
panic("conn not found in the list")
}
func (l *connList) Close() (retErr error) {
l.mx.Lock()
for _, c := range l.cns { for _, c := range l.cns {
if err := c.Close(); err != nil { if c == nil {
continue
}
if err := c.Close(); err != nil && retErr == nil {
retErr = err retErr = err
} }
} }
l.cns = nil l.cns = nil
atomic.StoreInt32(&l.len, 0) atomic.StoreInt32(&l.len, 0)
l.mx.Unlock() l.mu.Unlock()
return retErr return retErr
} }
func (l *connList) closed() bool {
return l.cns == nil
}

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"net"
"sync/atomic" "sync/atomic"
"time" "time"
@ -32,17 +33,17 @@ type Pooler interface {
First() *Conn First() *Conn
Get() (*Conn, bool, error) Get() (*Conn, bool, error)
Put(*Conn) error Put(*Conn) error
Remove(*Conn, error) error Replace(*Conn, error) error
Len() int Len() int
FreeLen() int FreeLen() int
Close() error Close() error
Stats() *PoolStats Stats() *PoolStats
} }
type dialer func() (*Conn, error) type dialer func() (net.Conn, error)
type ConnPool struct { type ConnPool struct {
dial dialer _dial dialer
poolTimeout time.Duration poolTimeout time.Duration
idleTimeout time.Duration idleTimeout time.Duration
@ -59,7 +60,7 @@ type ConnPool struct {
func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Duration) *ConnPool { func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Duration) *ConnPool {
p := &ConnPool{ p := &ConnPool{
dial: dial, _dial: dial,
poolTimeout: poolTimeout, poolTimeout: poolTimeout,
idleTimeout: idleTimeout, idleTimeout: idleTimeout,
@ -126,8 +127,7 @@ func (p *ConnPool) wait() *Conn {
panic("not reached") panic("not reached")
} }
// Establish a new connection func (p *ConnPool) dial() (net.Conn, error) {
func (p *ConnPool) new() (*Conn, error) {
if p.rl.Limit() { if p.rl.Limit() {
err := fmt.Errorf( err := fmt.Errorf(
"redis: you open connections too fast (last_error=%q)", "redis: you open connections too fast (last_error=%q)",
@ -136,15 +136,22 @@ func (p *ConnPool) new() (*Conn, error) {
return nil, err return nil, err
} }
cn, err := p.dial() cn, err := p._dial()
if err != nil { if err != nil {
p.storeLastErr(err.Error()) p.storeLastErr(err.Error())
return nil, err return nil, err
} }
return cn, nil return cn, nil
} }
func (p *ConnPool) newConn() (*Conn, error) {
netConn, err := p.dial()
if err != nil {
return nil, err
}
return NewConn(netConn), nil
}
// Get returns existed connection from the pool or creates a new one. // Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) { func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) {
if p.closed() { if p.closed() {
@ -164,7 +171,7 @@ func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) {
if p.conns.Reserve() { if p.conns.Reserve() {
isNew = true isNew = true
cn, err = p.new() cn, err = p.newConn()
if err != nil { if err != nil {
p.conns.Remove(nil) p.conns.Remove(nil)
return return
@ -189,23 +196,26 @@ func (p *ConnPool) Put(cn *Conn) error {
b, _ := cn.Rd.Peek(cn.Rd.Buffered()) b, _ := cn.Rd.Peek(cn.Rd.Buffered())
err := fmt.Errorf("connection has unread data: %q", b) err := fmt.Errorf("connection has unread data: %q", b)
Logger.Print(err) Logger.Print(err)
return p.Remove(cn, err) return p.Replace(cn, err)
} }
p.freeConns <- cn p.freeConns <- cn
return nil return nil
} }
func (p *ConnPool) replace(cn *Conn) (*Conn, error) { func (p *ConnPool) replace(cn *Conn) (*Conn, error) {
newcn, err := p.new() _ = cn.Close()
netConn, err := p.dial()
if err != nil { if err != nil {
_ = p.conns.Remove(cn) _ = p.conns.Remove(cn)
return nil, err return nil, err
} }
_ = p.conns.Replace(cn, newcn) cn.SetNetConn(netConn)
return newcn, nil
return cn, nil
} }
func (p *ConnPool) Remove(cn *Conn, reason error) error { func (p *ConnPool) Replace(cn *Conn, reason error) error {
p.storeLastErr(reason.Error()) p.storeLastErr(reason.Error())
// Replace existing connection with new one and unblock waiter. // Replace existing connection with new one and unblock waiter.

View File

@ -25,7 +25,7 @@ func (p *SingleConnPool) Put(cn *Conn) error {
return nil return nil
} }
func (p *SingleConnPool) Remove(cn *Conn, _ error) error { func (p *SingleConnPool) Replace(cn *Conn, _ error) error {
if p.cn != cn { if p.cn != cn {
panic("p.cn != cn") panic("p.cn != cn")
} }

View File

@ -67,13 +67,13 @@ func (p *StickyConnPool) Put(cn *Conn) error {
return nil return nil
} }
func (p *StickyConnPool) remove(reason error) error { func (p *StickyConnPool) replace(reason error) error {
err := p.pool.Remove(p.cn, reason) err := p.pool.Replace(p.cn, reason)
p.cn = nil p.cn = nil
return err return err
} }
func (p *StickyConnPool) Remove(cn *Conn, reason error) error { func (p *StickyConnPool) Replace(cn *Conn, reason error) error {
defer p.mx.Unlock() defer p.mx.Unlock()
p.mx.Lock() p.mx.Lock()
if p.closed { if p.closed {
@ -85,7 +85,7 @@ func (p *StickyConnPool) Remove(cn *Conn, reason error) error {
if cn != nil && p.cn != cn { if cn != nil && p.cn != cn {
panic("p.cn != cn") panic("p.cn != cn")
} }
return p.remove(reason) return p.replace(reason)
} }
func (p *StickyConnPool) Len() int { func (p *StickyConnPool) Len() int {
@ -121,7 +121,7 @@ func (p *StickyConnPool) Close() error {
err = p.put() err = p.put()
} else { } else {
reason := errors.New("redis: sticky not reusable connection") reason := errors.New("redis: sticky not reusable connection")
err = p.remove(reason) err = p.replace(reason)
} }
} }
return err return err

View File

@ -145,7 +145,7 @@ var _ = Describe("Multi", func() {
cn, _, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{} cn.SetNetConn(&badConn{})
err = client.Pool().Put(cn) err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -172,7 +172,7 @@ var _ = Describe("Multi", func() {
cn, _, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{} cn.SetNetConn(&badConn{})
err = client.Pool().Put(cn) err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())

View File

@ -67,18 +67,6 @@ func (opt *Options) getDialer() func() (net.Conn, error) {
return opt.Dialer return opt.Dialer
} }
func (opt *Options) getPoolDialer() func() (*pool.Conn, error) {
dial := opt.getDialer()
return func() (*pool.Conn, error) {
netcn, err := dial()
if err != nil {
return nil, err
}
cn := pool.NewConn(netcn)
return cn, opt.initConn(cn)
}
}
func (opt *Options) getPoolSize() int { func (opt *Options) getPoolSize() int {
if opt.PoolSize == 0 { if opt.PoolSize == 0 {
return 10 return 10
@ -104,32 +92,9 @@ func (opt *Options) getIdleTimeout() time.Duration {
return opt.IdleTimeout return opt.IdleTimeout
} }
func (opt *Options) initConn(cn *pool.Conn) error {
if opt.Password == "" && opt.DB == 0 {
return nil
}
// Temp client for Auth and Select.
client := newClient(opt, pool.NewSingleConnPool(cn))
if opt.Password != "" {
if err := client.Auth(opt.Password).Err(); err != nil {
return err
}
}
if opt.DB > 0 {
if err := client.Select(opt.DB).Err(); err != nil {
return err
}
}
return nil
}
func newConnPool(opt *Options) *pool.ConnPool { func newConnPool(opt *Options) *pool.ConnPool {
return pool.NewConnPool( return pool.NewConnPool(
opt.getPoolDialer(), opt.getPoolSize(), opt.getPoolTimeout(), opt.getIdleTimeout()) opt.getDialer(), opt.getPoolSize(), opt.getPoolTimeout(), opt.getIdleTimeout())
} }
// PoolStats contains pool state information and accumulated stats. // PoolStats contains pool state information and accumulated stats.

View File

@ -179,7 +179,7 @@ var _ = Describe("pool", func() {
// ok // ok
} }
err = pool.Remove(cn, errors.New("test")) err = pool.Replace(cn, errors.New("test"))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Check that Ping is unblocked. // Check that Ping is unblocked.
@ -203,7 +203,7 @@ var _ = Describe("pool", func() {
break break
} }
_ = pool.Remove(cn, errors.New("test")) _ = pool.Replace(cn, errors.New("test"))
} }
Expect(rateErr).To(MatchError(`redis: you open connections too fast (last_error="test")`)) Expect(rateErr).To(MatchError(`redis: you open connections too fast (last_error="test")`))

View File

@ -291,10 +291,10 @@ var _ = Describe("PubSub", func() {
expectReceiveMessageOnError := func(pubsub *redis.PubSub) { expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
cn1, _, err := pubsub.Pool().Get() cn1, _, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn1.NetConn = &badConn{ cn1.SetNetConn(&badConn{
readErr: io.EOF, readErr: io.EOF,
writeErr: io.EOF, writeErr: io.EOF,
} })
done := make(chan bool, 1) done := make(chan bool, 1)
go func() { go func() {

View File

@ -33,12 +33,19 @@ func (c *baseClient) String() string {
} }
func (c *baseClient) conn() (*pool.Conn, bool, error) { func (c *baseClient) conn() (*pool.Conn, bool, error) {
return c.connPool.Get() cn, isNew, err := c.connPool.Get()
if err == nil && isNew {
err = c.initConn(cn)
if err != nil {
c.putConn(cn, err, false)
}
}
return cn, isNew, err
} }
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
if isBadConn(err, allowTimeout) { if isBadConn(err, allowTimeout) {
err = c.connPool.Remove(cn, err) err = c.connPool.Replace(cn, err)
if err != nil { if err != nil {
Logger.Printf("pool.Remove failed: %s", err) Logger.Printf("pool.Remove failed: %s", err)
} }
@ -52,6 +59,29 @@ func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
return true return true
} }
func (c *baseClient) initConn(cn *pool.Conn) error {
if c.opt.Password == "" && c.opt.DB == 0 {
return nil
}
// Temp client for Auth and Select.
client := newClient(c.opt, pool.NewSingleConnPool(cn))
if c.opt.Password != "" {
if err := client.Auth(c.opt.Password).Err(); err != nil {
return err
}
}
if c.opt.DB > 0 {
if err := client.Select(c.opt.DB).Err(); err != nil {
return err
}
}
return nil
}
func (c *baseClient) process(cmd Cmder) { func (c *baseClient) process(cmd Cmder) {
for i := 0; i <= c.opt.MaxRetries; i++ { for i := 0; i <= c.opt.MaxRetries; i++ {
if i > 0 { if i > 0 {

View File

@ -160,7 +160,7 @@ var _ = Describe("Client", func() {
cn, _, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{} cn.SetNetConn(&badConn{})
err = client.Pool().Put(cn) err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())

View File

@ -267,7 +267,7 @@ func (d *sentinelFailover) closeOldConns(newMaster string) {
cn.RemoteAddr(), cn.RemoteAddr(),
) )
Logger.Print(err) Logger.Print(err)
d.pool.Remove(cn, err) d.pool.Replace(cn, err)
} else { } else {
cnsToPut = append(cnsToPut, cn) cnsToPut = append(cnsToPut, cn)
} }