diff --git a/bench_test.go b/bench_test.go index dfd2bf3e..5d6fa37f 100644 --- a/bench_test.go +++ b/bench_test.go @@ -2,12 +2,15 @@ package redis_test import ( "bytes" + "errors" + "net" "testing" "time" redigo "github.com/garyburd/redigo/redis" "gopkg.in/redis.v3" + "gopkg.in/redis.v3/internal/pool" ) func benchmarkRedisClient(poolSize int) *redis.Client { @@ -274,11 +277,11 @@ func BenchmarkZAdd(b *testing.B) { }) } -func BenchmarkPool(b *testing.B) { - client := benchmarkRedisClient(10) - defer client.Close() - - pool := client.Pool() +func benchmarkPoolGetPut(b *testing.B, poolSize int) { + dial := func() (*pool.Conn, error) { + return pool.NewConn(&net.TCPConn{}), nil + } + pool := pool.NewConnPool(dial, poolSize, time.Second, 0) b.ResetTimer() @@ -294,3 +297,49 @@ func BenchmarkPool(b *testing.B) { } }) } + +func BenchmarkPoolGetPut10Conns(b *testing.B) { + benchmarkPoolGetPut(b, 10) +} + +func BenchmarkPoolGetPut100Conns(b *testing.B) { + benchmarkPoolGetPut(b, 100) +} + +func BenchmarkPoolGetPut1000Conns(b *testing.B) { + benchmarkPoolGetPut(b, 1000) +} + +func benchmarkPoolGetRemove(b *testing.B, poolSize int) { + dial := func() (*pool.Conn, error) { + return pool.NewConn(&net.TCPConn{}), nil + } + pool := pool.NewConnPool(dial, poolSize, time.Second, 0) + removeReason := errors.New("benchmark") + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, _, err := pool.Get() + if err != nil { + b.Fatalf("no error expected on pool.Get but received: %s", err.Error()) + } + if err = pool.Remove(conn, removeReason); err != nil { + b.Fatalf("no error expected on pool.Remove but received: %s", err.Error()) + } + } + }) +} + +func BenchmarkPoolGetRemove10Conns(b *testing.B) { + benchmarkPoolGetRemove(b, 10) +} + +func BenchmarkPoolGetRemove100Conns(b *testing.B) { + benchmarkPoolGetRemove(b, 100) +} + +func BenchmarkPoolGetRemove1000Conns(b *testing.B) { + benchmarkPoolGetRemove(b, 1000) +} diff --git a/cluster_pipeline.go b/cluster_pipeline.go index 5299b5f5..7fa721c3 100644 --- a/cluster_pipeline.go +++ b/cluster_pipeline.go @@ -2,6 +2,7 @@ package redis import ( "gopkg.in/redis.v3/internal/hashtag" + "gopkg.in/redis.v3/internal/pool" ) // ClusterPipeline is not thread-safe. @@ -96,9 +97,9 @@ func (pipe *ClusterPipeline) Close() error { } func (pipe *ClusterPipeline) execClusterCmds( - cn *conn, cmds []Cmder, failedCmds map[string][]Cmder, + cn *pool.Conn, cmds []Cmder, failedCmds map[string][]Cmder, ) (map[string][]Cmder, error) { - if err := cn.writeCmds(cmds...); err != nil { + if err := writeCmd(cn, cmds...); err != nil { setCmdsErr(cmds, err) return failedCmds, err } diff --git a/command.go b/command.go index 31516f60..6681ff53 100644 --- a/command.go +++ b/command.go @@ -6,6 +6,8 @@ import ( "strconv" "strings" "time" + + "gopkg.in/redis.v3/internal/pool" ) var ( @@ -28,7 +30,7 @@ var ( type Cmder interface { args() []interface{} - readReply(*conn) error + readReply(*pool.Conn) error setErr(error) reset() @@ -51,6 +53,20 @@ func resetCmds(cmds []Cmder) { } } +func writeCmd(cn *pool.Conn, cmds ...Cmder) error { + cn.Buf = cn.Buf[:0] + for _, cmd := range cmds { + var err error + cn.Buf, err = appendArgs(cn.Buf, cmd.args()) + if err != nil { + return err + } + } + + _, err := cn.Write(cn.Buf) + return err +} + func cmdString(cmd Cmder, val interface{}) string { var ss []string for _, arg := range cmd.args() { @@ -143,7 +159,7 @@ func (cmd *Cmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *Cmd) readReply(cn *conn) error { +func (cmd *Cmd) readReply(cn *pool.Conn) error { val, err := readReply(cn, sliceParser) if err != nil { cmd.err = err @@ -188,7 +204,7 @@ func (cmd *SliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *SliceCmd) readReply(cn *conn) error { +func (cmd *SliceCmd) readReply(cn *pool.Conn) error { v, err := readArrayReply(cn, sliceParser) if err != nil { cmd.err = err @@ -231,7 +247,7 @@ func (cmd *StatusCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StatusCmd) readReply(cn *conn) error { +func (cmd *StatusCmd) readReply(cn *pool.Conn) error { cmd.val, cmd.err = readStringReply(cn) return cmd.err } @@ -265,7 +281,7 @@ func (cmd *IntCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *IntCmd) readReply(cn *conn) error { +func (cmd *IntCmd) readReply(cn *pool.Conn) error { cmd.val, cmd.err = readIntReply(cn) return cmd.err } @@ -303,7 +319,7 @@ func (cmd *DurationCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *DurationCmd) readReply(cn *conn) error { +func (cmd *DurationCmd) readReply(cn *pool.Conn) error { n, err := readIntReply(cn) if err != nil { cmd.err = err @@ -344,7 +360,7 @@ func (cmd *BoolCmd) String() string { var ok = []byte("OK") -func (cmd *BoolCmd) readReply(cn *conn) error { +func (cmd *BoolCmd) readReply(cn *pool.Conn) error { v, err := readReply(cn, nil) // `SET key value NX` returns nil when key already exists. But // `SETNX key value` returns bool (0/1). So convert nil to bool. @@ -430,13 +446,17 @@ func (cmd *StringCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringCmd) readReply(cn *conn) error { +func (cmd *StringCmd) readReply(cn *pool.Conn) error { b, err := readBytesReply(cn) if err != nil { cmd.err = err return err } - cmd.val = cn.copyBuf(b) + + new := make([]byte, len(b)) + copy(new, b) + cmd.val = new + return nil } @@ -469,7 +489,7 @@ func (cmd *FloatCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *FloatCmd) readReply(cn *conn) error { +func (cmd *FloatCmd) readReply(cn *pool.Conn) error { cmd.val, cmd.err = readFloatReply(cn) return cmd.err } @@ -503,7 +523,7 @@ func (cmd *StringSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringSliceCmd) readReply(cn *conn) error { +func (cmd *StringSliceCmd) readReply(cn *pool.Conn) error { v, err := readArrayReply(cn, stringSliceParser) if err != nil { cmd.err = err @@ -542,7 +562,7 @@ func (cmd *BoolSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *BoolSliceCmd) readReply(cn *conn) error { +func (cmd *BoolSliceCmd) readReply(cn *pool.Conn) error { v, err := readArrayReply(cn, boolSliceParser) if err != nil { cmd.err = err @@ -581,7 +601,7 @@ func (cmd *StringStringMapCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringStringMapCmd) readReply(cn *conn) error { +func (cmd *StringStringMapCmd) readReply(cn *pool.Conn) error { v, err := readArrayReply(cn, stringStringMapParser) if err != nil { cmd.err = err @@ -620,7 +640,7 @@ func (cmd *StringIntMapCmd) reset() { cmd.err = nil } -func (cmd *StringIntMapCmd) readReply(cn *conn) error { +func (cmd *StringIntMapCmd) readReply(cn *pool.Conn) error { v, err := readArrayReply(cn, stringIntMapParser) if err != nil { cmd.err = err @@ -659,7 +679,7 @@ func (cmd *ZSliceCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ZSliceCmd) readReply(cn *conn) error { +func (cmd *ZSliceCmd) readReply(cn *pool.Conn) error { v, err := readArrayReply(cn, zSliceParser) if err != nil { cmd.err = err @@ -703,7 +723,7 @@ func (cmd *ScanCmd) String() string { return cmdString(cmd, cmd.keys) } -func (cmd *ScanCmd) readReply(cn *conn) error { +func (cmd *ScanCmd) readReply(cn *pool.Conn) error { keys, cursor, err := readScanReply(cn) if err != nil { cmd.err = err @@ -751,7 +771,7 @@ func (cmd *ClusterSlotCmd) reset() { cmd.err = nil } -func (cmd *ClusterSlotCmd) readReply(cn *conn) error { +func (cmd *ClusterSlotCmd) readReply(cn *pool.Conn) error { v, err := readArrayReply(cn, clusterSlotInfoSliceParser) if err != nil { cmd.err = err @@ -838,7 +858,7 @@ func (cmd *GeoLocationCmd) String() string { return cmdString(cmd, cmd.locations) } -func (cmd *GeoLocationCmd) readReply(cn *conn) error { +func (cmd *GeoLocationCmd) readReply(cn *pool.Conn) error { reply, err := readArrayReply(cn, newGeoLocationSliceParser(cmd.q)) if err != nil { cmd.err = err diff --git a/conn.go b/conn.go deleted file mode 100644 index e7306d94..00000000 --- a/conn.go +++ /dev/null @@ -1,120 +0,0 @@ -package redis - -import ( - "bufio" - "net" - "time" -) - -const defaultBufSize = 4096 - -var noTimeout = time.Time{} - -// Stubbed in tests. -var now = time.Now - -type conn struct { - netcn net.Conn - rd *bufio.Reader - buf []byte - - UsedAt time.Time - ReadTimeout time.Duration - WriteTimeout time.Duration -} - -func newConnDialer(opt *Options) func() (*conn, error) { - dialer := opt.getDialer() - return func() (*conn, error) { - netcn, err := dialer() - if err != nil { - return nil, err - } - cn := &conn{ - netcn: netcn, - buf: make([]byte, defaultBufSize), - - UsedAt: now(), - } - cn.rd = bufio.NewReader(cn) - return cn, cn.init(opt) - } -} - -func (cn *conn) init(opt *Options) error { - if opt.Password == "" && opt.DB == 0 { - return nil - } - - // Temp client for Auth and Select. - client := newClient(opt, newSingleConnPool(cn)) - - if opt.Password != "" { - if err := client.Auth(opt.Password).Err(); err != nil { - return err - } - } - - if opt.DB > 0 { - if err := client.Select(opt.DB).Err(); err != nil { - return err - } - } - - return nil -} - -func (cn *conn) writeCmds(cmds ...Cmder) error { - cn.buf = cn.buf[:0] - for _, cmd := range cmds { - var err error - cn.buf, err = appendArgs(cn.buf, cmd.args()) - if err != nil { - return err - } - } - - _, err := cn.Write(cn.buf) - return err -} - -func (cn *conn) Read(b []byte) (int, error) { - cn.UsedAt = now() - if cn.ReadTimeout != 0 { - cn.netcn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout)) - } else { - cn.netcn.SetReadDeadline(noTimeout) - } - return cn.netcn.Read(b) -} - -func (cn *conn) Write(b []byte) (int, error) { - cn.UsedAt = now() - if cn.WriteTimeout != 0 { - cn.netcn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout)) - } else { - cn.netcn.SetWriteDeadline(noTimeout) - } - return cn.netcn.Write(b) -} - -func (cn *conn) RemoteAddr() net.Addr { - return cn.netcn.RemoteAddr() -} - -func (cn *conn) Close() error { - return cn.netcn.Close() -} - -func isSameSlice(s1, s2 []byte) bool { - return len(s1) > 0 && len(s2) > 0 && &s1[0] == &s2[0] -} - -func (cn *conn) copyBuf(b []byte) []byte { - if isSameSlice(b, cn.buf) { - new := make([]byte, len(b)) - copy(new, b) - return new - } - return b -} diff --git a/conn_test.go b/conn_test.go deleted file mode 100644 index 4f1eef6d..00000000 --- a/conn_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package redis_test - -import ( - "net" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - - "gopkg.in/redis.v3" -) - -var _ = Describe("newConnDialer with bad connection", func() { - It("should return an error", func() { - dialer := redis.NewConnDialer(&redis.Options{ - Dialer: func() (net.Conn, error) { - return &badConn{}, nil - }, - MaxRetries: 3, - Password: "password", - DB: 1, - }) - _, err := dialer() - Expect(err).To(MatchError("bad connection")) - }) -}) diff --git a/error.go b/error.go index 3f2a560c..e2430b4a 100644 --- a/error.go +++ b/error.go @@ -1,12 +1,15 @@ package redis import ( + "errors" "fmt" "io" "net" "strings" ) +var errClosed = errors.New("redis: client is closed") + // Redis nil reply, .e.g. when key does not exist. var Nil = errorf("redis: nil") diff --git a/export_test.go b/export_test.go index 95715b5f..cce779bb 100644 --- a/export_test.go +++ b/export_test.go @@ -1,30 +1,11 @@ package redis -import ( - "net" - "time" -) +import "gopkg.in/redis.v3/internal/pool" -func (c *baseClient) Pool() pool { +func (c *baseClient) Pool() pool.Pooler { return c.connPool } -func (c *PubSub) Pool() pool { +func (c *PubSub) Pool() pool.Pooler { return c.base.connPool } - -var NewConnDialer = newConnDialer - -func (cn *conn) SetNetConn(netcn net.Conn) { - cn.netcn = netcn -} - -func SetTime(tm time.Time) { - now = func() time.Time { - return tm - } -} - -func RestoreTime() { - now = time.Now -} diff --git a/internal/pool/conn.go b/internal/pool/conn.go new file mode 100644 index 00000000..1e1e8611 --- /dev/null +++ b/internal/pool/conn.go @@ -0,0 +1,60 @@ +package pool + +import ( + "bufio" + "net" + "time" +) + +const defaultBufSize = 4096 + +var noTimeout = time.Time{} + +type Conn struct { + NetConn net.Conn + Rd *bufio.Reader + Buf []byte + + UsedAt time.Time + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +func NewConn(netConn net.Conn) *Conn { + cn := &Conn{ + NetConn: netConn, + Buf: make([]byte, defaultBufSize), + + UsedAt: time.Now(), + } + cn.Rd = bufio.NewReader(cn) + return cn +} + +func (cn *Conn) Read(b []byte) (int, error) { + cn.UsedAt = time.Now() + if cn.ReadTimeout != 0 { + cn.NetConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout)) + } else { + cn.NetConn.SetReadDeadline(noTimeout) + } + return cn.NetConn.Read(b) +} + +func (cn *Conn) Write(b []byte) (int, error) { + cn.UsedAt = time.Now() + if cn.WriteTimeout != 0 { + cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout)) + } else { + cn.NetConn.SetWriteDeadline(noTimeout) + } + return cn.NetConn.Write(b) +} + +func (cn *Conn) RemoteAddr() net.Addr { + return cn.NetConn.RemoteAddr() +} + +func (cn *Conn) Close() error { + return cn.NetConn.Close() +} diff --git a/internal/pool/conn_list.go b/internal/pool/conn_list.go new file mode 100644 index 00000000..f8b82ab2 --- /dev/null +++ b/internal/pool/conn_list.go @@ -0,0 +1,100 @@ +package pool + +import ( + "sync" + "sync/atomic" +) + +type connList struct { + cns []*Conn + mx sync.Mutex + len int32 // atomic + size int32 +} + +func newConnList(size int) *connList { + return &connList{ + cns: make([]*Conn, 0, size), + size: int32(size), + } +} + +func (l *connList) Len() int { + return int(atomic.LoadInt32(&l.len)) +} + +// Reserve reserves place in the list and returns true on success. The +// caller must add or remove connection if place was reserved. +func (l *connList) Reserve() bool { + len := atomic.AddInt32(&l.len, 1) + reserved := len <= l.size + if !reserved { + atomic.AddInt32(&l.len, -1) + } + return reserved +} + +// Add adds connection to the list. The caller must reserve place first. +func (l *connList) Add(cn *Conn) { + l.mx.Lock() + l.cns = append(l.cns, cn) + l.mx.Unlock() +} + +// Remove closes connection and removes it from the list. +func (l *connList) Remove(cn *Conn) error { + defer l.mx.Unlock() + l.mx.Lock() + + if cn == nil { + atomic.AddInt32(&l.len, -1) + return nil + } + + for i, c := range l.cns { + if c == cn { + l.cns = append(l.cns[:i], l.cns[i+1:]...) + atomic.AddInt32(&l.len, -1) + return cn.Close() + } + } + + if l.closed() { + return nil + } + panic("conn not found in the list") +} + +func (l *connList) Replace(cn, newcn *Conn) error { + defer l.mx.Unlock() + l.mx.Lock() + + for i, c := range l.cns { + if c == cn { + l.cns[i] = newcn + return cn.Close() + } + } + + if l.closed() { + return newcn.Close() + } + panic("conn not found in the list") +} + +func (l *connList) Close() (retErr error) { + l.mx.Lock() + for _, c := range l.cns { + if err := c.Close(); err != nil { + retErr = err + } + } + l.cns = nil + atomic.StoreInt32(&l.len, 0) + l.mx.Unlock() + return retErr +} + +func (l *connList) closed() bool { + return l.cns == nil +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go new file mode 100644 index 00000000..ab03195d --- /dev/null +++ b/internal/pool/pool.go @@ -0,0 +1,284 @@ +package pool + +import ( + "errors" + "fmt" + "log" + "sync/atomic" + "time" + + "gopkg.in/bsm/ratelimit.v1" +) + +var Logger *log.Logger + +var ( + errClosed = errors.New("redis: client is closed") + ErrPoolTimeout = errors.New("redis: connection pool timeout") +) + +// PoolStats contains pool state information and accumulated stats. +type PoolStats struct { + Requests uint32 // number of times a connection was requested by the pool + Hits uint32 // number of times free connection was found in the pool + Waits uint32 // number of times the pool had to wait for a connection + Timeouts uint32 // number of times a wait timeout occurred + + TotalConns uint32 // the number of total connections in the pool + FreeConns uint32 // the number of free connections in the pool +} + +type Pooler interface { + First() *Conn + Get() (*Conn, bool, error) + Put(*Conn) error + Remove(*Conn, error) error + Len() int + FreeLen() int + Close() error + Stats() *PoolStats +} + +type dialer func() (*Conn, error) + +type ConnPool struct { + dial dialer + + poolTimeout time.Duration + idleTimeout time.Duration + + rl *ratelimit.RateLimiter + conns *connList + freeConns chan *Conn + stats PoolStats + + _closed int32 + + lastErr atomic.Value +} + +func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Duration) *ConnPool { + p := &ConnPool{ + dial: dial, + + poolTimeout: poolTimeout, + idleTimeout: idleTimeout, + + rl: ratelimit.New(3*poolSize, time.Second), + conns: newConnList(poolSize), + freeConns: make(chan *Conn, poolSize), + } + if idleTimeout > 0 { + go p.reaper() + } + return p +} + +func (p *ConnPool) closed() bool { + return atomic.LoadInt32(&p._closed) == 1 +} + +func (p *ConnPool) isIdle(cn *Conn) bool { + return p.idleTimeout > 0 && time.Since(cn.UsedAt) > p.idleTimeout +} + +// First returns first non-idle connection from the pool or nil if +// there are no connections. +func (p *ConnPool) First() *Conn { + for { + select { + case cn := <-p.freeConns: + if p.isIdle(cn) { + var err error + cn, err = p.replace(cn) + if err != nil { + Logger.Printf("pool.replace failed: %s", err) + continue + } + } + return cn + default: + return nil + } + } + panic("not reached") +} + +// wait waits for free non-idle connection. It returns nil on timeout. +func (p *ConnPool) wait() *Conn { + deadline := time.After(p.poolTimeout) + for { + select { + case cn := <-p.freeConns: + if p.isIdle(cn) { + var err error + cn, err = p.replace(cn) + if err != nil { + Logger.Printf("pool.replace failed: %s", err) + continue + } + } + return cn + case <-deadline: + return nil + } + } + panic("not reached") +} + +// Establish a new connection +func (p *ConnPool) new() (*Conn, error) { + if p.rl.Limit() { + err := fmt.Errorf( + "redis: you open connections too fast (last_error=%q)", + p.loadLastErr(), + ) + return nil, err + } + + cn, err := p.dial() + if err != nil { + p.storeLastErr(err.Error()) + return nil, err + } + + return cn, nil +} + +// Get returns existed connection from the pool or creates a new one. +func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) { + if p.closed() { + err = errClosed + return + } + + atomic.AddUint32(&p.stats.Requests, 1) + + // Fetch first non-idle connection, if available. + if cn = p.First(); cn != nil { + atomic.AddUint32(&p.stats.Hits, 1) + return + } + + // Try to create a new one. + if p.conns.Reserve() { + isNew = true + + cn, err = p.new() + if err != nil { + p.conns.Remove(nil) + return + } + p.conns.Add(cn) + return + } + + // Otherwise, wait for the available connection. + atomic.AddUint32(&p.stats.Waits, 1) + if cn = p.wait(); cn != nil { + return + } + + atomic.AddUint32(&p.stats.Timeouts, 1) + err = ErrPoolTimeout + return +} + +func (p *ConnPool) Put(cn *Conn) error { + if cn.Rd.Buffered() != 0 { + b, _ := cn.Rd.Peek(cn.Rd.Buffered()) + err := fmt.Errorf("connection has unread data: %q", b) + Logger.Print(err) + return p.Remove(cn, err) + } + p.freeConns <- cn + return nil +} + +func (p *ConnPool) replace(cn *Conn) (*Conn, error) { + newcn, err := p.new() + if err != nil { + _ = p.conns.Remove(cn) + return nil, err + } + _ = p.conns.Replace(cn, newcn) + return newcn, nil +} + +func (p *ConnPool) Remove(cn *Conn, reason error) error { + p.storeLastErr(reason.Error()) + + // Replace existing connection with new one and unblock waiter. + newcn, err := p.replace(cn) + if err != nil { + return err + } + p.freeConns <- newcn + return nil +} + +// Len returns total number of connections. +func (p *ConnPool) Len() int { + return p.conns.Len() +} + +// FreeLen returns number of free connections. +func (p *ConnPool) FreeLen() int { + return len(p.freeConns) +} + +func (p *ConnPool) Stats() *PoolStats { + stats := p.stats + stats.Requests = atomic.LoadUint32(&p.stats.Requests) + stats.Waits = atomic.LoadUint32(&p.stats.Waits) + stats.Timeouts = atomic.LoadUint32(&p.stats.Timeouts) + stats.TotalConns = uint32(p.Len()) + stats.FreeConns = uint32(p.FreeLen()) + return &stats +} + +func (p *ConnPool) Close() (retErr error) { + if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) { + return errClosed + } + // Wait for app to free connections, but don't close them immediately. + for i := 0; i < p.Len(); i++ { + if cn := p.wait(); cn == nil { + break + } + } + // Close all connections. + if err := p.conns.Close(); err != nil { + retErr = err + } + return retErr +} + +func (p *ConnPool) reaper() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for _ = range ticker.C { + if p.closed() { + break + } + + // pool.First removes idle connections from the pool and + // returns first non-idle connection. So just put returned + // connection back. + if cn := p.First(); cn != nil { + p.Put(cn) + } + } +} + +func (p *ConnPool) storeLastErr(err string) { + p.lastErr.Store(err) +} + +func (p *ConnPool) loadLastErr() string { + if v := p.lastErr.Load(); v != nil { + return v.(string) + } + return "" +} diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go new file mode 100644 index 00000000..f2d58cf7 --- /dev/null +++ b/internal/pool/pool_single.go @@ -0,0 +1,47 @@ +package pool + +type SingleConnPool struct { + cn *Conn +} + +func NewSingleConnPool(cn *Conn) *SingleConnPool { + return &SingleConnPool{ + cn: cn, + } +} + +func (p *SingleConnPool) First() *Conn { + return p.cn +} + +func (p *SingleConnPool) Get() (*Conn, bool, error) { + return p.cn, false, nil +} + +func (p *SingleConnPool) Put(cn *Conn) error { + if p.cn != cn { + panic("p.cn != cn") + } + return nil +} + +func (p *SingleConnPool) Remove(cn *Conn, _ error) error { + if p.cn != cn { + panic("p.cn != cn") + } + return nil +} + +func (p *SingleConnPool) Len() int { + return 1 +} + +func (p *SingleConnPool) FreeLen() int { + return 0 +} + +func (p *SingleConnPool) Stats() *PoolStats { return nil } + +func (p *SingleConnPool) Close() error { + return nil +} diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go new file mode 100644 index 00000000..c611c4b4 --- /dev/null +++ b/internal/pool/pool_sticky.go @@ -0,0 +1,128 @@ +package pool + +import ( + "errors" + "sync" +) + +type StickyConnPool struct { + pool *ConnPool + reusable bool + + cn *Conn + closed bool + mx sync.Mutex +} + +func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool { + return &StickyConnPool{ + pool: pool, + reusable: reusable, + } +} + +func (p *StickyConnPool) First() *Conn { + p.mx.Lock() + cn := p.cn + p.mx.Unlock() + return cn +} + +func (p *StickyConnPool) Get() (cn *Conn, isNew bool, err error) { + defer p.mx.Unlock() + p.mx.Lock() + + if p.closed { + err = errClosed + return + } + if p.cn != nil { + cn = p.cn + return + } + + cn, isNew, err = p.pool.Get() + if err != nil { + return + } + p.cn = cn + return +} + +func (p *StickyConnPool) put() (err error) { + err = p.pool.Put(p.cn) + p.cn = nil + return err +} + +func (p *StickyConnPool) Put(cn *Conn) error { + defer p.mx.Unlock() + p.mx.Lock() + if p.closed { + return errClosed + } + if p.cn != cn { + panic("p.cn != cn") + } + return nil +} + +func (p *StickyConnPool) remove(reason error) error { + err := p.pool.Remove(p.cn, reason) + p.cn = nil + return err +} + +func (p *StickyConnPool) Remove(cn *Conn, reason error) error { + defer p.mx.Unlock() + p.mx.Lock() + if p.closed { + return errClosed + } + if p.cn == nil { + panic("p.cn == nil") + } + if cn != nil && p.cn != cn { + panic("p.cn != cn") + } + return p.remove(reason) +} + +func (p *StickyConnPool) Len() int { + defer p.mx.Unlock() + p.mx.Lock() + if p.cn == nil { + return 0 + } + return 1 +} + +func (p *StickyConnPool) FreeLen() int { + defer p.mx.Unlock() + p.mx.Lock() + if p.cn == nil { + return 1 + } + return 0 +} + +func (p *StickyConnPool) Stats() *PoolStats { return nil } + +func (p *StickyConnPool) Close() error { + defer p.mx.Unlock() + p.mx.Lock() + if p.closed { + return errClosed + } + p.closed = true + var err error + if p.cn != nil { + if p.reusable { + err = p.put() + } else { + reason := errors.New("redis: sticky not reusable connection") + err = p.remove(reason) + } + } + return err +} diff --git a/multi.go b/multi.go index 1a13d047..6b43591f 100644 --- a/multi.go +++ b/multi.go @@ -3,6 +3,8 @@ package redis import ( "errors" "fmt" + + "gopkg.in/redis.v3/internal/pool" ) var errDiscard = errors.New("redis: Discard can be used only inside Exec") @@ -38,7 +40,7 @@ func (c *Client) Multi() *Multi { multi := &Multi{ base: &baseClient{ opt: c.opt, - connPool: newStickyConnPool(c.connPool, true), + connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true), }, } multi.commandable.process = multi.process @@ -137,8 +139,8 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) { return retCmds, err } -func (c *Multi) execCmds(cn *conn, cmds []Cmder) error { - err := cn.writeCmds(cmds...) +func (c *Multi) execCmds(cn *pool.Conn, cmds []Cmder) error { + err := writeCmd(cn, cmds...) if err != nil { setCmdsErr(cmds[1:len(cmds)-1], err) return err diff --git a/multi_test.go b/multi_test.go index 459d0a62..fa532d1a 100644 --- a/multi_test.go +++ b/multi_test.go @@ -145,7 +145,7 @@ var _ = Describe("Multi", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.SetNetConn(&badConn{}) + cn.NetConn = &badConn{} err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) @@ -172,7 +172,7 @@ var _ = Describe("Multi", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.SetNetConn(&badConn{}) + cn.NetConn = &badConn{} err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) diff --git a/options.go b/options.go new file mode 100644 index 00000000..95a26663 --- /dev/null +++ b/options.go @@ -0,0 +1,144 @@ +package redis + +import ( + "net" + "time" + + "gopkg.in/redis.v3/internal/pool" +) + +type Options struct { + // The network type, either tcp or unix. + // Default is tcp. + Network string + // host:port address. + Addr string + + // Dialer creates new network connection and has priority over + // Network and Addr options. + Dialer func() (net.Conn, error) + + // An optional password. Must match the password specified in the + // requirepass server configuration option. + Password string + // A database to be selected after connecting to server. + DB int64 + + // The maximum number of retries before giving up. + // Default is to not retry failed commands. + MaxRetries int + + // Sets the deadline for establishing new connections. If reached, + // dial will fail with a timeout. + DialTimeout time.Duration + // Sets the deadline for socket reads. If reached, commands will + // fail with a timeout instead of blocking. + ReadTimeout time.Duration + // Sets the deadline for socket writes. If reached, commands will + // fail with a timeout instead of blocking. + WriteTimeout time.Duration + + // The maximum number of socket connections. + // Default is 10 connections. + PoolSize int + // Specifies amount of time client waits for connection if all + // connections are busy before returning an error. + // Default is 1 seconds. + PoolTimeout time.Duration + // Specifies amount of time after which client closes idle + // connections. Should be less than server's timeout. + // Default is to not close idle connections. + IdleTimeout time.Duration +} + +func (opt *Options) getNetwork() string { + if opt.Network == "" { + return "tcp" + } + return opt.Network +} + +func (opt *Options) getDialer() func() (net.Conn, error) { + if opt.Dialer == nil { + opt.Dialer = func() (net.Conn, error) { + return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout()) + } + } + return opt.Dialer +} + +func (opt *Options) getPoolDialer() func() (*pool.Conn, error) { + dial := opt.getDialer() + return func() (*pool.Conn, error) { + netcn, err := dial() + if err != nil { + return nil, err + } + cn := pool.NewConn(netcn) + return cn, opt.initConn(cn) + } +} + +func (opt *Options) getPoolSize() int { + if opt.PoolSize == 0 { + return 10 + } + return opt.PoolSize +} + +func (opt *Options) getDialTimeout() time.Duration { + if opt.DialTimeout == 0 { + return 5 * time.Second + } + return opt.DialTimeout +} + +func (opt *Options) getPoolTimeout() time.Duration { + if opt.PoolTimeout == 0 { + return 1 * time.Second + } + return opt.PoolTimeout +} + +func (opt *Options) getIdleTimeout() time.Duration { + return opt.IdleTimeout +} + +func (opt *Options) initConn(cn *pool.Conn) error { + if opt.Password == "" && opt.DB == 0 { + return nil + } + + // Temp client for Auth and Select. + client := newClient(opt, pool.NewSingleConnPool(cn)) + + if opt.Password != "" { + if err := client.Auth(opt.Password).Err(); err != nil { + return err + } + } + + if opt.DB > 0 { + if err := client.Select(opt.DB).Err(); err != nil { + return err + } + } + + return nil +} + +func newConnPool(opt *Options) *pool.ConnPool { + return pool.NewConnPool( + opt.getPoolDialer(), opt.getPoolSize(), opt.getPoolTimeout(), opt.getIdleTimeout()) +} + +// PoolStats contains pool state information and accumulated stats. +type PoolStats struct { + Requests uint32 // number of times a connection was requested by the pool + Hits uint32 // number of times free connection was found in the pool + Waits uint32 // number of times the pool had to wait for a connection + Timeouts uint32 // number of times a wait timeout occurred + + TotalConns uint32 // the number of total connections in the pool + FreeConns uint32 // the number of free connections in the pool +} diff --git a/parser.go b/parser.go index 758ec8df..2496f66c 100644 --- a/parser.go +++ b/parser.go @@ -6,6 +6,8 @@ import ( "io" "net" "strconv" + + "gopkg.in/redis.v3/internal/pool" ) const ( @@ -16,7 +18,7 @@ const ( arrayReply = '*' ) -type multiBulkParser func(cn *conn, n int64) (interface{}, error) +type multiBulkParser func(cn *pool.Conn, n int64) (interface{}, error) var ( errReaderTooSmall = errors.New("redis: reader is too small") @@ -223,8 +225,8 @@ func scan(b []byte, val interface{}) error { //------------------------------------------------------------------------------ -func readLine(cn *conn) ([]byte, error) { - line, isPrefix, err := cn.rd.ReadLine() +func readLine(cn *pool.Conn) ([]byte, error) { + line, isPrefix, err := cn.Rd.ReadLine() if err != nil { return line, err } @@ -243,28 +245,27 @@ func isNilReply(b []byte) bool { b[1] == '-' && b[2] == '1' } -func readN(cn *conn, n int) ([]byte, error) { - var b []byte - if cap(cn.buf) < n { - b = make([]byte, n) +func readN(cn *pool.Conn, n int) ([]byte, error) { + if d := n - cap(cn.Buf); d > 0 { + cn.Buf = append(cn.Buf, make([]byte, d)...) } else { - b = cn.buf[:n] + cn.Buf = cn.Buf[:n] } - _, err := io.ReadFull(cn.rd, b) - return b, err + _, err := io.ReadFull(cn.Rd, cn.Buf) + return cn.Buf, err } //------------------------------------------------------------------------------ -func parseErrorReply(cn *conn, line []byte) error { +func parseErrorReply(cn *pool.Conn, line []byte) error { return errorf(string(line[1:])) } -func parseStatusReply(cn *conn, line []byte) ([]byte, error) { +func parseStatusReply(cn *pool.Conn, line []byte) ([]byte, error) { return line[1:], nil } -func parseIntReply(cn *conn, line []byte) (int64, error) { +func parseIntReply(cn *pool.Conn, line []byte) (int64, error) { n, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64) if err != nil { return 0, err @@ -272,7 +273,7 @@ func parseIntReply(cn *conn, line []byte) (int64, error) { return n, nil } -func readIntReply(cn *conn) (int64, error) { +func readIntReply(cn *pool.Conn) (int64, error) { line, err := readLine(cn) if err != nil { return 0, err @@ -287,7 +288,7 @@ func readIntReply(cn *conn) (int64, error) { } } -func parseBytesReply(cn *conn, line []byte) ([]byte, error) { +func parseBytesReply(cn *pool.Conn, line []byte) ([]byte, error) { if isNilReply(line) { return nil, Nil } @@ -305,7 +306,7 @@ func parseBytesReply(cn *conn, line []byte) ([]byte, error) { return b[:replyLen], nil } -func readBytesReply(cn *conn) ([]byte, error) { +func readBytesReply(cn *pool.Conn) ([]byte, error) { line, err := readLine(cn) if err != nil { return nil, err @@ -322,7 +323,7 @@ func readBytesReply(cn *conn) ([]byte, error) { } } -func readStringReply(cn *conn) (string, error) { +func readStringReply(cn *pool.Conn) (string, error) { b, err := readBytesReply(cn) if err != nil { return "", err @@ -330,7 +331,7 @@ func readStringReply(cn *conn) (string, error) { return string(b), nil } -func readFloatReply(cn *conn) (float64, error) { +func readFloatReply(cn *pool.Conn) (float64, error) { b, err := readBytesReply(cn) if err != nil { return 0, err @@ -338,7 +339,7 @@ func readFloatReply(cn *conn) (float64, error) { return strconv.ParseFloat(bytesToString(b), 64) } -func parseArrayHeader(cn *conn, line []byte) (int64, error) { +func parseArrayHeader(cn *pool.Conn, line []byte) (int64, error) { if isNilReply(line) { return 0, Nil } @@ -350,7 +351,7 @@ func parseArrayHeader(cn *conn, line []byte) (int64, error) { return n, nil } -func parseArrayReply(cn *conn, p multiBulkParser, line []byte) (interface{}, error) { +func parseArrayReply(cn *pool.Conn, p multiBulkParser, line []byte) (interface{}, error) { n, err := parseArrayHeader(cn, line) if err != nil { return nil, err @@ -358,7 +359,7 @@ func parseArrayReply(cn *conn, p multiBulkParser, line []byte) (interface{}, err return p(cn, n) } -func readArrayHeader(cn *conn) (int64, error) { +func readArrayHeader(cn *pool.Conn) (int64, error) { line, err := readLine(cn) if err != nil { return 0, err @@ -373,7 +374,7 @@ func readArrayHeader(cn *conn) (int64, error) { } } -func readArrayReply(cn *conn, p multiBulkParser) (interface{}, error) { +func readArrayReply(cn *pool.Conn, p multiBulkParser) (interface{}, error) { line, err := readLine(cn) if err != nil { return nil, err @@ -388,7 +389,7 @@ func readArrayReply(cn *conn, p multiBulkParser) (interface{}, error) { } } -func readReply(cn *conn, p multiBulkParser) (interface{}, error) { +func readReply(cn *pool.Conn, p multiBulkParser) (interface{}, error) { line, err := readLine(cn) if err != nil { return nil, err @@ -409,7 +410,7 @@ func readReply(cn *conn, p multiBulkParser) (interface{}, error) { return nil, fmt.Errorf("redis: can't parse %.100q", line) } -func readScanReply(cn *conn) ([]string, int64, error) { +func readScanReply(cn *pool.Conn) ([]string, int64, error) { n, err := readArrayHeader(cn) if err != nil { return nil, 0, err @@ -445,7 +446,7 @@ func readScanReply(cn *conn) ([]string, int64, error) { return keys, cursor, err } -func sliceParser(cn *conn, n int64) (interface{}, error) { +func sliceParser(cn *pool.Conn, n int64) (interface{}, error) { vals := make([]interface{}, 0, n) for i := int64(0); i < n; i++ { v, err := readReply(cn, sliceParser) @@ -465,7 +466,7 @@ func sliceParser(cn *conn, n int64) (interface{}, error) { return vals, nil } -func intSliceParser(cn *conn, n int64) (interface{}, error) { +func intSliceParser(cn *pool.Conn, n int64) (interface{}, error) { ints := make([]int64, 0, n) for i := int64(0); i < n; i++ { n, err := readIntReply(cn) @@ -477,7 +478,7 @@ func intSliceParser(cn *conn, n int64) (interface{}, error) { return ints, nil } -func boolSliceParser(cn *conn, n int64) (interface{}, error) { +func boolSliceParser(cn *pool.Conn, n int64) (interface{}, error) { bools := make([]bool, 0, n) for i := int64(0); i < n; i++ { n, err := readIntReply(cn) @@ -489,7 +490,7 @@ func boolSliceParser(cn *conn, n int64) (interface{}, error) { return bools, nil } -func stringSliceParser(cn *conn, n int64) (interface{}, error) { +func stringSliceParser(cn *pool.Conn, n int64) (interface{}, error) { ss := make([]string, 0, n) for i := int64(0); i < n; i++ { s, err := readStringReply(cn) @@ -504,7 +505,7 @@ func stringSliceParser(cn *conn, n int64) (interface{}, error) { return ss, nil } -func floatSliceParser(cn *conn, n int64) (interface{}, error) { +func floatSliceParser(cn *pool.Conn, n int64) (interface{}, error) { nn := make([]float64, 0, n) for i := int64(0); i < n; i++ { n, err := readFloatReply(cn) @@ -516,7 +517,7 @@ func floatSliceParser(cn *conn, n int64) (interface{}, error) { return nn, nil } -func stringStringMapParser(cn *conn, n int64) (interface{}, error) { +func stringStringMapParser(cn *pool.Conn, n int64) (interface{}, error) { m := make(map[string]string, n/2) for i := int64(0); i < n; i += 2 { key, err := readStringReply(cn) @@ -534,7 +535,7 @@ func stringStringMapParser(cn *conn, n int64) (interface{}, error) { return m, nil } -func stringIntMapParser(cn *conn, n int64) (interface{}, error) { +func stringIntMapParser(cn *pool.Conn, n int64) (interface{}, error) { m := make(map[string]int64, n/2) for i := int64(0); i < n; i += 2 { key, err := readStringReply(cn) @@ -552,7 +553,7 @@ func stringIntMapParser(cn *conn, n int64) (interface{}, error) { return m, nil } -func zSliceParser(cn *conn, n int64) (interface{}, error) { +func zSliceParser(cn *pool.Conn, n int64) (interface{}, error) { zz := make([]Z, n/2) for i := int64(0); i < n; i += 2 { var err error @@ -572,7 +573,7 @@ func zSliceParser(cn *conn, n int64) (interface{}, error) { return zz, nil } -func clusterSlotInfoSliceParser(cn *conn, n int64) (interface{}, error) { +func clusterSlotInfoSliceParser(cn *pool.Conn, n int64) (interface{}, error) { infos := make([]ClusterSlotInfo, 0, n) for i := int64(0); i < n; i++ { n, err := readArrayHeader(cn) @@ -638,7 +639,7 @@ func clusterSlotInfoSliceParser(cn *conn, n int64) (interface{}, error) { } func newGeoLocationParser(q *GeoRadiusQuery) multiBulkParser { - return func(cn *conn, n int64) (interface{}, error) { + return func(cn *pool.Conn, n int64) (interface{}, error) { var loc GeoLocation var err error @@ -682,7 +683,7 @@ func newGeoLocationParser(q *GeoRadiusQuery) multiBulkParser { } func newGeoLocationSliceParser(q *GeoRadiusQuery) multiBulkParser { - return func(cn *conn, n int64) (interface{}, error) { + return func(cn *pool.Conn, n int64) (interface{}, error) { locs := make([]GeoLocation, 0, n) for i := int64(0); i < n; i++ { v, err := readReply(cn, newGeoLocationParser(q)) diff --git a/parser_test.go b/parser_test.go index b1c74344..77287a7a 100644 --- a/parser_test.go +++ b/parser_test.go @@ -4,6 +4,8 @@ import ( "bufio" "bytes" "testing" + + "gopkg.in/redis.v3/internal/pool" ) func BenchmarkParseReplyStatus(b *testing.B) { @@ -31,9 +33,9 @@ func benchmarkParseReply(b *testing.B, reply string, p multiBulkParser, wanterr for i := 0; i < b.N; i++ { buf.WriteString(reply) } - cn := &conn{ - rd: bufio.NewReader(buf), - buf: make([]byte, 0, defaultBufSize), + cn := &pool.Conn{ + Rd: bufio.NewReader(buf), + Buf: make([]byte, 4096), } b.ResetTimer() diff --git a/pipeline.go b/pipeline.go index 8caae6bf..098207c2 100644 --- a/pipeline.go +++ b/pipeline.go @@ -3,6 +3,8 @@ package redis import ( "sync" "sync/atomic" + + "gopkg.in/redis.v3/internal/pool" ) // Pipeline implements pipelining as described in @@ -110,8 +112,8 @@ func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) { return cmds, retErr } -func execCmds(cn *conn, cmds []Cmder) ([]Cmder, error) { - if err := cn.writeCmds(cmds...); err != nil { +func execCmds(cn *pool.Conn, cmds []Cmder) ([]Cmder, error) { + if err := writeCmd(cn, cmds...); err != nil { setCmdsErr(cmds, err) return cmds, err } diff --git a/pool.go b/pool.go deleted file mode 100644 index 3725c408..00000000 --- a/pool.go +++ /dev/null @@ -1,542 +0,0 @@ -package redis - -import ( - "errors" - "fmt" - "sync" - "sync/atomic" - "time" - - "gopkg.in/bsm/ratelimit.v1" -) - -var ( - errClosed = errors.New("redis: client is closed") - errPoolTimeout = errors.New("redis: connection pool timeout") -) - -// PoolStats contains pool state information and accumulated stats. -type PoolStats struct { - Requests uint32 // number of times a connection was requested by the pool - Hits uint32 // number of times free connection was found in the pool - Waits uint32 // number of times the pool had to wait for a connection - Timeouts uint32 // number of times a wait timeout occurred - - TotalConns uint32 // the number of total connections in the pool - FreeConns uint32 // the number of free connections in the pool -} - -type pool interface { - First() *conn - Get() (*conn, bool, error) - Put(*conn) error - Remove(*conn, error) error - Len() int - FreeLen() int - Close() error - Stats() *PoolStats -} - -type connList struct { - cns []*conn - mx sync.Mutex - len int32 // atomic - size int32 -} - -func newConnList(size int) *connList { - return &connList{ - cns: make([]*conn, 0, size), - size: int32(size), - } -} - -func (l *connList) Len() int { - return int(atomic.LoadInt32(&l.len)) -} - -// Reserve reserves place in the list and returns true on success. The -// caller must add or remove connection if place was reserved. -func (l *connList) Reserve() bool { - len := atomic.AddInt32(&l.len, 1) - reserved := len <= l.size - if !reserved { - atomic.AddInt32(&l.len, -1) - } - return reserved -} - -// Add adds connection to the list. The caller must reserve place first. -func (l *connList) Add(cn *conn) { - l.mx.Lock() - l.cns = append(l.cns, cn) - l.mx.Unlock() -} - -// Remove closes connection and removes it from the list. -func (l *connList) Remove(cn *conn) error { - defer l.mx.Unlock() - l.mx.Lock() - - if cn == nil { - atomic.AddInt32(&l.len, -1) - return nil - } - - for i, c := range l.cns { - if c == cn { - l.cns = append(l.cns[:i], l.cns[i+1:]...) - atomic.AddInt32(&l.len, -1) - return cn.Close() - } - } - - if l.closed() { - return nil - } - panic("conn not found in the list") -} - -func (l *connList) Replace(cn, newcn *conn) error { - defer l.mx.Unlock() - l.mx.Lock() - - for i, c := range l.cns { - if c == cn { - l.cns[i] = newcn - return cn.Close() - } - } - - if l.closed() { - return newcn.Close() - } - panic("conn not found in the list") -} - -func (l *connList) Close() (retErr error) { - l.mx.Lock() - for _, c := range l.cns { - if err := c.Close(); err != nil { - retErr = err - } - } - l.cns = nil - atomic.StoreInt32(&l.len, 0) - l.mx.Unlock() - return retErr -} - -func (l *connList) closed() bool { - return l.cns == nil -} - -type connPool struct { - dialer func() (*conn, error) - - rl *ratelimit.RateLimiter - opt *Options - conns *connList - freeConns chan *conn - stats PoolStats - - _closed int32 - - lastErr atomic.Value -} - -func newConnPool(opt *Options) *connPool { - p := &connPool{ - dialer: newConnDialer(opt), - - rl: ratelimit.New(3*opt.getPoolSize(), time.Second), - opt: opt, - conns: newConnList(opt.getPoolSize()), - freeConns: make(chan *conn, opt.getPoolSize()), - } - if p.opt.getIdleTimeout() > 0 { - go p.reaper() - } - return p -} - -func (p *connPool) closed() bool { - return atomic.LoadInt32(&p._closed) == 1 -} - -func (p *connPool) isIdle(cn *conn) bool { - return p.opt.getIdleTimeout() > 0 && time.Since(cn.UsedAt) > p.opt.getIdleTimeout() -} - -// First returns first non-idle connection from the pool or nil if -// there are no connections. -func (p *connPool) First() *conn { - for { - select { - case cn := <-p.freeConns: - if p.isIdle(cn) { - var err error - cn, err = p.replace(cn) - if err != nil { - Logger.Printf("pool.replace failed: %s", err) - continue - } - } - return cn - default: - return nil - } - } - panic("not reached") -} - -// wait waits for free non-idle connection. It returns nil on timeout. -func (p *connPool) wait() *conn { - deadline := time.After(p.opt.getPoolTimeout()) - for { - select { - case cn := <-p.freeConns: - if p.isIdle(cn) { - var err error - cn, err = p.replace(cn) - if err != nil { - Logger.Printf("pool.replace failed: %s", err) - continue - } - } - return cn - case <-deadline: - return nil - } - } - panic("not reached") -} - -// Establish a new connection -func (p *connPool) new() (*conn, error) { - if p.rl.Limit() { - err := fmt.Errorf( - "redis: you open connections too fast (last_error=%q)", - p.loadLastErr(), - ) - return nil, err - } - - cn, err := p.dialer() - if err != nil { - p.storeLastErr(err.Error()) - return nil, err - } - - return cn, nil -} - -// Get returns existed connection from the pool or creates a new one. -func (p *connPool) Get() (cn *conn, isNew bool, err error) { - if p.closed() { - err = errClosed - return - } - - atomic.AddUint32(&p.stats.Requests, 1) - - // Fetch first non-idle connection, if available. - if cn = p.First(); cn != nil { - atomic.AddUint32(&p.stats.Hits, 1) - return - } - - // Try to create a new one. - if p.conns.Reserve() { - isNew = true - - cn, err = p.new() - if err != nil { - p.conns.Remove(nil) - return - } - p.conns.Add(cn) - return - } - - // Otherwise, wait for the available connection. - atomic.AddUint32(&p.stats.Waits, 1) - if cn = p.wait(); cn != nil { - return - } - - atomic.AddUint32(&p.stats.Timeouts, 1) - err = errPoolTimeout - return -} - -func (p *connPool) Put(cn *conn) error { - if cn.rd.Buffered() != 0 { - b, _ := cn.rd.Peek(cn.rd.Buffered()) - err := fmt.Errorf("connection has unread data: %q", b) - Logger.Print(err) - return p.Remove(cn, err) - } - p.freeConns <- cn - return nil -} - -func (p *connPool) replace(cn *conn) (*conn, error) { - newcn, err := p.new() - if err != nil { - _ = p.conns.Remove(cn) - return nil, err - } - _ = p.conns.Replace(cn, newcn) - return newcn, nil -} - -func (p *connPool) Remove(cn *conn, reason error) error { - p.storeLastErr(reason.Error()) - - // Replace existing connection with new one and unblock waiter. - newcn, err := p.replace(cn) - if err != nil { - return err - } - p.freeConns <- newcn - return nil -} - -// Len returns total number of connections. -func (p *connPool) Len() int { - return p.conns.Len() -} - -// FreeLen returns number of free connections. -func (p *connPool) FreeLen() int { - return len(p.freeConns) -} - -func (p *connPool) Stats() *PoolStats { - stats := p.stats - stats.Requests = atomic.LoadUint32(&p.stats.Requests) - stats.Waits = atomic.LoadUint32(&p.stats.Waits) - stats.Timeouts = atomic.LoadUint32(&p.stats.Timeouts) - stats.TotalConns = uint32(p.Len()) - stats.FreeConns = uint32(p.FreeLen()) - return &stats -} - -func (p *connPool) Close() (retErr error) { - if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) { - return errClosed - } - // Wait for app to free connections, but don't close them immediately. - for i := 0; i < p.Len(); i++ { - if cn := p.wait(); cn == nil { - break - } - } - // Close all connections. - if err := p.conns.Close(); err != nil { - retErr = err - } - return retErr -} - -func (p *connPool) reaper() { - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - - for _ = range ticker.C { - if p.closed() { - break - } - - // pool.First removes idle connections from the pool and - // returns first non-idle connection. So just put returned - // connection back. - if cn := p.First(); cn != nil { - p.Put(cn) - } - } -} - -func (p *connPool) storeLastErr(err string) { - p.lastErr.Store(err) -} - -func (p *connPool) loadLastErr() string { - if v := p.lastErr.Load(); v != nil { - return v.(string) - } - return "" -} - -//------------------------------------------------------------------------------ - -type singleConnPool struct { - cn *conn -} - -func newSingleConnPool(cn *conn) *singleConnPool { - return &singleConnPool{ - cn: cn, - } -} - -func (p *singleConnPool) First() *conn { - return p.cn -} - -func (p *singleConnPool) Get() (*conn, bool, error) { - return p.cn, false, nil -} - -func (p *singleConnPool) Put(cn *conn) error { - if p.cn != cn { - panic("p.cn != cn") - } - return nil -} - -func (p *singleConnPool) Remove(cn *conn, _ error) error { - if p.cn != cn { - panic("p.cn != cn") - } - return nil -} - -func (p *singleConnPool) Len() int { - return 1 -} - -func (p *singleConnPool) FreeLen() int { - return 0 -} - -func (p *singleConnPool) Stats() *PoolStats { return nil } - -func (p *singleConnPool) Close() error { - return nil -} - -//------------------------------------------------------------------------------ - -type stickyConnPool struct { - pool pool - reusable bool - - cn *conn - closed bool - mx sync.Mutex -} - -func newStickyConnPool(pool pool, reusable bool) *stickyConnPool { - return &stickyConnPool{ - pool: pool, - reusable: reusable, - } -} - -func (p *stickyConnPool) First() *conn { - p.mx.Lock() - cn := p.cn - p.mx.Unlock() - return cn -} - -func (p *stickyConnPool) Get() (cn *conn, isNew bool, err error) { - defer p.mx.Unlock() - p.mx.Lock() - - if p.closed { - err = errClosed - return - } - if p.cn != nil { - cn = p.cn - return - } - - cn, isNew, err = p.pool.Get() - if err != nil { - return - } - p.cn = cn - return -} - -func (p *stickyConnPool) put() (err error) { - err = p.pool.Put(p.cn) - p.cn = nil - return err -} - -func (p *stickyConnPool) Put(cn *conn) error { - defer p.mx.Unlock() - p.mx.Lock() - if p.closed { - return errClosed - } - if p.cn != cn { - panic("p.cn != cn") - } - return nil -} - -func (p *stickyConnPool) remove(reason error) error { - err := p.pool.Remove(p.cn, reason) - p.cn = nil - return err -} - -func (p *stickyConnPool) Remove(cn *conn, reason error) error { - defer p.mx.Unlock() - p.mx.Lock() - if p.closed { - return errClosed - } - if p.cn == nil { - panic("p.cn == nil") - } - if cn != nil && p.cn != cn { - panic("p.cn != cn") - } - return p.remove(reason) -} - -func (p *stickyConnPool) Len() int { - defer p.mx.Unlock() - p.mx.Lock() - if p.cn == nil { - return 0 - } - return 1 -} - -func (p *stickyConnPool) FreeLen() int { - defer p.mx.Unlock() - p.mx.Lock() - if p.cn == nil { - return 1 - } - return 0 -} - -func (p *stickyConnPool) Stats() *PoolStats { return nil } - -func (p *stickyConnPool) Close() error { - defer p.mx.Unlock() - p.mx.Lock() - if p.closed { - return errClosed - } - p.closed = true - var err error - if p.cn != nil { - if p.reusable { - err = p.put() - } else { - reason := errors.New("redis: sticky not reusable connection") - err = p.remove(reason) - } - } - return err -} diff --git a/pubsub.go b/pubsub.go index c1fb4628..f1c93c8f 100644 --- a/pubsub.go +++ b/pubsub.go @@ -4,6 +4,8 @@ import ( "fmt" "net" "time" + + "gopkg.in/redis.v3/internal/pool" ) // Posts a message to the given channel. @@ -30,7 +32,7 @@ func (c *Client) PubSub() *PubSub { return &PubSub{ base: &baseClient{ opt: c.opt, - connPool: newStickyConnPool(c.connPool, false), + connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), false), }, } } @@ -47,19 +49,20 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) { return pubsub, pubsub.PSubscribe(channels...) } -func (c *PubSub) subscribe(cmd string, channels ...string) error { +func (c *PubSub) subscribe(redisCmd string, channels ...string) error { cn, _, err := c.base.conn() if err != nil { return err } args := make([]interface{}, 1+len(channels)) - args[0] = cmd + args[0] = redisCmd for i, channel := range channels { args[1+i] = channel } - req := NewSliceCmd(args...) - return cn.writeCmds(req) + cmd := NewSliceCmd(args...) + + return writeCmd(cn, cmd) } // Subscribes the client to the specified channels. @@ -132,7 +135,7 @@ func (c *PubSub) Ping(payload string) error { args = append(args, payload) } cmd := NewCmd(args...) - return cn.writeCmds(cmd) + return writeCmd(cn, cmd) } // Message received after a successful subscription to channel. @@ -296,7 +299,7 @@ func (c *PubSub) ReceiveMessage() (*Message, error) { } } -func (c *PubSub) putConn(cn *conn, err error) { +func (c *PubSub) putConn(cn *pool.Conn, err error) { if !c.base.putConn(cn, err, true) { c.nsub = 0 } diff --git a/pubsub_test.go b/pubsub_test.go index 669c0737..a8bb610b 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -291,10 +291,10 @@ var _ = Describe("PubSub", func() { expectReceiveMessageOnError := func(pubsub *redis.PubSub) { cn1, _, err := pubsub.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn1.SetNetConn(&badConn{ + cn1.NetConn = &badConn{ readErr: io.EOF, writeErr: io.EOF, - }) + } done := make(chan bool, 1) go func() { diff --git a/redis.go b/redis.go index 5558ad10..da4b41b3 100644 --- a/redis.go +++ b/redis.go @@ -3,15 +3,26 @@ package redis // import "gopkg.in/redis.v3" import ( "fmt" "log" - "net" "os" - "time" + "sync/atomic" + + "gopkg.in/redis.v3/internal/pool" ) -var Logger = log.New(os.Stderr, "redis: ", log.LstdFlags) +// Deprecated. Use SetLogger instead. +var Logger *log.Logger + +func init() { + SetLogger(log.New(os.Stderr, "redis: ", log.LstdFlags)) +} + +func SetLogger(logger *log.Logger) { + Logger = logger + pool.Logger = logger +} type baseClient struct { - connPool pool + connPool pool.Pooler opt *Options onClose func() error // hook called when client is closed @@ -21,11 +32,11 @@ func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB) } -func (c *baseClient) conn() (*conn, bool, error) { +func (c *baseClient) conn() (*pool.Conn, bool, error) { return c.connPool.Get() } -func (c *baseClient) putConn(cn *conn, err error, allowTimeout bool) bool { +func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { if isBadConn(err, allowTimeout) { err = c.connPool.Remove(cn, err) if err != nil { @@ -61,7 +72,7 @@ func (c *baseClient) process(cmd Cmder) { } cn.WriteTimeout = c.opt.WriteTimeout - if err := cn.writeCmds(cmd); err != nil { + if err := writeCmd(cn, cmd); err != nil { c.putConn(cn, err, false) cmd.setErr(err) if shouldRetry(err) { @@ -99,93 +110,6 @@ func (c *baseClient) Close() error { //------------------------------------------------------------------------------ -type Options struct { - // The network type, either tcp or unix. - // Default is tcp. - Network string - // host:port address. - Addr string - - // Dialer creates new network connection and has priority over - // Network and Addr options. - Dialer func() (net.Conn, error) - - // An optional password. Must match the password specified in the - // requirepass server configuration option. - Password string - // A database to be selected after connecting to server. - DB int64 - - // The maximum number of retries before giving up. - // Default is to not retry failed commands. - MaxRetries int - - // Sets the deadline for establishing new connections. If reached, - // dial will fail with a timeout. - DialTimeout time.Duration - // Sets the deadline for socket reads. If reached, commands will - // fail with a timeout instead of blocking. - ReadTimeout time.Duration - // Sets the deadline for socket writes. If reached, commands will - // fail with a timeout instead of blocking. - WriteTimeout time.Duration - - // The maximum number of socket connections. - // Default is 10 connections. - PoolSize int - // Specifies amount of time client waits for connection if all - // connections are busy before returning an error. - // Default is 1 seconds. - PoolTimeout time.Duration - // Specifies amount of time after which client closes idle - // connections. Should be less than server's timeout. - // Default is to not close idle connections. - IdleTimeout time.Duration -} - -func (opt *Options) getNetwork() string { - if opt.Network == "" { - return "tcp" - } - return opt.Network -} - -func (opt *Options) getDialer() func() (net.Conn, error) { - if opt.Dialer == nil { - opt.Dialer = func() (net.Conn, error) { - return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout()) - } - } - return opt.Dialer -} - -func (opt *Options) getPoolSize() int { - if opt.PoolSize == 0 { - return 10 - } - return opt.PoolSize -} - -func (opt *Options) getDialTimeout() time.Duration { - if opt.DialTimeout == 0 { - return 5 * time.Second - } - return opt.DialTimeout -} - -func (opt *Options) getPoolTimeout() time.Duration { - if opt.PoolTimeout == 0 { - return 1 * time.Second - } - return opt.PoolTimeout -} - -func (opt *Options) getIdleTimeout() time.Duration { - return opt.IdleTimeout -} - -//------------------------------------------------------------------------------ - // Client is a Redis client representing a pool of zero or more // underlying connections. It's safe for concurrent use by multiple // goroutines. @@ -194,7 +118,7 @@ type Client struct { commandable } -func newClient(opt *Options, pool pool) *Client { +func newClient(opt *Options, pool pool.Pooler) *Client { base := baseClient{opt: opt, connPool: pool} return &Client{ baseClient: base, @@ -206,11 +130,19 @@ func newClient(opt *Options, pool pool) *Client { // NewClient returns a client to the Redis Server specified by Options. func NewClient(opt *Options) *Client { - pool := newConnPool(opt) - return newClient(opt, pool) + return newClient(opt, newConnPool(opt)) } -// PoolStats returns connection pool stats +// PoolStats returns connection pool stats. func (c *Client) PoolStats() *PoolStats { - return c.connPool.Stats() + s := c.connPool.Stats() + return &PoolStats{ + Requests: atomic.LoadUint32(&s.Requests), + Hits: atomic.LoadUint32(&s.Hits), + Waits: atomic.LoadUint32(&s.Waits), + Timeouts: atomic.LoadUint32(&s.Timeouts), + + TotalConns: atomic.LoadUint32(&s.TotalConns), + FreeConns: atomic.LoadUint32(&s.FreeConns), + } } diff --git a/redis_test.go b/redis_test.go index 7b5197bc..23c39009 100644 --- a/redis_test.go +++ b/redis_test.go @@ -160,7 +160,7 @@ var _ = Describe("Client", func() { cn, _, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) - cn.SetNetConn(&badConn{}) + cn.NetConn = &badConn{} err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) @@ -174,10 +174,6 @@ var _ = Describe("Client", func() { Expect(cn.UsedAt).NotTo(BeZero()) createdAt := cn.UsedAt - future := time.Now().Add(time.Hour) - redis.SetTime(future) - defer redis.RestoreTime() - err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) Expect(cn.UsedAt.Equal(createdAt)).To(BeTrue()) @@ -187,6 +183,6 @@ var _ = Describe("Client", func() { cn = client.Pool().First() Expect(cn).NotTo(BeNil()) - Expect(cn.UsedAt.Equal(future)).To(BeTrue()) + Expect(cn.UsedAt.After(createdAt)).To(BeTrue()) }) }) diff --git a/ring.go b/ring.go index f1ae8adf..3b88d7ca 100644 --- a/ring.go +++ b/ring.go @@ -8,6 +8,7 @@ import ( "gopkg.in/redis.v3/internal/consistenthash" "gopkg.in/redis.v3/internal/hashtag" + "gopkg.in/redis.v3/internal/pool" ) var ( @@ -200,7 +201,7 @@ func (ring *Ring) heartbeat() { for _, shard := range ring.shards { err := shard.Client.Ping().Err() - if shard.Vote(err == nil || err == errPoolTimeout) { + if shard.Vote(err == nil || err == pool.ErrPoolTimeout) { Logger.Printf("ring shard state changed: %s", shard) rebalance = true } diff --git a/sentinel.go b/sentinel.go index db5db64d..5575e73e 100644 --- a/sentinel.go +++ b/sentinel.go @@ -7,6 +7,8 @@ import ( "strings" "sync" "time" + + "gopkg.in/redis.v3/internal/pool" ) //------------------------------------------------------------------------------ @@ -103,7 +105,7 @@ func (c *sentinelClient) PubSub() *PubSub { return &PubSub{ base: &baseClient{ opt: c.opt, - connPool: newStickyConnPool(c.connPool, false), + connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), false), }, } } @@ -126,7 +128,7 @@ type sentinelFailover struct { opt *Options - pool pool + pool *pool.ConnPool poolOnce sync.Once mu sync.RWMutex @@ -145,7 +147,7 @@ func (d *sentinelFailover) dial() (net.Conn, error) { return net.DialTimeout("tcp", addr, d.opt.DialTimeout) } -func (d *sentinelFailover) Pool() pool { +func (d *sentinelFailover) Pool() *pool.ConnPool { d.poolOnce.Do(func() { d.opt.Dialer = d.dial d.pool = newConnPool(d.opt) @@ -252,7 +254,7 @@ func (d *sentinelFailover) closeOldConns(newMaster string) { // Good connections that should be put back to the pool. They // can't be put immediately, because pool.First will return them // again on next iteration. - cnsToPut := make([]*conn, 0) + cnsToPut := make([]*pool.Conn, 0) for { cn := d.pool.First()