Merge pull request #285 from go-redis/feature/new-pool

Faster and simpler pool.
This commit is contained in:
Vladimir Mihailenco 2016-03-19 13:56:26 +02:00
commit fb6ea09ce3
22 changed files with 418 additions and 492 deletions

View File

@ -16,15 +16,16 @@ import (
type ClusterClient struct { type ClusterClient struct {
commandable commandable
opt *ClusterOptions
slotsMx sync.RWMutex // protects slots and addrs
addrs []string addrs []string
slots [][]string slots [][]string
slotsMx sync.RWMutex // Protects slots and addrs.
clientsMx sync.RWMutex // protects clients and closed
clients map[string]*Client clients map[string]*Client
closed bool
clientsMx sync.RWMutex // Protects clients and closed.
opt *ClusterOptions _closed int32 // atomic
// Reports where slots reloading is in progress. // Reports where slots reloading is in progress.
reloading uint32 reloading uint32
@ -34,17 +35,29 @@ type ClusterClient struct {
// http://redis.io/topics/cluster-spec. // http://redis.io/topics/cluster-spec.
func NewClusterClient(opt *ClusterOptions) *ClusterClient { func NewClusterClient(opt *ClusterOptions) *ClusterClient {
client := &ClusterClient{ client := &ClusterClient{
opt: opt,
addrs: opt.Addrs, addrs: opt.Addrs,
slots: make([][]string, hashtag.SlotNumber), slots: make([][]string, hashtag.SlotNumber),
clients: make(map[string]*Client), clients: make(map[string]*Client),
opt: opt,
} }
client.commandable.process = client.process client.commandable.process = client.process
client.reloadSlots() client.reloadSlots()
go client.reaper()
return client return client
} }
// getClients returns a snapshot of clients for cluster nodes
// this ClusterClient has been working with recently.
// Note that snapshot can contain closed clients.
func (c *ClusterClient) getClients() map[string]*Client {
c.clientsMx.RLock()
clients := make(map[string]*Client, len(c.clients))
for addr, client := range c.clients {
clients[addr] = client
}
c.clientsMx.RUnlock()
return clients
}
// Watch creates new transaction and marks the keys to be watched // Watch creates new transaction and marks the keys to be watched
// for conditional execution of a transaction. // for conditional execution of a transaction.
func (c *ClusterClient) Watch(keys ...string) (*Multi, error) { func (c *ClusterClient) Watch(keys ...string) (*Multi, error) {
@ -59,56 +72,56 @@ func (c *ClusterClient) Watch(keys ...string) (*Multi, error) {
// PoolStats returns accumulated connection pool stats. // PoolStats returns accumulated connection pool stats.
func (c *ClusterClient) PoolStats() *PoolStats { func (c *ClusterClient) PoolStats() *PoolStats {
acc := PoolStats{} acc := PoolStats{}
c.clientsMx.RLock() for _, client := range c.getClients() {
for _, client := range c.clients { s := client.connPool.Stats()
m := client.PoolStats() acc.Requests += s.Requests
acc.Requests += m.Requests acc.Hits += s.Hits
acc.Waits += m.Waits acc.Waits += s.Waits
acc.Timeouts += m.Timeouts acc.Timeouts += s.Timeouts
acc.TotalConns += m.TotalConns acc.TotalConns += s.TotalConns
acc.FreeConns += m.FreeConns acc.FreeConns += s.FreeConns
} }
c.clientsMx.RUnlock()
return &acc return &acc
} }
func (c *ClusterClient) closed() bool {
return atomic.LoadInt32(&c._closed) == 1
}
// Close closes the cluster client, releasing any open resources. // Close closes the cluster client, releasing any open resources.
// //
// It is rare to Close a ClusterClient, as the ClusterClient is meant // It is rare to Close a ClusterClient, as the ClusterClient is meant
// to be long-lived and shared between many goroutines. // to be long-lived and shared between many goroutines.
func (c *ClusterClient) Close() error { func (c *ClusterClient) Close() error {
defer c.clientsMx.Unlock() if !atomic.CompareAndSwapInt32(&c._closed, 0, 1) {
c.clientsMx.Lock()
if c.closed {
return pool.ErrClosed return pool.ErrClosed
} }
c.closed = true
c.clientsMx.Lock()
c.resetClients() c.resetClients()
c.clientsMx.Unlock()
c.setSlots(nil) c.setSlots(nil)
return nil return nil
} }
// getClient returns a Client for a given address. // getClient returns a Client for a given address.
func (c *ClusterClient) getClient(addr string) (*Client, error) { func (c *ClusterClient) getClient(addr string) (*Client, error) {
if c.closed() {
return nil, pool.ErrClosed
}
if addr == "" { if addr == "" {
return c.randomClient() return c.randomClient()
} }
c.clientsMx.RLock() c.clientsMx.RLock()
client, ok := c.clients[addr] client, ok := c.clients[addr]
c.clientsMx.RUnlock()
if ok { if ok {
c.clientsMx.RUnlock()
return client, nil return client, nil
} }
c.clientsMx.RUnlock()
c.clientsMx.Lock() c.clientsMx.Lock()
if c.closed {
c.clientsMx.Unlock()
return nil, pool.ErrClosed
}
client, ok = c.clients[addr] client, ok = c.clients[addr]
if !ok { if !ok {
opt := c.opt.clientOptions() opt := c.opt.clientOptions()
@ -276,28 +289,30 @@ func (c *ClusterClient) lazyReloadSlots() {
} }
// reaper closes idle connections to the cluster. // reaper closes idle connections to the cluster.
func (c *ClusterClient) reaper() { func (c *ClusterClient) reaper(frequency time.Duration) {
ticker := time.NewTicker(time.Minute) ticker := time.NewTicker(frequency)
defer ticker.Stop() defer ticker.Stop()
for _ = range ticker.C {
c.clientsMx.RLock()
if c.closed { for _ = range ticker.C {
c.clientsMx.RUnlock() if c.closed() {
break break
} }
for _, client := range c.clients { var n int
pool := client.connPool for _, client := range c.getClients() {
// pool.First removes idle connections from the pool and nn, err := client.connPool.(*pool.ConnPool).ReapStaleConns()
// returns first non-idle connection. So just put returned if err != nil {
// connection back. Logger.Printf("ReapStaleConns failed: %s", err)
if cn := pool.First(); cn != nil { } else {
pool.Put(cn) n += nn
} }
} }
c.clientsMx.RUnlock() s := c.PoolStats()
Logger.Printf(
"reaper: removed %d stale conns (TotalConns=%d FreeConns=%d Requests=%d Hits=%d Timeouts=%d)",
n, s.TotalConns, s.FreeConns, s.Requests, s.Hits, s.Timeouts,
)
} }
} }
@ -309,8 +324,7 @@ type ClusterOptions struct {
// A seed list of host:port addresses of cluster nodes. // A seed list of host:port addresses of cluster nodes.
Addrs []string Addrs []string
// The maximum number of MOVED/ASK redirects to follow before // The maximum number of MOVED/ASK redirects to follow before giving up.
// giving up.
// Default is 16 // Default is 16
MaxRedirects int MaxRedirects int
@ -323,9 +337,10 @@ type ClusterOptions struct {
WriteTimeout time.Duration WriteTimeout time.Duration
// PoolSize applies per cluster node and not for the whole cluster. // PoolSize applies per cluster node and not for the whole cluster.
PoolSize int PoolSize int
PoolTimeout time.Duration PoolTimeout time.Duration
IdleTimeout time.Duration IdleTimeout time.Duration
IdleCheckFrequency time.Duration
} }
func (opt *ClusterOptions) getMaxRedirects() int { func (opt *ClusterOptions) getMaxRedirects() int {
@ -349,5 +364,6 @@ func (opt *ClusterOptions) clientOptions() *Options {
PoolSize: opt.PoolSize, PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout, PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout, IdleTimeout: opt.IdleTimeout,
// IdleCheckFrequency is not copied to disable reaper
} }
} }

View File

@ -52,7 +52,15 @@ func (s *clusterScenario) clusterClient(opt *redis.ClusterOptions) *redis.Cluste
addrs[i] = net.JoinHostPort("127.0.0.1", port) addrs[i] = net.JoinHostPort("127.0.0.1", port)
} }
if opt == nil { if opt == nil {
opt = &redis.ClusterOptions{} opt = &redis.ClusterOptions{
DialTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
PoolSize: 10,
PoolTimeout: 30 * time.Second,
IdleTimeout: time.Second,
IdleCheckFrequency: time.Second,
}
} }
opt.Addrs = addrs opt.Addrs = addrs
return redis.NewClusterClient(opt) return redis.NewClusterClient(opt)

View File

@ -1301,12 +1301,16 @@ var _ = Describe("Commands", func() {
}) })
It("should BLPop timeout", func() { It("should BLPop timeout", func() {
bLPop := client.BLPop(time.Second, "list1") val, err := client.BLPop(time.Second, "list1").Result()
Expect(bLPop.Val()).To(BeNil()) Expect(err).To(Equal(redis.Nil))
Expect(bLPop.Err()).To(Equal(redis.Nil)) Expect(val).To(BeNil())
stats := client.Pool().Stats() Expect(client.Ping().Err()).NotTo(HaveOccurred())
Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(1)))
stats := client.PoolStats()
Expect(stats.Requests).To(Equal(uint32(3)))
Expect(stats.Hits).To(Equal(uint32(2)))
Expect(stats.Timeouts).To(Equal(uint32(0)))
}) })
It("should BRPop", func() { It("should BRPop", func() {

View File

@ -1,6 +1,10 @@
package redis package redis
import "gopkg.in/redis.v3/internal/pool" import (
"time"
"gopkg.in/redis.v3/internal/pool"
)
func (c *baseClient) Pool() pool.Pooler { func (c *baseClient) Pool() pool.Pooler {
return c.connPool return c.connPool
@ -9,3 +13,7 @@ func (c *baseClient) Pool() pool.Pooler {
func (c *PubSub) Pool() pool.Pooler { func (c *PubSub) Pool() pool.Pooler {
return c.base.connPool return c.base.connPool
} }
func SetReceiveMessageTimeout(d time.Duration) {
receiveMessageTimeout = d
}

View File

@ -2,7 +2,6 @@ package pool_test
import ( import (
"errors" "errors"
"net"
"testing" "testing"
"time" "time"
@ -10,22 +9,19 @@ import (
) )
func benchmarkPoolGetPut(b *testing.B, poolSize int) { func benchmarkPoolGetPut(b *testing.B, poolSize int) {
dial := func() (net.Conn, error) { connPool := pool.NewConnPool(dummyDialer, poolSize, time.Second, time.Hour, time.Hour)
return &net.TCPConn{}, nil connPool.DialLimiter = nil
}
pool := pool.NewConnPool(dial, poolSize, time.Second, 0)
pool.DialLimiter = nil
b.ResetTimer() b.ResetTimer()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
conn, err := pool.Get() cn, err := connPool.Get()
if err != nil { if err != nil {
b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) b.Fatal(err)
} }
if err = pool.Put(conn); err != nil { if err = connPool.Put(cn); err != nil {
b.Fatalf("no error expected on pool.Put but received: %s", err.Error()) b.Fatal(err)
} }
} }
}) })
@ -43,38 +39,34 @@ func BenchmarkPoolGetPut1000Conns(b *testing.B) {
benchmarkPoolGetPut(b, 1000) benchmarkPoolGetPut(b, 1000)
} }
func benchmarkPoolGetReplace(b *testing.B, poolSize int) { func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
dial := func() (net.Conn, error) { connPool := pool.NewConnPool(dummyDialer, poolSize, time.Second, time.Hour, time.Hour)
return &net.TCPConn{}, nil connPool.DialLimiter = nil
}
pool := pool.NewConnPool(dial, poolSize, time.Second, 0)
pool.DialLimiter = nil
removeReason := errors.New("benchmark") removeReason := errors.New("benchmark")
b.ResetTimer() b.ResetTimer()
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
conn, err := pool.Get() cn, err := connPool.Get()
if err != nil { if err != nil {
b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) b.Fatal(err)
} }
if err = pool.Replace(conn, removeReason); err != nil { if err := connPool.Remove(cn, removeReason); err != nil {
b.Fatalf("no error expected on pool.Remove but received: %s", err.Error()) b.Fatal(err)
} }
} }
}) })
} }
func BenchmarkPoolGetReplace10Conns(b *testing.B) { func BenchmarkPoolGetRemove10Conns(b *testing.B) {
benchmarkPoolGetReplace(b, 10) benchmarkPoolGetRemove(b, 10)
} }
func BenchmarkPoolGetReplace100Conns(b *testing.B) { func BenchmarkPoolGetRemove100Conns(b *testing.B) {
benchmarkPoolGetReplace(b, 100) benchmarkPoolGetRemove(b, 100)
} }
func BenchmarkPoolGetReplace1000Conns(b *testing.B) { func BenchmarkPoolGetRemove1000Conns(b *testing.B) {
benchmarkPoolGetReplace(b, 1000) benchmarkPoolGetRemove(b, 1000)
} }

