Merge pull request from go-redis/fix/optimize-pool-remove

Optimize pool.Remove.
This commit is contained in:
Vladimir Mihailenco 2016-03-12 12:56:13 +02:00
commit 7af992dd83
13 changed files with 126 additions and 128 deletions

View File

@ -278,8 +278,8 @@ func BenchmarkZAdd(b *testing.B) {
}
func benchmarkPoolGetPut(b *testing.B, poolSize int) {
dial := func() (*pool.Conn, error) {
return pool.NewConn(&net.TCPConn{}), nil
dial := func() (net.Conn, error) {
return &net.TCPConn{}, nil
}
pool := pool.NewConnPool(dial, poolSize, time.Second, 0)
@ -311,8 +311,8 @@ func BenchmarkPoolGetPut1000Conns(b *testing.B) {
}
func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
dial := func() (*pool.Conn, error) {
return pool.NewConn(&net.TCPConn{}), nil
dial := func() (net.Conn, error) {
return &net.TCPConn{}, nil
}
pool := pool.NewConnPool(dial, poolSize, time.Second, 0)
removeReason := errors.New("benchmark")
@ -325,7 +325,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
if err != nil {
b.Fatalf("no error expected on pool.Get but received: %s", err.Error())
}
if err = pool.Remove(conn, removeReason); err != nil {
if err = pool.Replace(conn, removeReason); err != nil {
b.Fatalf("no error expected on pool.Remove but received: %s", err.Error())
}
}

View File

@ -8,10 +8,12 @@ import (
const defaultBufSize = 4096
var noTimeout = time.Time{}
var noDeadline = time.Time{}
type Conn struct {
NetConn net.Conn
idx int
netConn net.Conn
Rd *bufio.Reader
Buf []byte
@ -22,7 +24,9 @@ type Conn struct {
func NewConn(netConn net.Conn) *Conn {
cn := &Conn{
NetConn: netConn,
idx: -1,
netConn: netConn,
Buf: make([]byte, defaultBufSize),
UsedAt: time.Now(),
@ -31,30 +35,35 @@ func NewConn(netConn net.Conn) *Conn {
return cn
}
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
cn.UsedAt = time.Now()
}
func (cn *Conn) Read(b []byte) (int, error) {
cn.UsedAt = time.Now()
if cn.ReadTimeout != 0 {
cn.NetConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout))
cn.netConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout))
} else {
cn.NetConn.SetReadDeadline(noTimeout)
cn.netConn.SetReadDeadline(noDeadline)
}
return cn.NetConn.Read(b)
return cn.netConn.Read(b)
}
func (cn *Conn) Write(b []byte) (int, error) {
cn.UsedAt = time.Now()
if cn.WriteTimeout != 0 {
cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout))
cn.netConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout))
} else {
cn.NetConn.SetWriteDeadline(noTimeout)
cn.netConn.SetWriteDeadline(noDeadline)
}
return cn.NetConn.Write(b)
return cn.netConn.Write(b)
}
func (cn *Conn) RemoteAddr() net.Addr {
return cn.NetConn.RemoteAddr()
return cn.netConn.RemoteAddr()
}
func (cn *Conn) Close() error {
return cn.NetConn.Close()
return cn.netConn.Close()
}

View File