View File

@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"io" "io"
"net" "net"
"sync/atomic"
"time" "time"
) )
@ -13,8 +12,6 @@ const defaultBufSize = 4096
var noDeadline = time.Time{} var noDeadline = time.Time{}
type Conn struct { type Conn struct {
idx int32
NetConn net.Conn NetConn net.Conn
Rd *bufio.Reader Rd *bufio.Reader
Buf []byte Buf []byte
@ -28,8 +25,6 @@ type Conn struct {
func NewConn(netConn net.Conn) *Conn { func NewConn(netConn net.Conn) *Conn {
cn := &Conn{ cn := &Conn{
idx: -1,
NetConn: netConn, NetConn: netConn,
Buf: make([]byte, defaultBufSize), Buf: make([]byte, defaultBufSize),
@ -39,18 +34,6 @@ func NewConn(netConn net.Conn) *Conn {
return cn return cn
} }
func (cn *Conn) Index() int {
return int(atomic.LoadInt32(&cn.idx))
}
func (cn *Conn) SetIndex(newIdx int) int {
oldIdx := cn.Index()
if !atomic.CompareAndSwapInt32(&cn.idx, int32(oldIdx), int32(newIdx)) {
return -1
}
return oldIdx
}
func (cn *Conn) IsStale(timeout time.Duration) bool { func (cn *Conn) IsStale(timeout time.Duration) bool {
return timeout > 0 && time.Since(cn.UsedAt) > timeout return timeout > 0 && time.Since(cn.UsedAt) > timeout
} }

View File

@ -1,89 +0,0 @@
package pool
import (
"sync"
"sync/atomic"
)
type connList struct {
cns []*Conn
mu sync.Mutex
len int32 // atomic
size int32
}
func newConnList(size int) *connList {
return &connList{
cns: make([]*Conn, size),
size: int32(size),
}
}
func (l *connList) Len() int {
return int(atomic.LoadInt32(&l.len))
}
// Reserve reserves place in the list and returns true on success.
// The caller must add connection or cancel reservation if it was reserved.
func (l *connList) Reserve() bool {
len := atomic.AddInt32(&l.len, 1)
reserved := len <= l.size
if !reserved {
atomic.AddInt32(&l.len, -1)
}
return reserved
}
func (l *connList) CancelReservation() {
atomic.AddInt32(&l.len, -1)
}
// Add adds connection to the list. The caller must reserve place first.
func (l *connList) Add(cn *Conn) {
l.mu.Lock()
for i, c := range l.cns {
if c == nil {
cn.SetIndex(i)
l.cns[i] = cn
l.mu.Unlock()
return
}
}
panic("not reached")
}
func (l *connList) Replace(cn *Conn) {
l.mu.Lock()
if l.cns != nil {
l.cns[cn.idx] = cn
}
l.mu.Unlock()
}
// Remove closes connection and removes it from the list.
func (l *connList) Remove(idx int) {
l.mu.Lock()
if l.cns != nil {
l.cns[idx] = nil
atomic.AddInt32(&l.len, -1)
}
l.mu.Unlock()
}
func (l *connList) Reset() []*Conn {
l.mu.Lock()
for _, cn := range l.cns {
if cn == nil {
continue
}
cn.SetIndex(-1)
}
cns := l.cns
l.cns = nil
l.len = 0
l.mu.Unlock()
return cns
}

View File

@ -1,74 +0,0 @@
package pool
import (
"sync"
"time"
)
// connStack is used as a LIFO to maintain free connections
type connStack struct {
cns []*Conn
free chan struct{}
mu sync.Mutex
}
func newConnStack(max int) *connStack {
return &connStack{
cns: make([]*Conn, 0, max),
free: make(chan struct{}, max),
}
}
func (s *connStack) Len() int { return len(s.free) }
func (s *connStack) Push(cn *Conn) {
s.mu.Lock()
s.cns = append(s.cns, cn)
s.mu.Unlock()
s.free <- struct{}{}
}
func (s *connStack) ShiftStale(idleTimeout time.Duration) *Conn {
select {
case <-s.free:
s.mu.Lock()
if cn := s.cns[0]; cn.IsStale(idleTimeout) {
copy(s.cns, s.cns[1:])
s.cns = s.cns[:len(s.cns)-1]
s.mu.Unlock()
return cn
}
s.mu.Unlock()
s.free <- struct{}{}
return nil
default:
return nil
}
}
func (s *connStack) Pop() *Conn {
select {
case <-s.free:
return s.pop()
default:
return nil
}
}
func (s *connStack) PopWithTimeout(d time.Duration) *Conn {
select {
case <-s.free:
return s.pop()
case <-time.After(d):
return nil
}
}
func (s *connStack) pop() (cn *Conn) {
s.mu.Lock()
ci := len(s.cns) - 1
cn, s.cns = s.cns[ci], s.cns[:ci]
s.mu.Unlock()
return
}

View File

@ -3,24 +3,34 @@ package pool
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"net" "net"
"os" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"gopkg.in/bsm/ratelimit.v1" "gopkg.in/bsm/ratelimit.v1"
) )
var Logger = log.New(os.Stderr, "redis: ", log.LstdFlags) var Logger = log.New(ioutil.Discard, "redis: ", log.LstdFlags)
var ( var (
ErrClosed = errors.New("redis: client is closed") ErrClosed = errors.New("redis: client is closed")
errConnClosed = errors.New("redis: connection is closed")
ErrPoolTimeout = errors.New("redis: connection pool timeout") ErrPoolTimeout = errors.New("redis: connection pool timeout")
errConnClosed = errors.New("connection is closed")
errConnStale = errors.New("connection is stale")
) )
var timers = sync.Pool{
New: func() interface{} {
return time.NewTimer(0)
},
}
// PoolStats contains pool state information and accumulated stats. // PoolStats contains pool state information and accumulated stats.
// TODO: remove Waits
type PoolStats struct { type PoolStats struct {
Requests uint32 // number of times a connection was requested by the pool Requests uint32 // number of times a connection was requested by the pool
Hits uint32 // number of times free connection was found in the pool Hits uint32 // number of times free connection was found in the pool
@ -32,10 +42,9 @@ type PoolStats struct {
} }
type Pooler interface { type Pooler interface {
First() *Conn
Get() (*Conn, error) Get() (*Conn, error)
Put(*Conn) error Put(*Conn) error
Replace(*Conn, error) error Remove(*Conn, error) error
Len() int Len() int
FreeLen() int FreeLen() int
Stats() *PoolStats Stats() *PoolStats
@ -53,18 +62,23 @@ type ConnPool struct {
poolTimeout time.Duration poolTimeout time.Duration
idleTimeout time.Duration idleTimeout time.Duration
conns *connList queue chan struct{}
freeConns *connStack
stats PoolStats
_closed int32 connsMu sync.Mutex
conns []*Conn
freeConnsMu sync.Mutex
freeConns []*Conn
stats PoolStats
_closed int32 // atomic
lastErr atomic.Value lastErr atomic.Value
} }
var _ Pooler = (*ConnPool)(nil) var _ Pooler = (*ConnPool)(nil)
func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Duration) *ConnPool { func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout, idleCheckFrequency time.Duration) *ConnPool {
p := &ConnPool{ p := &ConnPool{
_dial: dial, _dial: dial,
DialLimiter: ratelimit.New(3*poolSize, time.Second), DialLimiter: ratelimit.New(3*poolSize, time.Second),
@ -72,57 +86,19 @@ func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Durati
poolTimeout: poolTimeout, poolTimeout: poolTimeout,
idleTimeout: idleTimeout, idleTimeout: idleTimeout,
conns: newConnList(poolSize), queue: make(chan struct{}, poolSize),
freeConns: newConnStack(poolSize), conns: make([]*Conn, 0, poolSize),
freeConns: make([]*Conn, 0, poolSize),
} }
if idleTimeout > 0 { for i := 0; i < poolSize; i++ {
go p.reaper(getIdleCheckFrequency()) p.queue <- struct{}{}
}
if idleTimeout > 0 && idleCheckFrequency > 0 {
go p.reaper(idleCheckFrequency)
} }
return p return p
} }
func (p *ConnPool) Add(cn *Conn) bool {
if !p.conns.Reserve() {
return false
}
p.conns.Add(cn)
p.Put(cn)
return true
}
// First returns first non-idle connection from the pool or nil if
// there are no connections.
func (p *ConnPool) First() *Conn {
for {
cn := p.freeConns.Pop()
if cn != nil && cn.IsStale(p.idleTimeout) {
var err error
cn, err = p.replace(cn)
if err != nil {
Logger.Printf("pool.replace failed: %s", err)
continue
}
}
return cn
}
}
// wait waits for free non-idle connection. It returns nil on timeout.
func (p *ConnPool) wait(timeout time.Duration) *Conn {
for {
cn := p.freeConns.PopWithTimeout(timeout)
if cn != nil && cn.IsStale(p.idleTimeout) {
var err error
cn, err = p.replace(cn)
if err != nil {
Logger.Printf("pool.replace failed: %s", err)
continue
}
}
return cn
}
}
func (p *ConnPool) dial() (net.Conn, error) { func (p *ConnPool) dial() (net.Conn, error) {
if p.DialLimiter != nil && p.DialLimiter.Limit() { if p.DialLimiter != nil && p.DialLimiter.Limit() {
err := fmt.Errorf( err := fmt.Errorf(
@ -148,6 +124,42 @@ func (p *ConnPool) NewConn() (*Conn, error) {
return NewConn(netConn), nil return NewConn(netConn), nil
} }
func (p *ConnPool) PopFree() *Conn {
timer := timers.Get().(*time.Timer)
if !timer.Reset(p.poolTimeout) {
<-timer.C
}
select {
case <-p.queue:
timers.Put(timer)
case <-timer.C:
timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1)
return nil
}
p.freeConnsMu.Lock()
cn := p.popFree()
p.freeConnsMu.Unlock()
if cn == nil {
p.queue <- struct{}{}
}
return cn
}
func (p *ConnPool) popFree() *Conn {
if len(p.freeConns) == 0 {
return nil
}
idx := len(p.freeConns) - 1
cn := p.freeConns[idx]
p.freeConns = p.freeConns[:idx]
return cn
}
// Get returns existed connection from the pool or creates a new one. // Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (*Conn, error) { func (p *ConnPool) Get() (*Conn, error) {
if p.Closed() { if p.Closed() {
@ -156,31 +168,46 @@ func (p *ConnPool) Get() (*Conn, error) {
atomic.AddUint32(&p.stats.Requests, 1) atomic.AddUint32(&p.stats.Requests, 1)
// Fetch first non-idle connection, if available. timer := timers.Get().(*time.Timer)
if cn := p.First(); cn != nil { if !timer.Reset(p.poolTimeout) {
<-timer.C
}
select {
case <-p.queue:
timers.Put(timer)
case <-timer.C:
timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1)
return nil, ErrPoolTimeout
}
p.freeConnsMu.Lock()
cn := p.popFree()
p.freeConnsMu.Unlock()
if cn != nil {
atomic.AddUint32(&p.stats.Hits, 1) atomic.AddUint32(&p.stats.Hits, 1)
return cn, nil if !cn.IsStale(p.idleTimeout) {
} return cn, nil
// Try to create a new one.
if p.conns.Reserve() {
cn, err := p.NewConn()
if err != nil {
p.conns.CancelReservation()
return nil, err
} }
p.conns.Add(cn) _ = cn.Close()
return cn, nil
} }
// Otherwise, wait for the available connection. newcn, err := p.NewConn()
atomic.AddUint32(&p.stats.Waits, 1) if err != nil {
if cn := p.wait(p.poolTimeout); cn != nil { p.queue <- struct{}{}
return cn, nil return nil, err
} }
atomic.AddUint32(&p.stats.Timeouts, 1) p.connsMu.Lock()
return nil, ErrPoolTimeout if cn != nil {
p.remove(cn, errConnStale)
}
p.conns = append(p.conns, newcn)
p.connsMu.Unlock()
return newcn, nil
} }
func (p *ConnPool) Put(cn *Conn) error { func (p *ConnPool) Put(cn *Conn) error {
@ -188,71 +215,54 @@ func (p *ConnPool) Put(cn *Conn) error {
b, _ := cn.Rd.Peek(cn.Rd.Buffered()) b, _ := cn.Rd.Peek(cn.Rd.Buffered())
err := fmt.Errorf("connection has unread data: %q", b) err := fmt.Errorf("connection has unread data: %q", b)
Logger.Print(err) Logger.Print(err)
return p.Replace(cn, err) return p.Remove(cn, err)
} }
p.freeConns.Push(cn) p.freeConnsMu.Lock()
return nil p.freeConns = append(p.freeConns, cn)
} p.freeConnsMu.Unlock()
p.queue <- struct{}{}
func (p *ConnPool) replace(cn *Conn) (*Conn, error) {
_ = cn.Close()
idx := cn.SetIndex(-1)
if idx == -1 {
return nil, errConnClosed
}
netConn, err := p.dial()
if err != nil {
p.conns.Remove(idx)
return nil, err
}
cn = NewConn(netConn)
cn.SetIndex(idx)
p.conns.Replace(cn)
return cn, nil
}
func (p *ConnPool) Replace(cn *Conn, reason error) error {
p.storeLastErr(reason.Error())
// Replace existing connection with new one and unblock waiter.
newcn, err := p.replace(cn)
if err != nil {
return err
}
p.freeConns.Push(newcn)
return nil return nil
} }
func (p *ConnPool) Remove(cn *Conn, reason error) error { func (p *ConnPool) Remove(cn *Conn, reason error) error {
_ = cn.Close() _ = cn.Close()
p.connsMu.Lock()
idx := cn.SetIndex(-1) p.remove(cn, reason)
if idx == -1 { p.connsMu.Unlock()
return errConnClosed p.queue <- struct{}{}
}
p.storeLastErr(reason.Error())
p.conns.Remove(idx)
return nil return nil
} }
func (p *ConnPool) remove(cn *Conn, reason error) {
p.storeLastErr(reason.Error())
for i, c := range p.conns {
if c == cn {
p.conns = append(p.conns[:i], p.conns[i+1:]...)
break
}
}
}
// Len returns total number of connections. // Len returns total number of connections.
func (p *ConnPool) Len() int { func (p *ConnPool) Len() int {
return p.conns.Len() p.connsMu.Lock()
l := len(p.conns)
p.connsMu.Unlock()
return l
} }
// FreeLen returns number of free connections. // FreeLen returns number of free connections.
func (p *ConnPool) FreeLen() int { func (p *ConnPool) FreeLen() int {
return p.freeConns.Len() p.freeConnsMu.Lock()
l := len(p.freeConns)
p.freeConnsMu.Unlock()
return l
} }
func (p *ConnPool) Stats() *PoolStats { func (p *ConnPool) Stats() *PoolStats {
stats := p.stats stats := PoolStats{}
stats.Requests = atomic.LoadUint32(&p.stats.Requests) stats.Requests = atomic.LoadUint32(&p.stats.Requests)
stats.Hits = atomic.LoadUint32(&p.stats.Hits)
stats.Waits = atomic.LoadUint32(&p.stats.Waits) stats.Waits = atomic.LoadUint32(&p.stats.Waits)
stats.Timeouts = atomic.LoadUint32(&p.stats.Timeouts) stats.Timeouts = atomic.LoadUint32(&p.stats.Timeouts)
stats.TotalConns = uint32(p.Len()) stats.TotalConns = uint32(p.Len())
@ -269,16 +279,10 @@ func (p *ConnPool) Close() (retErr error) {
return ErrClosed return ErrClosed
} }
// Wait for app to free connections, but don't close them immediately. p.connsMu.Lock()
for i := 0; i < p.Len()-p.FreeLen(); i++ {
if cn := p.wait(3 * time.Second); cn == nil {
break
}
}
// Close all connections. // Close all connections.
cns := p.conns.Reset() for _, cn := range p.conns {
for _, cn := range cns {
if cn == nil { if cn == nil {
continue continue
} }
@ -286,6 +290,12 @@ func (p *ConnPool) Close() (retErr error) {
retErr = err retErr = err
} }
} }
p.conns = nil
p.connsMu.Unlock()
p.freeConnsMu.Lock()
p.freeConns = nil
p.freeConnsMu.Unlock()
return retErr return retErr
} }
@ -298,16 +308,32 @@ func (p *ConnPool) closeConn(cn *Conn) error {
} }
func (p *ConnPool) ReapStaleConns() (n int, err error) { func (p *ConnPool) ReapStaleConns() (n int, err error) {
for { <-p.queue
cn := p.freeConns.ShiftStale(p.idleTimeout) p.freeConnsMu.Lock()
if cn == nil {
if len(p.freeConns) == 0 {
p.freeConnsMu.Unlock()
p.queue <- struct{}{}
return
}
var idx int
var cn *Conn
for idx, cn = range p.freeConns {
if !cn.IsStale(p.idleTimeout) {
break break
} }
if err = p.Remove(cn, errors.New("connection is stale")); err != nil { p.connsMu.Lock()
return p.remove(cn, errConnStale)
} p.connsMu.Unlock()
n++ n++
} }
if idx > 0 {
p.freeConns = append(p.freeConns[:0], p.freeConns[idx:]...)
}
p.freeConnsMu.Unlock()
p.queue <- struct{}{}
return return
} }
@ -322,9 +348,13 @@ func (p *ConnPool) reaper(frequency time.Duration) {
n, err := p.ReapStaleConns() n, err := p.ReapStaleConns()
if err != nil { if err != nil {
Logger.Printf("ReapStaleConns failed: %s", err) Logger.Printf("ReapStaleConns failed: %s", err)
} else if n > 0 { continue
Logger.Printf("removed %d stale connections", n)
} }
s := p.Stats()
Logger.Printf(
"reaper: removed %d stale conns (TotalConns=%d FreeConns=%d Requests=%d Hits=%d Timeouts=%d)",
n, s.TotalConns, s.FreeConns, s.Requests, s.Hits, s.Timeouts,
)
} }
} }

View File

@ -27,7 +27,7 @@ func (p *SingleConnPool) Put(cn *Conn) error {
return nil return nil
} }
func (p *SingleConnPool) Replace(cn *Conn, _ error) error { func (p *SingleConnPool) Remove(cn *Conn, _ error) error {
if p.cn != cn { if p.cn != cn {
panic("p.cn != cn") panic("p.cn != cn")
} }

View File

@ -67,13 +67,13 @@ func (p *StickyConnPool) Put(cn *Conn) error {
return nil return nil
} }
func (p *StickyConnPool) replace(reason error) error { func (p *StickyConnPool) remove(reason error) error {
err := p.pool.Replace(p.cn, reason) err := p.pool.Remove(p.cn, reason)
p.cn = nil p.cn = nil
return err return err
} }
func (p *StickyConnPool) Replace(cn *Conn, reason error) error { func (p *StickyConnPool) Remove(cn *Conn, reason error) error {
defer p.mx.Unlock() defer p.mx.Unlock()
p.mx.Lock() p.mx.Lock()
if p.closed { if p.closed {
@ -85,7 +85,7 @@ func (p *StickyConnPool) Replace(cn *Conn, reason error) error {
if cn != nil && p.cn != cn { if cn != nil && p.cn != cn {
panic("p.cn != cn") panic("p.cn != cn")
} }
return p.replace(reason) return p.remove(reason)
} }
func (p *StickyConnPool) Len() int { func (p *StickyConnPool) Len() int {
@ -121,7 +121,7 @@ func (p *StickyConnPool) Close() error {
err = p.put() err = p.put()
} else { } else {
reason := errors.New("redis: sticky not reusable connection") reason := errors.New("redis: sticky not reusable connection")
err = p.replace(reason) err = p.remove(reason)
} }
} }
return err return err

View File

@ -16,8 +16,8 @@ var _ = Describe("ConnPool", func() {
var connPool *pool.ConnPool var connPool *pool.ConnPool
BeforeEach(func() { BeforeEach(func() {
pool.SetIdleCheckFrequency(time.Second) connPool = pool.NewConnPool(
connPool = pool.NewConnPool(dummyDialer, 10, time.Hour, time.Second) dummyDialer, 10, time.Hour, time.Millisecond, time.Millisecond)
}) })
AfterEach(func() { AfterEach(func() {
@ -33,7 +33,7 @@ var _ = Describe("ConnPool", func() {
break break
} }
_ = connPool.Replace(cn, errors.New("test")) _ = connPool.Remove(cn, errors.New("test"))
} }
Expect(rateErr).To(MatchError(`redis: you open connections too fast (last_error="test")`)) Expect(rateErr).To(MatchError(`redis: you open connections too fast (last_error="test")`))
@ -75,7 +75,7 @@ var _ = Describe("ConnPool", func() {
// ok // ok
} }
err = connPool.Replace(cn, errors.New("test")) err = connPool.Remove(cn, errors.New("test"))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Check that Ping is unblocked. // Check that Ping is unblocked.
@ -93,26 +93,33 @@ var _ = Describe("ConnPool", func() {
}) })
}) })
var _ = Describe("conns reapser", func() { var _ = Describe("conns reaper", func() {
var connPool *pool.ConnPool var connPool *pool.ConnPool
BeforeEach(func() { BeforeEach(func() {
pool.SetIdleCheckFrequency(time.Hour) connPool = pool.NewConnPool(
connPool = pool.NewConnPool(dummyDialer, 10, 0, time.Minute) dummyDialer, 10, time.Second, time.Millisecond, time.Hour)
var cns []*pool.Conn
// add stale connections // add stale connections
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn := pool.NewConn(&net.TCPConn{}) cn, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
cn.UsedAt = time.Now().Add(-2 * time.Minute) cn.UsedAt = time.Now().Add(-2 * time.Minute)
Expect(connPool.Add(cn)).To(BeTrue()) cns = append(cns, cn)
Expect(cn.Index()).To(Equal(i))
} }
// add fresh connections // add fresh connections
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn := pool.NewConn(&net.TCPConn{}) cn := pool.NewConn(&net.TCPConn{})
Expect(connPool.Add(cn)).To(BeTrue()) cn, err := connPool.Get()
Expect(cn.Index()).To(Equal(3 + i)) Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn)
}
for _, cn := range cns {
Expect(connPool.Put(cn)).NotTo(HaveOccurred())
} }
Expect(connPool.Len()).To(Equal(6)) Expect(connPool.Len()).To(Equal(6))
@ -136,7 +143,8 @@ var _ = Describe("conns reapser", func() {
for j := 0; j < 3; j++ { for j := 0; j < 3; j++ {
var freeCns []*pool.Conn var freeCns []*pool.Conn
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn := connPool.First() cn, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn) freeCns = append(freeCns, cn)
} }
@ -144,9 +152,6 @@ var _ = Describe("conns reapser", func() {
Expect(connPool.Len()).To(Equal(3)) Expect(connPool.Len()).To(Equal(3))
Expect(connPool.FreeLen()).To(Equal(0)) Expect(connPool.FreeLen()).To(Equal(0))
cn := connPool.First()
Expect(cn).To(BeNil())
cn, err := connPool.Get() cn, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
@ -173,42 +178,60 @@ var _ = Describe("conns reapser", func() {
var _ = Describe("race", func() { var _ = Describe("race", func() {
var connPool *pool.ConnPool var connPool *pool.ConnPool
var C, N int
var C, N = 10, 1000
if testing.Short() {
C = 4
N = 100
}
BeforeEach(func() { BeforeEach(func() {
pool.SetIdleCheckFrequency(time.Second) C, N = 10, 1000
connPool = pool.NewConnPool(dummyDialer, 10, time.Second, time.Second) if testing.Short() {
C = 4
N = 100
}
}) })
AfterEach(func() { AfterEach(func() {
connPool.Close() connPool.Close()
}) })
It("does not happend", func() { It("does not happen on Get, Put, and Remove", func() {
connPool = pool.NewConnPool(
dummyDialer, 10, time.Minute, time.Millisecond, time.Millisecond)
connPool.DialLimiter = nil
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get() cn, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
connPool.Put(cn) Expect(connPool.Put(cn)).NotTo(HaveOccurred())
} }
} }
}, func(id int) { }, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get() cn, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
connPool.Replace(cn, errors.New("test")) Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred())
} }
} }
}, func(id int) { })
})
It("does not happen on Get and PopFree", func() {
connPool = pool.NewConnPool(
dummyDialer, 10, time.Minute, time.Second, time.Millisecond)
connPool.DialLimiter = nil
perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get() cn, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
connPool.Remove(cn, errors.New("test")) Expect(connPool.Put(cn)).NotTo(HaveOccurred())
}
cn = connPool.PopFree()
if cn != nil {
Expect(connPool.Put(cn)).NotTo(HaveOccurred())
} }
} }
}) })

View File

@ -15,7 +15,6 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"gopkg.in/redis.v3" "gopkg.in/redis.v3"
"gopkg.in/redis.v3/internal/pool"
) )
const ( const (
@ -53,8 +52,6 @@ var cluster = &clusterScenario{
var _ = BeforeSuite(func() { var _ = BeforeSuite(func() {
var err error var err error
pool.SetIdleCheckFrequency(time.Second) // be aggressive in tests
redisMain, err = startRedis(redisPort) redisMain, err = startRedis(redisPort)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -104,27 +101,30 @@ func TestGinkgoSuite(t *testing.T) {
func redisOptions() *redis.Options { func redisOptions() *redis.Options {
return &redis.Options{ return &redis.Options{
Addr: redisAddr, Addr: redisAddr,
DB: 15, DB: 15,
DialTimeout: 10 * time.Second, DialTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second, ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second,
PoolSize: 10, PoolSize: 10,
PoolTimeout: 30 * time.Second, PoolTimeout: 30 * time.Second,
IdleTimeout: time.Second, // be aggressive in tests IdleTimeout: time.Second,
IdleCheckFrequency: time.Second,
} }
} }
func perform(n int, cb func(int)) { func perform(n int, cbs ...func(int)) {
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < n; i++ { for _, cb := range cbs {
wg.Add(1) for i := 0; i < n; i++ {
go func(i int) { wg.Add(1)
defer GinkgoRecover() go func(cb func(int), i int) {
defer wg.Done() defer GinkgoRecover()
defer wg.Done()
cb(i) cb(i)
}(i) }(cb, i)
}
} }
wg.Wait() wg.Wait()
} }

View File

@ -50,6 +50,9 @@ type Options struct {
// connections. Should be less than server's timeout. // connections. Should be less than server's timeout.
// Default is to not close idle connections. // Default is to not close idle connections.
IdleTimeout time.Duration IdleTimeout time.Duration
// The frequency of idle checks.
// Default is 1 minute.
IdleCheckFrequency time.Duration
} }
func (opt *Options) getNetwork() string { func (opt *Options) getNetwork() string {
@ -93,9 +96,21 @@ func (opt *Options) getIdleTimeout() time.Duration {
return opt.IdleTimeout return opt.IdleTimeout
} }
func (opt *Options) getIdleCheckFrequency() time.Duration {
if opt.IdleCheckFrequency == 0 {
return time.Minute
}
return opt.IdleCheckFrequency
}
func newConnPool(opt *Options) *pool.ConnPool { func newConnPool(opt *Options) *pool.ConnPool {
return pool.NewConnPool( return pool.NewConnPool(
opt.getDialer(), opt.getPoolSize(), opt.getPoolTimeout(), opt.getIdleTimeout()) opt.getDialer(),
opt.getPoolSize(),
opt.getPoolTimeout(),
opt.getIdleTimeout(),
opt.getIdleCheckFrequency(),
)
} }
// PoolStats contains pool state information and accumulated stats. // PoolStats contains pool state information and accumulated stats.

View File

@ -106,7 +106,7 @@ var _ = Describe("pool", func() {
stats := pool.Stats() stats := pool.Stats()
Expect(stats.Requests).To(Equal(uint32(4))) Expect(stats.Requests).To(Equal(uint32(4)))
Expect(stats.Hits).To(Equal(uint32(3))) Expect(stats.Hits).To(Equal(uint32(2)))
Expect(stats.Waits).To(Equal(uint32(0))) Expect(stats.Waits).To(Equal(uint32(0)))
Expect(stats.Timeouts).To(Equal(uint32(0))) Expect(stats.Timeouts).To(Equal(uint32(0)))
}) })

View File

@ -8,6 +8,8 @@ import (
"gopkg.in/redis.v3/internal/pool" "gopkg.in/redis.v3/internal/pool"
) )
var receiveMessageTimeout = 5 * time.Second
// Posts a message to the given channel. // Posts a message to the given channel.
func (c *Client) Publish(channel, message string) *IntCmd { func (c *Client) Publish(channel, message string) *IntCmd {
req := NewIntCmd("PUBLISH", channel, message) req := NewIntCmd("PUBLISH", channel, message)
@ -255,7 +257,7 @@ func (c *PubSub) Receive() (interface{}, error) {
func (c *PubSub) ReceiveMessage() (*Message, error) { func (c *PubSub) ReceiveMessage() (*Message, error) {
var errNum uint var errNum uint
for { for {
msgi, err := c.ReceiveTimeout(5 * time.Second) msgi, err := c.ReceiveTimeout(receiveMessageTimeout)
if err != nil { if err != nil {
if !isNetworkError(err) { if !isNetworkError(err) {
return nil, err return nil, err

View File

@ -68,7 +68,7 @@ var _ = Describe("PubSub", func() {
Expect(subscr.Count).To(Equal(0)) Expect(subscr.Count).To(Equal(0))
} }
stats := client.Pool().Stats() stats := client.PoolStats()
Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2))) Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2)))
}) })
@ -195,7 +195,7 @@ var _ = Describe("PubSub", func() {
Expect(subscr.Count).To(Equal(0)) Expect(subscr.Count).To(Equal(0))
} }
stats := client.Pool().Stats() stats := client.PoolStats()
Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2))) Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2)))
}) })
@ -256,6 +256,9 @@ var _ = Describe("PubSub", func() {
}) })
It("should ReceiveMessage after timeout", func() { It("should ReceiveMessage after timeout", func() {
timeout := time.Second
redis.SetReceiveMessageTimeout(timeout)
pubsub, err := client.Subscribe("mychannel") pubsub, err := client.Subscribe("mychannel")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
defer pubsub.Close() defer pubsub.Close()
@ -267,7 +270,7 @@ var _ = Describe("PubSub", func() {
done <- true done <- true
}() }()
time.Sleep(5*time.Second + 100*time.Millisecond) time.Sleep(timeout + 100*time.Millisecond)
n, err := client.Publish("mychannel", "hello").Result() n, err := client.Publish("mychannel", "hello").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1))) Expect(n).To(Equal(int64(1)))
@ -280,8 +283,9 @@ var _ = Describe("PubSub", func() {
Eventually(done).Should(Receive()) Eventually(done).Should(Receive())
stats := client.Pool().Stats() stats := client.PoolStats()
Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2))) Expect(stats.Requests).To(Equal(uint32(3)))
Expect(stats.Hits).To(Equal(uint32(1)))
}) })
expectReceiveMessageOnError := func(pubsub *redis.PubSub) { expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
@ -311,8 +315,9 @@ var _ = Describe("PubSub", func() {
Eventually(done).Should(Receive()) Eventually(done).Should(Receive())
stats := client.Pool().Stats() stats := client.PoolStats()
Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2))) Expect(stats.Requests).To(Equal(uint32(4)))
Expect(stats.Hits).To(Equal(uint32(1)))
} }
It("Subscribe should reconnect on ReceiveMessage error", func() { It("Subscribe should reconnect on ReceiveMessage error", func() {

View File

@ -17,16 +17,17 @@ import (
var _ = Describe("races", func() { var _ = Describe("races", func() {
var client *redis.Client var client *redis.Client
var C, N int
var C, N = 10, 1000
if testing.Short() {
C = 4
N = 100
}
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(redisOptions()) client = redis.NewClient(redisOptions())
Expect(client.FlushDb().Err()).To(BeNil()) Expect(client.FlushDb().Err()).To(BeNil())
C, N = 10, 1000
if testing.Short() {
C = 4
N = 100
}
}) })
AfterEach(func() { AfterEach(func() {
@ -123,16 +124,13 @@ var _ = Describe("races", func() {
}) })
It("should handle big vals in Set", func() { It("should handle big vals in Set", func() {
C, N = 4, 100
bigVal := string(bytes.Repeat([]byte{'*'}, 1<<17)) // 128kb bigVal := string(bytes.Repeat([]byte{'*'}, 1<<17)) // 128kb
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
err := client.Set("key", bigVal, 0).Err() err := client.Set("key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
got, err := client.Get("key").Result()
Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal))
} }
}) })
}) })