@ -7,14 +7,14 @@ import (
type connList struct {
cns []*Conn
mx sync.Mutex
mu sync.Mutex
len int32 // atomic
size int32
}
func newConnList(size int) *connList {
return &connList{
cns: make([]*Conn, 0, size),
cns: make([]*Conn, size),
size: int32(size),
}
}
@ -23,8 +23,8 @@ func (l *connList) Len() int {
return int(atomic.LoadInt32(&l.len))
}
// Reserve reserves place in the list and returns true on success. The
// caller must add or remove connection if place was reserved.
// Reserve reserves place in the list and returns true on success.
// The caller must add or remove connection if place was reserved.
func (l *connList) Reserve() bool {
len := atomic.AddInt32(&l.len, 1)
reserved := len <= l.size
@ -36,65 +36,49 @@ func (l *connList) Reserve() bool {
// Add adds connection to the list. The caller must reserve place first.
func (l *connList) Add(cn *Conn) {
l.mx.Lock()
l.cns = append(l.cns, cn)
l.mx.Unlock()
l.mu.Lock()
for i, c := range l.cns {
if c == nil {
cn.idx = i
l.cns[i] = cn
l.mu.Unlock()
return
}
}
panic("not reached")
}
// Remove closes connection and removes it from the list.
func (l *connList) Remove(cn *Conn) error {
defer l.mx.Unlock()
l.mx.Lock()
if cn == nil {
atomic.AddInt32(&l.len, -1)
if cn == nil { // free reserved place
return nil
}
for i, c := range l.cns {
if c == cn {
l.cns = append(l.cns[:i], l.cns[i+1:]...)
atomic.AddInt32(&l.len, -1)
return cn.Close()
}
l.mu.Lock()
if l.cns != nil {
l.cns[cn.idx] = nil
cn.idx = -1
}
l.mu.Unlock()
if l.closed() {
return nil
}
panic("conn not found in the list")
}
func (l *connList) Replace(cn, newcn *Conn) error {
defer l.mx.Unlock()
l.mx.Lock()
for i, c := range l.cns {
if c == cn {
l.cns[i] = newcn
return cn.Close()
}
}
if l.closed() {
return newcn.Close()
}
panic("conn not found in the list")
}
func (l *connList) Close() (retErr error) {
l.mx.Lock()
func (l *connList) Close() error {
var retErr error
l.mu.Lock()
for _, c := range l.cns {
if err := c.Close(); err != nil {
if c == nil {
continue
}
if err := c.Close(); err != nil && retErr == nil {
retErr = err
}
}
l.cns = nil
atomic.StoreInt32(&l.len, 0)
l.mx.Unlock()
l.mu.Unlock()
return retErr
}
func (l *connList) closed() bool {
return l.cns == nil
}

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"log"
"net"
"sync/atomic"
"time"
@ -32,17 +33,17 @@ type Pooler interface {
First() *Conn
Get() (*Conn, bool, error)
Put(*Conn) error
Remove(*Conn, error) error
Replace(*Conn, error) error
Len() int
FreeLen() int
Close() error
Stats() *PoolStats
}
type dialer func() (*Conn, error)
type dialer func() (net.Conn, error)
type ConnPool struct {
dial dialer
_dial dialer
poolTimeout time.Duration
idleTimeout time.Duration
@ -59,7 +60,7 @@ type ConnPool struct {
func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Duration) *ConnPool {
p := &ConnPool{
dial: dial,
_dial: dial,
poolTimeout: poolTimeout,
idleTimeout: idleTimeout,
@ -126,8 +127,7 @@ func (p *ConnPool) wait() *Conn {
panic("not reached")
}
// Establish a new connection
func (p *ConnPool) new() (*Conn, error) {
func (p *ConnPool) dial() (net.Conn, error) {
if p.rl.Limit() {
err := fmt.Errorf(
"redis: you open connections too fast (last_error=%q)",
@ -136,15 +136,22 @@ func (p *ConnPool) new() (*Conn, error) {
return nil, err
}
cn, err := p.dial()
cn, err := p._dial()
if err != nil {
p.storeLastErr(err.Error())
return nil, err
}
return cn, nil
}
func (p *ConnPool) newConn() (*Conn, error) {
netConn, err := p.dial()
if err != nil {
return nil, err
}
return NewConn(netConn), nil
}
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) {
if p.closed() {
@ -164,7 +171,7 @@ func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) {
if p.conns.Reserve() {
isNew = true
cn, err = p.new()
cn, err = p.newConn()
if err != nil {
p.conns.Remove(nil)
return
@ -189,23 +196,26 @@ func (p *ConnPool) Put(cn *Conn) error {
b, _ := cn.Rd.Peek(cn.Rd.Buffered())
err := fmt.Errorf("connection has unread data: %q", b)
Logger.Print(err)
return p.Remove(cn, err)
return p.Replace(cn, err)
}
p.freeConns <- cn
return nil
}
func (p *ConnPool) replace(cn *Conn) (*Conn, error) {
newcn, err := p.new()
_ = cn.Close()
netConn, err := p.dial()
if err != nil {
_ = p.conns.Remove(cn)
return nil, err
}
_ = p.conns.Replace(cn, newcn)
return newcn, nil
cn.SetNetConn(netConn)
return cn, nil
}
func (p *ConnPool) Remove(cn *Conn, reason error) error {
func (p *ConnPool) Replace(cn *Conn, reason error) error {
p.storeLastErr(reason.Error())
// Replace existing connection with new one and unblock waiter.

View File

@ -25,7 +25,7 @@ func (p *SingleConnPool) Put(cn *Conn) error {
return nil
}
func (p *SingleConnPool) Remove(cn *Conn, _ error) error {
func (p *SingleConnPool) Replace(cn *Conn, _ error) error {
if p.cn != cn {
panic("p.cn != cn")
}

View File

@ -67,13 +67,13 @@ func (p *StickyConnPool) Put(cn *Conn) error {
return nil
}
func (p *StickyConnPool) remove(reason error) error {
err := p.pool.Remove(p.cn, reason)
func (p *StickyConnPool) replace(reason error) error {
err := p.pool.Replace(p.cn, reason)
p.cn = nil
return err
}
func (p *StickyConnPool) Remove(cn *Conn, reason error) error {
func (p *StickyConnPool) Replace(cn *Conn, reason error) error {
defer p.mx.Unlock()
p.mx.Lock()
if p.closed {
@ -85,7 +85,7 @@ func (p *StickyConnPool) Remove(cn *Conn, reason error) error {
if cn != nil && p.cn != cn {
panic("p.cn != cn")
}
return p.remove(reason)
return p.replace(reason)
}
func (p *StickyConnPool) Len() int {
@ -121,7 +121,7 @@ func (p *StickyConnPool) Close() error {
err = p.put()
} else {
reason := errors.New("redis: sticky not reusable connection")
err = p.remove(reason)
err = p.replace(reason)
}
}
return err

View File

@ -145,7 +145,7 @@ var _ = Describe("Multi", func() {
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{}
cn.SetNetConn(&badConn{})
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())
@ -172,7 +172,7 @@ var _ = Describe("Multi", func() {
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{}
cn.SetNetConn(&badConn{})
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())

View File

@ -67,18 +67,6 @@ func (opt *Options) getDialer() func() (net.Conn, error) {
return opt.Dialer
}
func (opt *Options) getPoolDialer() func() (*pool.Conn, error) {
dial := opt.getDialer()
return func() (*pool.Conn, error) {
netcn, err := dial()
if err != nil {
return nil, err
}
cn := pool.NewConn(netcn)
return cn, opt.initConn(cn)
}
}
func (opt *Options) getPoolSize() int {
if opt.PoolSize == 0 {
return 10
@ -104,32 +92,9 @@ func (opt *Options) getIdleTimeout() time.Duration {
return opt.IdleTimeout
}
func (opt *Options) initConn(cn *pool.Conn) error {
if opt.Password == "" && opt.DB == 0 {
return nil
}
// Temp client for Auth and Select.
client := newClient(opt, pool.NewSingleConnPool(cn))
if opt.Password != "" {
if err := client.Auth(opt.Password).Err(); err != nil {
return err
}
}
if opt.DB > 0 {
if err := client.Select(opt.DB).Err(); err != nil {
return err
}
}
return nil
}
func newConnPool(opt *Options) *pool.ConnPool {
return pool.NewConnPool(
opt.getPoolDialer(), opt.getPoolSize(), opt.getPoolTimeout(), opt.getIdleTimeout())
opt.getDialer(), opt.getPoolSize(), opt.getPoolTimeout(), opt.getIdleTimeout())
}
// PoolStats contains pool state information and accumulated stats.

View File

@ -179,7 +179,7 @@ var _ = Describe("pool", func() {
// ok
}
err = pool.Remove(cn, errors.New("test"))
err = pool.Replace(cn, errors.New("test"))
Expect(err).NotTo(HaveOccurred())
// Check that Ping is unblocked.
@ -203,7 +203,7 @@ var _ = Describe("pool", func() {
break
}
_ = pool.Remove(cn, errors.New("test"))
_ = pool.Replace(cn, errors.New("test"))
}
Expect(rateErr).To(MatchError(`redis: you open connections too fast (last_error="test")`))

View File

@ -291,10 +291,10 @@ var _ = Describe("PubSub", func() {
expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
cn1, _, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn1.NetConn = &badConn{
cn1.SetNetConn(&badConn{
readErr: io.EOF,
writeErr: io.EOF,
}
})
done := make(chan bool, 1)
go func() {

View File

@ -33,12 +33,19 @@ func (c *baseClient) String() string {
}
func (c *baseClient) conn() (*pool.Conn, bool, error) {
return c.connPool.Get()
cn, isNew, err := c.connPool.Get()
if err == nil && isNew {
err = c.initConn(cn)
if err != nil {
c.putConn(cn, err, false)
}
}
return cn, isNew, err
}
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
if isBadConn(err, allowTimeout) {
err = c.connPool.Remove(cn, err)
err = c.connPool.Replace(cn, err)
if err != nil {
Logger.Printf("pool.Remove failed: %s", err)
}
@ -52,6 +59,29 @@ func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
return true
}
func (c *baseClient) initConn(cn *pool.Conn) error {
if c.opt.Password == "" && c.opt.DB == 0 {
return nil
}
// Temp client for Auth and Select.
client := newClient(c.opt, pool.NewSingleConnPool(cn))
if c.opt.Password != "" {
if err := client.Auth(c.opt.Password).Err(); err != nil {
return err
}
}
if c.opt.DB > 0 {
if err := client.Select(c.opt.DB).Err(); err != nil {
return err
}
}
return nil
}
func (c *baseClient) process(cmd Cmder) {
for i := 0; i <= c.opt.MaxRetries; i++ {
if i > 0 {

View File

@ -160,7 +160,7 @@ var _ = Describe("Client", func() {
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{}
cn.SetNetConn(&badConn{})
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())

View File

@ -267,7 +267,7 @@ func (d *sentinelFailover) closeOldConns(newMaster string) {
cn.RemoteAddr(),
)
Logger.Print(err)
d.pool.Remove(cn, err)
d.pool.Replace(cn, err)
} else {
cnsToPut = append(cnsToPut, cn)
}