View File

@ -39,7 +39,7 @@ func (c *baseClient) conn() (*pool.Conn, error) {
} }
if !cn.Inited { if !cn.Inited {
if err := c.initConn(cn); err != nil { if err := c.initConn(cn); err != nil {
_ = c.connPool.Replace(cn, err) _ = c.connPool.Remove(cn, err)
return nil, err return nil, err
} }
} }
@ -48,7 +48,7 @@ func (c *baseClient) conn() (*pool.Conn, error) {
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
if isBadConn(err, allowTimeout) { if isBadConn(err, allowTimeout) {
_ = c.connPool.Replace(cn, err) _ = c.connPool.Remove(cn, err)
return false return false
} }

View File

@ -166,7 +166,8 @@ var _ = Describe("Client", func() {
err = client.Ping().Err() err = client.Ping().Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn = client.Pool().First() cn, err = client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
Expect(cn.UsedAt.After(createdAt)).To(BeTrue()) Expect(cn.UsedAt.After(createdAt)).To(BeTrue())
}) })

14
ring.go
View File

@ -32,9 +32,10 @@ type RingOptions struct {
ReadTimeout time.Duration ReadTimeout time.Duration
WriteTimeout time.Duration WriteTimeout time.Duration
PoolSize int PoolSize int
PoolTimeout time.Duration PoolTimeout time.Duration
IdleTimeout time.Duration IdleTimeout time.Duration
IdleCheckFrequency time.Duration
} }
func (opt *RingOptions) clientOptions() *Options { func (opt *RingOptions) clientOptions() *Options {
@ -46,9 +47,10 @@ func (opt *RingOptions) clientOptions() *Options {
ReadTimeout: opt.ReadTimeout, ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout, WriteTimeout: opt.WriteTimeout,
PoolSize: opt.PoolSize, PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout, PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout, IdleTimeout: opt.IdleTimeout,
IdleCheckFrequency: opt.IdleCheckFrequency,
} }
} }

View File

@ -26,15 +26,16 @@ type FailoverOptions struct {
Password string Password string
DB int64 DB int64
MaxRetries int
DialTimeout time.Duration DialTimeout time.Duration
ReadTimeout time.Duration ReadTimeout time.Duration
WriteTimeout time.Duration WriteTimeout time.Duration
PoolSize int PoolSize int
PoolTimeout time.Duration PoolTimeout time.Duration
IdleTimeout time.Duration IdleTimeout time.Duration
IdleCheckFrequency time.Duration
MaxRetries int
} }
func (opt *FailoverOptions) options() *Options { func (opt *FailoverOptions) options() *Options {
@ -44,15 +45,16 @@ func (opt *FailoverOptions) options() *Options {
DB: opt.DB, DB: opt.DB,
Password: opt.Password, Password: opt.Password,
MaxRetries: opt.MaxRetries,
DialTimeout: opt.DialTimeout, DialTimeout: opt.DialTimeout,
ReadTimeout: opt.ReadTimeout, ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout, WriteTimeout: opt.WriteTimeout,
PoolSize: opt.PoolSize, PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout, PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout, IdleTimeout: opt.IdleTimeout,
IdleCheckFrequency: opt.IdleCheckFrequency,
MaxRetries: opt.MaxRetries,
} }
} }
@ -257,7 +259,7 @@ func (d *sentinelFailover) closeOldConns(newMaster string) {
cnsToPut := make([]*pool.Conn, 0) cnsToPut := make([]*pool.Conn, 0)
for { for {
cn := d.pool.First() cn := d.pool.PopFree()
if cn == nil { if cn == nil {
break break
} }
@ -267,7 +269,7 @@ func (d *sentinelFailover) closeOldConns(newMaster string) {
cn.RemoteAddr(), cn.RemoteAddr(),
) )
Logger.Print(err) Logger.Print(err)
d.pool.Replace(cn, err) d.pool.Remove(cn, err)
} else { } else {
cnsToPut = append(cnsToPut, cn) cnsToPut = append(cnsToPut, cn)
} }