Merge pull request #273 from go-redis/fix/extract-pool-package

Extract pool package. Add pool benchmark.
This commit is contained in:
Vladimir Mihailenco 2016-03-12 11:40:35 +02:00
commit 0d67a1f70f
25 changed files with 969 additions and 898 deletions

View File

@ -2,12 +2,15 @@ package redis_test
import (
"bytes"
"errors"
"net"
"testing"
"time"
redigo "github.com/garyburd/redigo/redis"
"gopkg.in/redis.v3"
"gopkg.in/redis.v3/internal/pool"
)
func benchmarkRedisClient(poolSize int) *redis.Client {
@ -274,11 +277,11 @@ func BenchmarkZAdd(b *testing.B) {
})
}
func BenchmarkPool(b *testing.B) {
client := benchmarkRedisClient(10)
defer client.Close()
pool := client.Pool()
func benchmarkPoolGetPut(b *testing.B, poolSize int) {
dial := func() (*pool.Conn, error) {
return pool.NewConn(&net.TCPConn{}), nil
}
pool := pool.NewConnPool(dial, poolSize, time.Second, 0)
b.ResetTimer()
@ -294,3 +297,49 @@ func BenchmarkPool(b *testing.B) {
}
})
}
func BenchmarkPoolGetPut10Conns(b *testing.B) {
benchmarkPoolGetPut(b, 10)
}
func BenchmarkPoolGetPut100Conns(b *testing.B) {
benchmarkPoolGetPut(b, 100)
}
func BenchmarkPoolGetPut1000Conns(b *testing.B) {
benchmarkPoolGetPut(b, 1000)
}
func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
dial := func() (*pool.Conn, error) {
return pool.NewConn(&net.TCPConn{}), nil
}
pool := pool.NewConnPool(dial, poolSize, time.Second, 0)
removeReason := errors.New("benchmark")
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
conn, _, err := pool.Get()
if err != nil {
b.Fatalf("no error expected on pool.Get but received: %s", err.Error())
}
if err = pool.Remove(conn, removeReason); err != nil {
b.Fatalf("no error expected on pool.Remove but received: %s", err.Error())
}
}
})
}
func BenchmarkPoolGetRemove10Conns(b *testing.B) {
benchmarkPoolGetRemove(b, 10)
}
func BenchmarkPoolGetRemove100Conns(b *testing.B) {
benchmarkPoolGetRemove(b, 100)
}
func BenchmarkPoolGetRemove1000Conns(b *testing.B) {
benchmarkPoolGetRemove(b, 1000)
}

View File

@ -2,6 +2,7 @@ package redis
import (
"gopkg.in/redis.v3/internal/hashtag"
"gopkg.in/redis.v3/internal/pool"
)
// ClusterPipeline is not thread-safe.
@ -96,9 +97,9 @@ func (pipe *ClusterPipeline) Close() error {
}
func (pipe *ClusterPipeline) execClusterCmds(
cn *conn, cmds []Cmder, failedCmds map[string][]Cmder,
cn *pool.Conn, cmds []Cmder, failedCmds map[string][]Cmder,
) (map[string][]Cmder, error) {
if err := cn.writeCmds(cmds...); err != nil {
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return failedCmds, err
}

View File

@ -6,6 +6,8 @@ import (
"strconv"
"strings"
"time"
"gopkg.in/redis.v3/internal/pool"
)
var (
@ -28,7 +30,7 @@ var (
type Cmder interface {
args() []interface{}
readReply(*conn) error
readReply(*pool.Conn) error
setErr(error)
reset()
@ -51,6 +53,20 @@ func resetCmds(cmds []Cmder) {
}
}
func writeCmd(cn *pool.Conn, cmds ...Cmder) error {
cn.Buf = cn.Buf[:0]
for _, cmd := range cmds {
var err error
cn.Buf, err = appendArgs(cn.Buf, cmd.args())
if err != nil {
return err
}
}
_, err := cn.Write(cn.Buf)
return err
}
func cmdString(cmd Cmder, val interface{}) string {
var ss []string
for _, arg := range cmd.args() {
@ -143,7 +159,7 @@ func (cmd *Cmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *Cmd) readReply(cn *conn) error {
func (cmd *Cmd) readReply(cn *pool.Conn) error {
val, err := readReply(cn, sliceParser)
if err != nil {
cmd.err = err
@ -188,7 +204,7 @@ func (cmd *SliceCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *SliceCmd) readReply(cn *conn) error {
func (cmd *SliceCmd) readReply(cn *pool.Conn) error {
v, err := readArrayReply(cn, sliceParser)
if err != nil {
cmd.err = err
@ -231,7 +247,7 @@ func (cmd *StatusCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *StatusCmd) readReply(cn *conn) error {
func (cmd *StatusCmd) readReply(cn *pool.Conn) error {
cmd.val, cmd.err = readStringReply(cn)
return cmd.err
}
@ -265,7 +281,7 @@ func (cmd *IntCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *IntCmd) readReply(cn *conn) error {
func (cmd *IntCmd) readReply(cn *pool.Conn) error {
cmd.val, cmd.err = readIntReply(cn)
return cmd.err
}
@ -303,7 +319,7 @@ func (cmd *DurationCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *DurationCmd) readReply(cn *conn) error {
func (cmd *DurationCmd) readReply(cn *pool.Conn) error {
n, err := readIntReply(cn)
if err != nil {
cmd.err = err
@ -344,7 +360,7 @@ func (cmd *BoolCmd) String() string {
var ok = []byte("OK")
func (cmd *BoolCmd) readReply(cn *conn) error {
func (cmd *BoolCmd) readReply(cn *pool.Conn) error {
v, err := readReply(cn, nil)
// `SET key value NX` returns nil when key already exists. But
// `SETNX key value` returns bool (0/1). So convert nil to bool.
@ -430,13 +446,17 @@ func (cmd *StringCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *StringCmd) readReply(cn *conn) error {
func (cmd *StringCmd) readReply(cn *pool.Conn) error {
b, err := readBytesReply(cn)
if err != nil {
cmd.err = err
return err
}
cmd.val = cn.copyBuf(b)
new := make([]byte, len(b))
copy(new, b)
cmd.val = new
return nil
}
@ -469,7 +489,7 @@ func (cmd *FloatCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *FloatCmd) readReply(cn *conn) error {
func (cmd *FloatCmd) readReply(cn *pool.Conn) error {
cmd.val, cmd.err = readFloatReply(cn)
return cmd.err
}
@ -503,7 +523,7 @@ func (cmd *StringSliceCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *StringSliceCmd) readReply(cn *conn) error {
func (cmd *StringSliceCmd) readReply(cn *pool.Conn) error {
v, err := readArrayReply(cn, stringSliceParser)
if err != nil {
cmd.err = err
@ -542,7 +562,7 @@ func (cmd *BoolSliceCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *BoolSliceCmd) readReply(cn *conn) error {
func (cmd *BoolSliceCmd) readReply(cn *pool.Conn) error {
v, err := readArrayReply(cn, boolSliceParser)
if err != nil {
cmd.err = err
@ -581,7 +601,7 @@ func (cmd *StringStringMapCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *StringStringMapCmd) readReply(cn *conn) error {
func (cmd *StringStringMapCmd) readReply(cn *pool.Conn) error {
v, err := readArrayReply(cn, stringStringMapParser)
if err != nil {
cmd.err = err
@ -620,7 +640,7 @@ func (cmd *StringIntMapCmd) reset() {
cmd.err = nil
}
func (cmd *StringIntMapCmd) readReply(cn *conn) error {
func (cmd *StringIntMapCmd) readReply(cn *pool.Conn) error {
v, err := readArrayReply(cn, stringIntMapParser)
if err != nil {
cmd.err = err
@ -659,7 +679,7 @@ func (cmd *ZSliceCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *ZSliceCmd) readReply(cn *conn) error {
func (cmd *ZSliceCmd) readReply(cn *pool.Conn) error {
v, err := readArrayReply(cn, zSliceParser)
if err != nil {
cmd.err = err
@ -703,7 +723,7 @@ func (cmd *ScanCmd) String() string {
return cmdString(cmd, cmd.keys)
}
func (cmd *ScanCmd) readReply(cn *conn) error {
func (cmd *ScanCmd) readReply(cn *pool.Conn) error {
keys, cursor, err := readScanReply(cn)
if err != nil {
cmd.err = err
@ -751,7 +771,7 @@ func (cmd *ClusterSlotCmd) reset() {
cmd.err = nil
}
func (cmd *ClusterSlotCmd) readReply(cn *conn) error {
func (cmd *ClusterSlotCmd) readReply(cn *pool.Conn) error {
v, err := readArrayReply(cn, clusterSlotInfoSliceParser)
if err != nil {
cmd.err = err
@ -838,7 +858,7 @@ func (cmd *GeoLocationCmd) String() string {
return cmdString(cmd, cmd.locations)
}
func (cmd *GeoLocationCmd) readReply(cn *conn) error {
func (cmd *GeoLocationCmd) readReply(cn *pool.Conn) error {
reply, err := readArrayReply(cn, newGeoLocationSliceParser(cmd.q))
if err != nil {
cmd.err = err

120
conn.go
View File

@ -1,120 +0,0 @@
package redis
import (
"bufio"
"net"
"time"
)
const defaultBufSize = 4096
var noTimeout = time.Time{}
// Stubbed in tests.
var now = time.Now
type conn struct {
netcn net.Conn
rd *bufio.Reader
buf []byte
UsedAt time.Time
ReadTimeout time.Duration
WriteTimeout time.Duration
}
func newConnDialer(opt *Options) func() (*conn, error) {
dialer := opt.getDialer()
return func() (*conn, error) {
netcn, err := dialer()
if err != nil {
return nil, err
}
cn := &conn{
netcn: netcn,
buf: make([]byte, defaultBufSize),
UsedAt: now(),
}
cn.rd = bufio.NewReader(cn)
return cn, cn.init(opt)
}
}
func (cn *conn) init(opt *Options) error {
if opt.Password == "" && opt.DB == 0 {
return nil
}
// Temp client for Auth and Select.
client := newClient(opt, newSingleConnPool(cn))
if opt.Password != "" {
if err := client.Auth(opt.Password).Err(); err != nil {
return err
}
}
if opt.DB > 0 {
if err := client.Select(opt.DB).Err(); err != nil {
return err
}
}
return nil
}
func (cn *conn) writeCmds(cmds ...Cmder) error {
cn.buf = cn.buf[:0]
for _, cmd := range cmds {
var err error
cn.buf, err = appendArgs(cn.buf, cmd.args())
if err != nil {
return err
}
}
_, err := cn.Write(cn.buf)
return err
}
func (cn *conn) Read(b []byte) (int, error) {
cn.UsedAt = now()
if cn.ReadTimeout != 0 {
cn.netcn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout))
} else {
cn.netcn.SetReadDeadline(noTimeout)
}
return cn.netcn.Read(b)
}
func (cn *conn) Write(b []byte) (int, error) {
cn.UsedAt = now()
if cn.WriteTimeout != 0 {
cn.netcn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout))
} else {
cn.netcn.SetWriteDeadline(noTimeout)
}
return cn.netcn.Write(b)
}
func (cn *conn) RemoteAddr() net.Addr {
return cn.netcn.RemoteAddr()
}
func (cn *conn) Close() error {
return cn.netcn.Close()
}
func isSameSlice(s1, s2 []byte) bool {
return len(s1) > 0 && len(s2) > 0 && &s1[0] == &s2[0]
}
func (cn *conn) copyBuf(b []byte) []byte {
if isSameSlice(b, cn.buf) {
new := make([]byte, len(b))
copy(new, b)
return new
}
return b
}

View File

@ -1,25 +0,0 @@
package redis_test
import (
"net"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"gopkg.in/redis.v3"
)
var _ = Describe("newConnDialer with bad connection", func() {
It("should return an error", func() {
dialer := redis.NewConnDialer(&redis.Options{
Dialer: func() (net.Conn, error) {
return &badConn{}, nil
},
MaxRetries: 3,
Password: "password",
DB: 1,
})
_, err := dialer()
Expect(err).To(MatchError("bad connection"))
})
})

View File

@ -1,12 +1,15 @@
package redis
import (
"errors"
"fmt"
"io"
"net"
"strings"
)
var errClosed = errors.New("redis: client is closed")
// Redis nil reply, .e.g. when key does not exist.
var Nil = errorf("redis: nil")

View File

@ -1,30 +1,11 @@
package redis
import (
"net"
"time"
)
import "gopkg.in/redis.v3/internal/pool"
func (c *baseClient) Pool() pool {
func (c *baseClient) Pool() pool.Pooler {
return c.connPool
}
func (c *PubSub) Pool() pool {
func (c *PubSub) Pool() pool.Pooler {
return c.base.connPool
}
var NewConnDialer = newConnDialer
func (cn *conn) SetNetConn(netcn net.Conn) {
cn.netcn = netcn
}
func SetTime(tm time.Time) {
now = func() time.Time {
return tm
}
}
func RestoreTime() {
now = time.Now
}

60
internal/pool/conn.go Normal file
View File

@ -0,0 +1,60 @@
package pool
import (
"bufio"
"net"
"time"
)
const defaultBufSize = 4096
var noTimeout = time.Time{}
type Conn struct {
NetConn net.Conn
Rd *bufio.Reader
Buf []byte
UsedAt time.Time
ReadTimeout time.Duration
WriteTimeout time.Duration
}
func NewConn(netConn net.Conn) *Conn {
cn := &Conn{
NetConn: netConn,
Buf: make([]byte, defaultBufSize),
UsedAt: time.Now(),
}
cn.Rd = bufio.NewReader(cn)
return cn
}
func (cn *Conn) Read(b []byte) (int, error) {
cn.UsedAt = time.Now()
if cn.ReadTimeout != 0 {
cn.NetConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout))
} else {
cn.NetConn.SetReadDeadline(noTimeout)
}
return cn.NetConn.Read(b)
}
func (cn *Conn) Write(b []byte) (int, error) {
cn.UsedAt = time.Now()
if cn.WriteTimeout != 0 {
cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout))
} else {
cn.NetConn.SetWriteDeadline(noTimeout)
}
return cn.NetConn.Write(b)
}
func (cn *Conn) RemoteAddr() net.Addr {
return cn.NetConn.RemoteAddr()
}
func (cn *Conn) Close() error {
return cn.NetConn.Close()
}

100
internal/pool/conn_list.go Normal file
View File

@ -0,0 +1,100 @@
package pool
import (
"sync"
"sync/atomic"
)
type connList struct {
cns []*Conn
mx sync.Mutex
len int32 // atomic
size int32
}
func newConnList(size int) *connList {
return &connList{
cns: make([]*Conn, 0, size),
size: int32(size),
}
}
func (l *connList) Len() int {
return int(atomic.LoadInt32(&l.len))
}
// Reserve reserves place in the list and returns true on success. The
// caller must add or remove connection if place was reserved.
func (l *connList) Reserve() bool {
len := atomic.AddInt32(&l.len, 1)
reserved := len <= l.size
if !reserved {
atomic.AddInt32(&l.len, -1)
}
return reserved
}
// Add adds connection to the list. The caller must reserve place first.
func (l *connList) Add(cn *Conn) {
l.mx.Lock()
l.cns = append(l.cns, cn)
l.mx.Unlock()
}
// Remove closes connection and removes it from the list.
func (l *connList) Remove(cn *Conn) error {
defer l.mx.Unlock()
l.mx.Lock()
if cn == nil {
atomic.AddInt32(&l.len, -1)
return nil
}
for i, c := range l.cns {
if c == cn {
l.cns = append(l.cns[:i], l.cns[i+1:]...)
atomic.AddInt32(&l.len, -1)
return cn.Close()
}
}
if l.closed() {
return nil
}
panic("conn not found in the list")
}
func (l *connList) Replace(cn, newcn *Conn) error {
defer l.mx.Unlock()
l.mx.Lock()
for i, c := range l.cns {
if c == cn {
l.cns[i] = newcn
return cn.Close()
}
}
if l.closed() {
return newcn.Close()
}
panic("conn not found in the list")
}
func (l *connList) Close() (retErr error) {
l.mx.Lock()
for _, c := range l.cns {
if err := c.Close(); err != nil {
retErr = err
}
}
l.cns = nil
atomic.StoreInt32(&l.len, 0)
l.mx.Unlock()
return retErr
}
func (l *connList) closed() bool {
return l.cns == nil
}

284
internal/pool/pool.go Normal file
View File

@ -0,0 +1,284 @@
package pool
import (
"errors"
"fmt"
"log"
"sync/atomic"
"time"
"gopkg.in/bsm/ratelimit.v1"
)
var Logger *log.Logger
var (
errClosed = errors.New("redis: client is closed")
ErrPoolTimeout = errors.New("redis: connection pool timeout")
)
// PoolStats contains pool state information and accumulated stats.
type PoolStats struct {
Requests uint32 // number of times a connection was requested by the pool
Hits uint32 // number of times free connection was found in the pool
Waits uint32 // number of times the pool had to wait for a connection
Timeouts uint32 // number of times a wait timeout occurred
TotalConns uint32 // the number of total connections in the pool
FreeConns uint32 // the number of free connections in the pool
}
type Pooler interface {
First() *Conn
Get() (*Conn, bool, error)
Put(*Conn) error
Remove(*Conn, error) error
Len() int
FreeLen() int
Close() error
Stats() *PoolStats
}
type dialer func() (*Conn, error)
type ConnPool struct {
dial dialer
poolTimeout time.Duration
idleTimeout time.Duration
rl *ratelimit.RateLimiter
conns *connList
freeConns chan *Conn
stats PoolStats
_closed int32
lastErr atomic.Value
}
func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Duration) *ConnPool {
p := &ConnPool{
dial: dial,
poolTimeout: poolTimeout,
idleTimeout: idleTimeout,
rl: ratelimit.New(3*poolSize, time.Second),
conns: newConnList(poolSize),
freeConns: make(chan *Conn, poolSize),
}
if idleTimeout > 0 {
go p.reaper()
}
return p
}
func (p *ConnPool) closed() bool {
return atomic.LoadInt32(&p._closed) == 1
}
func (p *ConnPool) isIdle(cn *Conn) bool {
return p.idleTimeout > 0 && time.Since(cn.UsedAt) > p.idleTimeout
}
// First returns first non-idle connection from the pool or nil if
// there are no connections.
func (p *ConnPool) First() *Conn {
for {
select {
case cn := <-p.freeConns:
if p.isIdle(cn) {
var err error
cn, err = p.replace(cn)
if err != nil {
Logger.Printf("pool.replace failed: %s", err)
continue
}
}
return cn
default:
return nil
}
}
panic("not reached")
}
// wait waits for free non-idle connection. It returns nil on timeout.
func (p *ConnPool) wait() *Conn {
deadline := time.After(p.poolTimeout)
for {
select {
case cn := <-p.freeConns:
if p.isIdle(cn) {
var err error
cn, err = p.replace(cn)
if err != nil {
Logger.Printf("pool.replace failed: %s", err)
continue
}
}
return cn
case <-deadline:
return nil
}
}
panic("not reached")
}
// Establish a new connection
func (p *ConnPool) new() (*Conn, error) {
if p.rl.Limit() {
err := fmt.Errorf(
"redis: you open connections too fast (last_error=%q)",
p.loadLastErr(),
)
return nil, err
}
cn, err := p.dial()
if err != nil {
p.storeLastErr(err.Error())
return nil, err
}
return cn, nil
}
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) {
if p.closed() {
err = errClosed
return
}
atomic.AddUint32(&p.stats.Requests, 1)
// Fetch first non-idle connection, if available.
if cn = p.First(); cn != nil {
atomic.AddUint32(&p.stats.Hits, 1)
return
}
// Try to create a new one.
if p.conns.Reserve() {
isNew = true
cn, err = p.new()
if err != nil {
p.conns.Remove(nil)
return
}
p.conns.Add(cn)
return
}
// Otherwise, wait for the available connection.
atomic.AddUint32(&p.stats.Waits, 1)
if cn = p.wait(); cn != nil {
return
}
atomic.AddUint32(&p.stats.Timeouts, 1)
err = ErrPoolTimeout
return
}
func (p *ConnPool) Put(cn *Conn) error {
if cn.Rd.Buffered() != 0 {
b, _ := cn.Rd.Peek(cn.Rd.Buffered())
err := fmt.Errorf("connection has unread data: %q", b)
Logger.Print(err)
return p.Remove(cn, err)
}
p.freeConns <- cn
return nil
}
func (p *ConnPool) replace(cn *Conn) (*Conn, error) {
newcn, err := p.new()
if err != nil {
_ = p.conns.Remove(cn)
return nil, err
}
_ = p.conns.Replace(cn, newcn)
return newcn, nil
}
func (p *ConnPool) Remove(cn *Conn, reason error) error {
p.storeLastErr(reason.Error())
// Replace existing connection with new one and unblock waiter.
newcn, err := p.replace(cn)
if err != nil {
return err
}
p.freeConns <- newcn
return nil
}
// Len returns total number of connections.
func (p *ConnPool) Len() int {
return p.conns.Len()
}
// FreeLen returns number of free connections.
func (p *ConnPool) FreeLen() int {
return len(p.freeConns)
}
func (p *ConnPool) Stats() *PoolStats {
stats := p.stats
stats.Requests = atomic.LoadUint32(&p.stats.Requests)
stats.Waits = atomic.LoadUint32(&p.stats.Waits)
stats.Timeouts = atomic.LoadUint32(&p.stats.Timeouts)
stats.TotalConns = uint32(p.Len())
stats.FreeConns = uint32(p.FreeLen())
return &stats
}
func (p *ConnPool) Close() (retErr error) {
if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) {
return errClosed
}
// Wait for app to free connections, but don't close them immediately.
for i := 0; i < p.Len(); i++ {
if cn := p.wait(); cn == nil {
break
}
}
// Close all connections.
if err := p.conns.Close(); err != nil {
retErr = err
}
return retErr
}
func (p *ConnPool) reaper() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for _ = range ticker.C {
if p.closed() {
break
}
// pool.First removes idle connections from the pool and
// returns first non-idle connection. So just put returned
// connection back.
if cn := p.First(); cn != nil {
p.Put(cn)
}
}
}
func (p *ConnPool) storeLastErr(err string) {
p.lastErr.Store(err)
}
func (p *ConnPool) loadLastErr() string {
if v := p.lastErr.Load(); v != nil {
return v.(string)
}
return ""
}

View File

@ -0,0 +1,47 @@
package pool
type SingleConnPool struct {
cn *Conn
}
func NewSingleConnPool(cn *Conn) *SingleConnPool {
return &SingleConnPool{
cn: cn,
}
}
func (p *SingleConnPool) First() *Conn {
return p.cn
}
func (p *SingleConnPool) Get() (*Conn, bool, error) {
return p.cn, false, nil
}
func (p *SingleConnPool) Put(cn *Conn) error {
if p.cn != cn {
panic("p.cn != cn")
}
return nil
}
func (p *SingleConnPool) Remove(cn *Conn, _ error) error {
if p.cn != cn {
panic("p.cn != cn")
}
return nil
}
func (p *SingleConnPool) Len() int {
return 1
}
func (p *SingleConnPool) FreeLen() int {
return 0
}
func (p *SingleConnPool) Stats() *PoolStats { return nil }
func (p *SingleConnPool) Close() error {
return nil
}

View File

@ -0,0 +1,128 @@
package pool
import (
"errors"
"sync"
)
type StickyConnPool struct {
pool *ConnPool
reusable bool
cn *Conn
closed bool
mx sync.Mutex
}
func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool {
return &StickyConnPool{
pool: pool,
reusable: reusable,
}
}
func (p *StickyConnPool) First() *Conn {
p.mx.Lock()
cn := p.cn
p.mx.Unlock()
return cn
}
func (p *StickyConnPool) Get() (cn *Conn, isNew bool, err error) {
defer p.mx.Unlock()
p.mx.Lock()
if p.closed {
err = errClosed
return
}
if p.cn != nil {
cn = p.cn
return
}
cn, isNew, err = p.pool.Get()
if err != nil {
return
}
p.cn = cn
return
}
func (p *StickyConnPool) put() (err error) {
err = p.pool.Put(p.cn)
p.cn = nil
return err
}
func (p *StickyConnPool) Put(cn *Conn) error {
defer p.mx.Unlock()
p.mx.Lock()
if p.closed {
return errClosed
}
if p.cn != cn {
panic("p.cn != cn")
}
return nil
}
func (p *StickyConnPool) remove(reason error) error {
err := p.pool.Remove(p.cn, reason)
p.cn = nil
return err
}
func (p *StickyConnPool) Remove(cn *Conn, reason error) error {
defer p.mx.Unlock()
p.mx.Lock()
if p.closed {
return errClosed
}
if p.cn == nil {
panic("p.cn == nil")
}
if cn != nil && p.cn != cn {
panic("p.cn != cn")
}
return p.remove(reason)
}
func (p *StickyConnPool) Len() int {
defer p.mx.Unlock()
p.mx.Lock()
if p.cn == nil {
return 0
}
return 1
}
func (p *StickyConnPool) FreeLen() int {
defer p.mx.Unlock()
p.mx.Lock()
if p.cn == nil {
return 1
}
return 0
}
func (p *StickyConnPool) Stats() *PoolStats { return nil }
func (p *StickyConnPool) Close() error {
defer p.mx.Unlock()
p.mx.Lock()
if p.closed {
return errClosed
}
p.closed = true
var err error
if p.cn != nil {
if p.reusable {
err = p.put()
} else {
reason := errors.New("redis: sticky not reusable connection")
err = p.remove(reason)
}
}
return err
}

View File

@ -3,6 +3,8 @@ package redis
import (
"errors"
"fmt"
"gopkg.in/redis.v3/internal/pool"
)
var errDiscard = errors.New("redis: Discard can be used only inside Exec")
@ -38,7 +40,7 @@ func (c *Client) Multi() *Multi {
multi := &Multi{
base: &baseClient{
opt: c.opt,
connPool: newStickyConnPool(c.connPool, true),
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true),
},
}
multi.commandable.process = multi.process
@ -137,8 +139,8 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) {
return retCmds, err
}
func (c *Multi) execCmds(cn *conn, cmds []Cmder) error {
err := cn.writeCmds(cmds...)
func (c *Multi) execCmds(cn *pool.Conn, cmds []Cmder) error {
err := writeCmd(cn, cmds...)
if err != nil {
setCmdsErr(cmds[1:len(cmds)-1], err)
return err

View File

@ -145,7 +145,7 @@ var _ = Describe("Multi", func() {
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
cn.NetConn = &badConn{}
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())
@ -172,7 +172,7 @@ var _ = Describe("Multi", func() {
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
cn.NetConn = &badConn{}
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())

144
options.go Normal file
View File

@ -0,0 +1,144 @@
package redis
import (
"net"
"time"
"gopkg.in/redis.v3/internal/pool"
)
type Options struct {
// The network type, either tcp or unix.
// Default is tcp.
Network string
// host:port address.
Addr string
// Dialer creates new network connection and has priority over
// Network and Addr options.
Dialer func() (net.Conn, error)
// An optional password. Must match the password specified in the
// requirepass server configuration option.
Password string
// A database to be selected after connecting to server.
DB int64
// The maximum number of retries before giving up.
// Default is to not retry failed commands.
MaxRetries int
// Sets the deadline for establishing new connections. If reached,
// dial will fail with a timeout.
DialTimeout time.Duration
// Sets the deadline for socket reads. If reached, commands will
// fail with a timeout instead of blocking.
ReadTimeout time.Duration
// Sets the deadline for socket writes. If reached, commands will
// fail with a timeout instead of blocking.
WriteTimeout time.Duration
// The maximum number of socket connections.
// Default is 10 connections.
PoolSize int
// Specifies amount of time client waits for connection if all
// connections are busy before returning an error.
// Default is 1 seconds.
PoolTimeout time.Duration
// Specifies amount of time after which client closes idle
// connections. Should be less than server's timeout.
// Default is to not close idle connections.
IdleTimeout time.Duration
}
func (opt *Options) getNetwork() string {
if opt.Network == "" {
return "tcp"
}
return opt.Network
}
func (opt *Options) getDialer() func() (net.Conn, error) {
if opt.Dialer == nil {
opt.Dialer = func() (net.Conn, error) {
return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout())
}
}
return opt.Dialer
}
func (opt *Options) getPoolDialer() func() (*pool.Conn, error) {
dial := opt.getDialer()
return func() (*pool.Conn, error) {
netcn, err := dial()
if err != nil {
return nil, err
}
cn := pool.NewConn(netcn)
return cn, opt.initConn(cn)
}
}
func (opt *Options) getPoolSize() int {
if opt.PoolSize == 0 {
return 10
}
return opt.PoolSize
}
func (opt *Options) getDialTimeout() time.Duration {
if opt.DialTimeout == 0 {
return 5 * time.Second
}
return opt.DialTimeout
}
func (opt *Options) getPoolTimeout() time.Duration {
if opt.PoolTimeout == 0 {
return 1 * time.Second
}
return opt.PoolTimeout
}
func (opt *Options) getIdleTimeout() time.Duration {
return opt.IdleTimeout
}
func (opt *Options) initConn(cn *pool.Conn) error {
if opt.Password == "" && opt.DB == 0 {
return nil
}
// Temp client for Auth and Select.
client := newClient(opt, pool.NewSingleConnPool(cn))
if opt.Password != "" {
if err := client.Auth(opt.Password).Err(); err != nil {
return err
}
}
if opt.DB > 0 {
if err := client.Select(opt.DB).Err(); err != nil {
return err
}
}
return nil
}
func newConnPool(opt *Options) *pool.ConnPool {
return pool.NewConnPool(
opt.getPoolDialer(), opt.getPoolSize(), opt.getPoolTimeout(), opt.getIdleTimeout())
}
// PoolStats contains pool state information and accumulated stats.
type PoolStats struct {
Requests uint32 // number of times a connection was requested by the pool
Hits uint32 // number of times free connection was found in the pool
Waits uint32 // number of times the pool had to wait for a connection
Timeouts uint32 // number of times a wait timeout occurred
TotalConns uint32 // the number of total connections in the pool
FreeConns uint32 // the number of free connections in the pool
}

View File

@ -6,6 +6,8 @@ import (
"io"
"net"
"strconv"
"gopkg.in/redis.v3/internal/pool"
)
const (
@ -16,7 +18,7 @@ const (
arrayReply = '*'
)
type multiBulkParser func(cn *conn, n int64) (interface{}, error)
type multiBulkParser func(cn *pool.Conn, n int64) (interface{}, error)
var (
errReaderTooSmall = errors.New("redis: reader is too small")
@ -223,8 +225,8 @@ func scan(b []byte, val interface{}) error {
//------------------------------------------------------------------------------
func readLine(cn *conn) ([]byte, error) {
line, isPrefix, err := cn.rd.ReadLine()
func readLine(cn *pool.Conn) ([]byte, error) {
line, isPrefix, err := cn.Rd.ReadLine()
if err != nil {
return line, err
}
@ -243,28 +245,27 @@ func isNilReply(b []byte) bool {
b[1] == '-' && b[2] == '1'
}
func readN(cn *conn, n int) ([]byte, error) {
var b []byte
if cap(cn.buf) < n {
b = make([]byte, n)
func readN(cn *pool.Conn, n int) ([]byte, error) {
if d := n - cap(cn.Buf); d > 0 {
cn.Buf = append(cn.Buf, make([]byte, d)...)
} else {
b = cn.buf[:n]
cn.Buf = cn.Buf[:n]
}
_, err := io.ReadFull(cn.rd, b)
return b, err
_, err := io.ReadFull(cn.Rd, cn.Buf)
return cn.Buf, err
}
//------------------------------------------------------------------------------
func parseErrorReply(cn *conn, line []byte) error {
func parseErrorReply(cn *pool.Conn, line []byte) error {
return errorf(string(line[1:]))
}
func parseStatusReply(cn *conn, line []byte) ([]byte, error) {
func parseStatusReply(cn *pool.Conn, line []byte) ([]byte, error) {
return line[1:], nil
}
func parseIntReply(cn *conn, line []byte) (int64, error) {
func parseIntReply(cn *pool.Conn, line []byte) (int64, error) {
n, err := strconv.ParseInt(bytesToString(line[1:]), 10, 64)
if err != nil {
return 0, err
@ -272,7 +273,7 @@ func parseIntReply(cn *conn, line []byte) (int64, error) {
return n, nil
}
func readIntReply(cn *conn) (int64, error) {
func readIntReply(cn *pool.Conn) (int64, error) {
line, err := readLine(cn)
if err != nil {
return 0, err
@ -287,7 +288,7 @@ func readIntReply(cn *conn) (int64, error) {
}
}
func parseBytesReply(cn *conn, line []byte) ([]byte, error) {
func parseBytesReply(cn *pool.Conn, line []byte) ([]byte, error) {
if isNilReply(line) {
return nil, Nil
}
@ -305,7 +306,7 @@ func parseBytesReply(cn *conn, line []byte) ([]byte, error) {
return b[:replyLen], nil
}
func readBytesReply(cn *conn) ([]byte, error) {
func readBytesReply(cn *pool.Conn) ([]byte, error) {
line, err := readLine(cn)
if err != nil {
return nil, err
@ -322,7 +323,7 @@ func readBytesReply(cn *conn) ([]byte, error) {
}
}
func readStringReply(cn *conn) (string, error) {
func readStringReply(cn *pool.Conn) (string, error) {
b, err := readBytesReply(cn)
if err != nil {
return "", err
@ -330,7 +331,7 @@ func readStringReply(cn *conn) (string, error) {
return string(b), nil
}
func readFloatReply(cn *conn) (float64, error) {
func readFloatReply(cn *pool.Conn) (float64, error) {
b, err := readBytesReply(cn)
if err != nil {
return 0, err
@ -338,7 +339,7 @@ func readFloatReply(cn *conn) (float64, error) {
return strconv.ParseFloat(bytesToString(b), 64)
}
func parseArrayHeader(cn *conn, line []byte) (int64, error) {
func parseArrayHeader(cn *pool.Conn, line []byte) (int64, error) {
if isNilReply(line) {
return 0, Nil
}
@ -350,7 +351,7 @@ func parseArrayHeader(cn *conn, line []byte) (int64, error) {
return n, nil
}
func parseArrayReply(cn *conn, p multiBulkParser, line []byte) (interface{}, error) {
func parseArrayReply(cn *pool.Conn, p multiBulkParser, line []byte) (interface{}, error) {
n, err := parseArrayHeader(cn, line)
if err != nil {
return nil, err
@ -358,7 +359,7 @@ func parseArrayReply(cn *conn, p multiBulkParser, line []byte) (interface{}, err
return p(cn, n)
}
func readArrayHeader(cn *conn) (int64, error) {
func readArrayHeader(cn *pool.Conn) (int64, error) {
line, err := readLine(cn)
if err != nil {
return 0, err
@ -373,7 +374,7 @@ func readArrayHeader(cn *conn) (int64, error) {
}
}
func readArrayReply(cn *conn, p multiBulkParser) (interface{}, error) {
func readArrayReply(cn *pool.Conn, p multiBulkParser) (interface{}, error) {
line, err := readLine(cn)
if err != nil {
return nil, err
@ -388,7 +389,7 @@ func readArrayReply(cn *conn, p multiBulkParser) (interface{}, error) {
}
}
func readReply(cn *conn, p multiBulkParser) (interface{}, error) {
func readReply(cn *pool.Conn, p multiBulkParser) (interface{}, error) {
line, err := readLine(cn)
if err != nil {
return nil, err
@ -409,7 +410,7 @@ func readReply(cn *conn, p multiBulkParser) (interface{}, error) {
return nil, fmt.Errorf("redis: can't parse %.100q", line)
}
func readScanReply(cn *conn) ([]string, int64, error) {
func readScanReply(cn *pool.Conn) ([]string, int64, error) {
n, err := readArrayHeader(cn)
if err != nil {
return nil, 0, err
@ -445,7 +446,7 @@ func readScanReply(cn *conn) ([]string, int64, error) {
return keys, cursor, err
}
func sliceParser(cn *conn, n int64) (interface{}, error) {
func sliceParser(cn *pool.Conn, n int64) (interface{}, error) {
vals := make([]interface{}, 0, n)
for i := int64(0); i < n; i++ {
v, err := readReply(cn, sliceParser)
@ -465,7 +466,7 @@ func sliceParser(cn *conn, n int64) (interface{}, error) {
return vals, nil
}
func intSliceParser(cn *conn, n int64) (interface{}, error) {
func intSliceParser(cn *pool.Conn, n int64) (interface{}, error) {
ints := make([]int64, 0, n)
for i := int64(0); i < n; i++ {
n, err := readIntReply(cn)
@ -477,7 +478,7 @@ func intSliceParser(cn *conn, n int64) (interface{}, error) {
return ints, nil
}
func boolSliceParser(cn *conn, n int64) (interface{}, error) {
func boolSliceParser(cn *pool.Conn, n int64) (interface{}, error) {
bools := make([]bool, 0, n)
for i := int64(0); i < n; i++ {
n, err := readIntReply(cn)
@ -489,7 +490,7 @@ func boolSliceParser(cn *conn, n int64) (interface{}, error) {
return bools, nil
}
func stringSliceParser(cn *conn, n int64) (interface{}, error) {
func stringSliceParser(cn *pool.Conn, n int64) (interface{}, error) {
ss := make([]string, 0, n)
for i := int64(0); i < n; i++ {
s, err := readStringReply(cn)
@ -504,7 +505,7 @@ func stringSliceParser(cn *conn, n int64) (interface{}, error) {
return ss, nil
}
func floatSliceParser(cn *conn, n int64) (interface{}, error) {
func floatSliceParser(cn *pool.Conn, n int64) (interface{}, error) {
nn := make([]float64, 0, n)
for i := int64(0); i < n; i++ {
n, err := readFloatReply(cn)
@ -516,7 +517,7 @@ func floatSliceParser(cn *conn, n int64) (interface{}, error) {
return nn, nil
}
func stringStringMapParser(cn *conn, n int64) (interface{}, error) {
func stringStringMapParser(cn *pool.Conn, n int64) (interface{}, error) {
m := make(map[string]string, n/2)
for i := int64(0); i < n; i += 2 {
key, err := readStringReply(cn)
@ -534,7 +535,7 @@ func stringStringMapParser(cn *conn, n int64) (interface{}, error) {
return m, nil
}
func stringIntMapParser(cn *conn, n int64) (interface{}, error) {
func stringIntMapParser(cn *pool.Conn, n int64) (interface{}, error) {
m := make(map[string]int64, n/2)
for i := int64(0); i < n; i += 2 {
key, err := readStringReply(cn)
@ -552,7 +553,7 @@ func stringIntMapParser(cn *conn, n int64) (interface{}, error) {
return m, nil
}
func zSliceParser(cn *conn, n int64) (interface{}, error) {
func zSliceParser(cn *pool.Conn, n int64) (interface{}, error) {
zz := make([]Z, n/2)
for i := int64(0); i < n; i += 2 {
var err error
@ -572,7 +573,7 @@ func zSliceParser(cn *conn, n int64) (interface{}, error) {
return zz, nil
}
func clusterSlotInfoSliceParser(cn *conn, n int64) (interface{}, error) {
func clusterSlotInfoSliceParser(cn *pool.Conn, n int64) (interface{}, error) {
infos := make([]ClusterSlotInfo, 0, n)
for i := int64(0); i < n; i++ {
n, err := readArrayHeader(cn)
@ -638,7 +639,7 @@ func clusterSlotInfoSliceParser(cn *conn, n int64) (interface{}, error) {
}
func newGeoLocationParser(q *GeoRadiusQuery) multiBulkParser {
return func(cn *conn, n int64) (interface{}, error) {
return func(cn *pool.Conn, n int64) (interface{}, error) {
var loc GeoLocation
var err error
@ -682,7 +683,7 @@ func newGeoLocationParser(q *GeoRadiusQuery) multiBulkParser {
}
func newGeoLocationSliceParser(q *GeoRadiusQuery) multiBulkParser {
return func(cn *conn, n int64) (interface{}, error) {
return func(cn *pool.Conn, n int64) (interface{}, error) {
locs := make([]GeoLocation, 0, n)
for i := int64(0); i < n; i++ {
v, err := readReply(cn, newGeoLocationParser(q))

View File

@ -4,6 +4,8 @@ import (
"bufio"
"bytes"
"testing"
"gopkg.in/redis.v3/internal/pool"
)
func BenchmarkParseReplyStatus(b *testing.B) {
@ -31,9 +33,9 @@ func benchmarkParseReply(b *testing.B, reply string, p multiBulkParser, wanterr
for i := 0; i < b.N; i++ {
buf.WriteString(reply)
}
cn := &conn{
rd: bufio.NewReader(buf),
buf: make([]byte, 0, defaultBufSize),
cn := &pool.Conn{
Rd: bufio.NewReader(buf),
Buf: make([]byte, 4096),
}
b.ResetTimer()

View File

@ -3,6 +3,8 @@ package redis
import (
"sync"
"sync/atomic"
"gopkg.in/redis.v3/internal/pool"
)
// Pipeline implements pipelining as described in
@ -110,8 +112,8 @@ func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) {
return cmds, retErr
}
func execCmds(cn *conn, cmds []Cmder) ([]Cmder, error) {
if err := cn.writeCmds(cmds...); err != nil {
func execCmds(cn *pool.Conn, cmds []Cmder) ([]Cmder, error) {
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return cmds, err
}

542
pool.go
View File

@ -1,542 +0,0 @@
package redis
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"gopkg.in/bsm/ratelimit.v1"
)
var (
errClosed = errors.New("redis: client is closed")
errPoolTimeout = errors.New("redis: connection pool timeout")
)
// PoolStats contains pool state information and accumulated stats.
type PoolStats struct {
Requests uint32 // number of times a connection was requested by the pool
Hits uint32 // number of times free connection was found in the pool
Waits uint32 // number of times the pool had to wait for a connection
Timeouts uint32 // number of times a wait timeout occurred
TotalConns uint32 // the number of total connections in the pool
FreeConns uint32 // the number of free connections in the pool
}
type pool interface {
First() *conn
Get() (*conn, bool, error)
Put(*conn) error
Remove(*conn, error) error
Len() int
FreeLen() int
Close() error
Stats() *PoolStats
}
type connList struct {
cns []*conn
mx sync.Mutex
len int32 // atomic
size int32
}
func newConnList(size int) *connList {
return &connList{
cns: make([]*conn, 0, size),
size: int32(size),
}
}
func (l *connList) Len() int {
return int(atomic.LoadInt32(&l.len))
}
// Reserve reserves place in the list and returns true on success. The
// caller must add or remove connection if place was reserved.
func (l *connList) Reserve() bool {
len := atomic.AddInt32(&l.len, 1)
reserved := len <= l.size
if !reserved {
atomic.AddInt32(&l.len, -1)
}
return reserved
}
// Add adds connection to the list. The caller must reserve place first.
func (l *connList) Add(cn *conn) {
l.mx.Lock()
l.cns = append(l.cns, cn)
l.mx.Unlock()
}
// Remove closes connection and removes it from the list.
func (l *connList) Remove(cn *conn) error {
defer l.mx.Unlock()
l.mx.Lock()
if cn == nil {
atomic.AddInt32(&l.len, -1)
return nil
}
for i, c := range l.cns {
if c == cn {
l.cns = append(l.cns[:i], l.cns[i+1:]...)
atomic.AddInt32(&l.len, -1)
return cn.Close()
}
}
if l.closed() {
return nil
}
panic("conn not found in the list")
}
func (l *connList) Replace(cn, newcn *conn) error {
defer l.mx.Unlock()
l.mx.Lock()
for i, c := range l.cns {
if c == cn {
l.cns[i] = newcn
return cn.Close()
}
}
if l.closed() {
return newcn.Close()
}
panic("conn not found in the list")
}
func (l *connList) Close() (retErr error) {
l.mx.Lock()
for _, c := range l.cns {
if err := c.Close(); err != nil {
retErr = err
}
}
l.cns = nil
atomic.StoreInt32(&l.len, 0)
l.mx.Unlock()
return retErr
}
func (l *connList) closed() bool {
return l.cns == nil
}
type connPool struct {
dialer func() (*conn, error)
rl *ratelimit.RateLimiter
opt *Options
conns *connList
freeConns chan *conn
stats PoolStats
_closed int32
lastErr atomic.Value
}
func newConnPool(opt *Options) *connPool {
p := &connPool{
dialer: newConnDialer(opt),
rl: ratelimit.New(3*opt.getPoolSize(), time.Second),
opt: opt,
conns: newConnList(opt.getPoolSize()),
freeConns: make(chan *conn, opt.getPoolSize()),
}
if p.opt.getIdleTimeout() > 0 {
go p.reaper()
}
return p
}
func (p *connPool) closed() bool {
return atomic.LoadInt32(&p._closed) == 1
}
func (p *connPool) isIdle(cn *conn) bool {
return p.opt.getIdleTimeout() > 0 && time.Since(cn.UsedAt) > p.opt.getIdleTimeout()
}
// First returns first non-idle connection from the pool or nil if
// there are no connections.
func (p *connPool) First() *conn {
for {
select {
case cn := <-p.freeConns:
if p.isIdle(cn) {
var err error
cn, err = p.replace(cn)
if err != nil {
Logger.Printf("pool.replace failed: %s", err)
continue
}
}
return cn
default:
return nil
}
}
panic("not reached")
}
// wait waits for free non-idle connection. It returns nil on timeout.
func (p *connPool) wait() *conn {
deadline := time.After(p.opt.getPoolTimeout())
for {
select {
case cn := <-p.freeConns:
if p.isIdle(cn) {
var err error
cn, err = p.replace(cn)
if err != nil {
Logger.Printf("pool.replace failed: %s", err)
continue
}
}
return cn
case <-deadline:
return nil
}
}
panic("not reached")
}
// Establish a new connection
func (p *connPool) new() (*conn, error) {
if p.rl.Limit() {
err := fmt.Errorf(
"redis: you open connections too fast (last_error=%q)",
p.loadLastErr(),
)
return nil, err
}
cn, err := p.dialer()
if err != nil {
p.storeLastErr(err.Error())
return nil, err
}
return cn, nil
}
// Get returns existed connection from the pool or creates a new one.
func (p *connPool) Get() (cn *conn, isNew bool, err error) {
if p.closed() {
err = errClosed
return
}
atomic.AddUint32(&p.stats.Requests, 1)
// Fetch first non-idle connection, if available.
if cn = p.First(); cn != nil {
atomic.AddUint32(&p.stats.Hits, 1)
return
}
// Try to create a new one.
if p.conns.Reserve() {
isNew = true
cn, err = p.new()
if err != nil {
p.conns.Remove(nil)
return
}
p.conns.Add(cn)
return
}
// Otherwise, wait for the available connection.
atomic.AddUint32(&p.stats.Waits, 1)
if cn = p.wait(); cn != nil {
return
}
atomic.AddUint32(&p.stats.Timeouts, 1)
err = errPoolTimeout
return
}
func (p *connPool) Put(cn *conn) error {
if cn.rd.Buffered() != 0 {
b, _ := cn.rd.Peek(cn.rd.Buffered())
err := fmt.Errorf("connection has unread data: %q", b)
Logger.Print(err)
return p.Remove(cn, err)
}
p.freeConns <- cn
return nil
}
func (p *connPool) replace(cn *conn) (*conn, error) {
newcn, err := p.new()
if err != nil {
_ = p.conns.Remove(cn)
return nil, err
}
_ = p.conns.Replace(cn, newcn)
return newcn, nil
}
func (p *connPool) Remove(cn *conn, reason error) error {
p.storeLastErr(reason.Error())
// Replace existing connection with new one and unblock waiter.
newcn, err := p.replace(cn)
if err != nil {
return err
}
p.freeConns <- newcn
return nil
}
// Len returns total number of connections.
func (p *connPool) Len() int {
return p.conns.Len()
}
// FreeLen returns number of free connections.
func (p *connPool) FreeLen() int {
return len(p.freeConns)
}
func (p *connPool) Stats() *PoolStats {
stats := p.stats
stats.Requests = atomic.LoadUint32(&p.stats.Requests)
stats.Waits = atomic.LoadUint32(&p.stats.Waits)
stats.Timeouts = atomic.LoadUint32(&p.stats.Timeouts)
stats.TotalConns = uint32(p.Len())
stats.FreeConns = uint32(p.FreeLen())
return &stats
}
func (p *connPool) Close() (retErr error) {
if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) {
return errClosed
}
// Wait for app to free connections, but don't close them immediately.
for i := 0; i < p.Len(); i++ {
if cn := p.wait(); cn == nil {
break
}
}
// Close all connections.
if err := p.conns.Close(); err != nil {
retErr = err
}
return retErr
}
func (p *connPool) reaper() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for _ = range ticker.C {
if p.closed() {
break
}
// pool.First removes idle connections from the pool and
// returns first non-idle connection. So just put returned
// connection back.
if cn := p.First(); cn != nil {
p.Put(cn)
}
}
}
func (p *connPool) storeLastErr(err string) {
p.lastErr.Store(err)
}
func (p *connPool) loadLastErr() string {
if v := p.lastErr.Load(); v != nil {
return v.(string)
}
return ""
}
//------------------------------------------------------------------------------
type singleConnPool struct {
cn *conn
}
func newSingleConnPool(cn *conn) *singleConnPool {
return &singleConnPool{
cn: cn,
}
}
func (p *singleConnPool) First() *conn {
return p.cn
}
func (p *singleConnPool) Get() (*conn, bool, error) {
return p.cn, false, nil
}
func (p *singleConnPool) Put(cn *conn) error {
if p.cn != cn {
panic("p.cn != cn")
}
return nil
}
func (p *singleConnPool) Remove(cn *conn, _ error) error {
if p.cn != cn {
panic("p.cn != cn")
}
return nil
}
func (p *singleConnPool) Len() int {
return 1
}
func (p *singleConnPool) FreeLen() int {
return 0
}
func (p *singleConnPool) Stats() *PoolStats { return nil }
func (p *singleConnPool) Close() error {
return nil
}
//------------------------------------------------------------------------------
type stickyConnPool struct {
pool pool
reusable bool
cn *conn
closed bool
mx sync.Mutex
}
func newStickyConnPool(pool pool, reusable bool) *stickyConnPool {
return &stickyConnPool{
pool: pool,
reusable: reusable,
}
}
func (p *stickyConnPool) First() *conn {
p.mx.Lock()
cn := p.cn
p.mx.Unlock()
return cn
}
func (p *stickyConnPool) Get() (cn *conn, isNew bool, err error) {
defer p.mx.Unlock()
p.mx.Lock()
if p.closed {
err = errClosed
return
}
if p.cn != nil {
cn = p.cn
return
}
cn, isNew, err = p.pool.Get()
if err != nil {
return
}
p.cn = cn
return
}
func (p *stickyConnPool) put() (err error) {
err = p.pool.Put(p.cn)
p.cn = nil
return err
}
func (p *stickyConnPool) Put(cn *conn) error {
defer p.mx.Unlock()
p.mx.Lock()
if p.closed {
return errClosed
}
if p.cn != cn {
panic("p.cn != cn")
}
return nil
}
func (p *stickyConnPool) remove(reason error) error {
err := p.pool.Remove(p.cn, reason)
p.cn = nil
return err
}
func (p *stickyConnPool) Remove(cn *conn, reason error) error {
defer p.mx.Unlock()
p.mx.Lock()
if p.closed {
return errClosed
}
if p.cn == nil {
panic("p.cn == nil")
}
if cn != nil && p.cn != cn {
panic("p.cn != cn")
}
return p.remove(reason)
}
func (p *stickyConnPool) Len() int {
defer p.mx.Unlock()
p.mx.Lock()
if p.cn == nil {
return 0
}
return 1
}
func (p *stickyConnPool) FreeLen() int {
defer p.mx.Unlock()
p.mx.Lock()
if p.cn == nil {
return 1
}
return 0
}
func (p *stickyConnPool) Stats() *PoolStats { return nil }
func (p *stickyConnPool) Close() error {
defer p.mx.Unlock()
p.mx.Lock()
if p.closed {
return errClosed
}
p.closed = true
var err error
if p.cn != nil {
if p.reusable {
err = p.put()
} else {
reason := errors.New("redis: sticky not reusable connection")
err = p.remove(reason)
}
}
return err
}

View File

@ -4,6 +4,8 @@ import (
"fmt"
"net"
"time"
"gopkg.in/redis.v3/internal/pool"
)
// Posts a message to the given channel.
@ -30,7 +32,7 @@ func (c *Client) PubSub() *PubSub {
return &PubSub{
base: &baseClient{
opt: c.opt,
connPool: newStickyConnPool(c.connPool, false),
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), false),
},
}
}
@ -47,19 +49,20 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) {
return pubsub, pubsub.PSubscribe(channels...)
}
func (c *PubSub) subscribe(cmd string, channels ...string) error {
func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
cn, _, err := c.base.conn()
if err != nil {
return err
}
args := make([]interface{}, 1+len(channels))
args[0] = cmd
args[0] = redisCmd
for i, channel := range channels {
args[1+i] = channel
}
req := NewSliceCmd(args...)
return cn.writeCmds(req)
cmd := NewSliceCmd(args...)
return writeCmd(cn, cmd)
}
// Subscribes the client to the specified channels.
@ -132,7 +135,7 @@ func (c *PubSub) Ping(payload string) error {
args = append(args, payload)
}
cmd := NewCmd(args...)
return cn.writeCmds(cmd)
return writeCmd(cn, cmd)
}
// Message received after a successful subscription to channel.
@ -296,7 +299,7 @@ func (c *PubSub) ReceiveMessage() (*Message, error) {
}
}
func (c *PubSub) putConn(cn *conn, err error) {
func (c *PubSub) putConn(cn *pool.Conn, err error) {
if !c.base.putConn(cn, err, true) {
c.nsub = 0
}

View File

@ -291,10 +291,10 @@ var _ = Describe("PubSub", func() {
expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
cn1, _, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn1.SetNetConn(&badConn{
cn1.NetConn = &badConn{
readErr: io.EOF,
writeErr: io.EOF,
})
}
done := make(chan bool, 1)
go func() {

130
redis.go
View File

@ -3,15 +3,26 @@ package redis // import "gopkg.in/redis.v3"
import (
"fmt"
"log"
"net"
"os"
"time"
"sync/atomic"
"gopkg.in/redis.v3/internal/pool"
)
var Logger = log.New(os.Stderr, "redis: ", log.LstdFlags)
// Deprecated. Use SetLogger instead.
var Logger *log.Logger
func init() {
SetLogger(log.New(os.Stderr, "redis: ", log.LstdFlags))
}
func SetLogger(logger *log.Logger) {
Logger = logger
pool.Logger = logger
}
type baseClient struct {
connPool pool
connPool pool.Pooler
opt *Options
onClose func() error // hook called when client is closed
@ -21,11 +32,11 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB)
}
func (c *baseClient) conn() (*conn, bool, error) {
func (c *baseClient) conn() (*pool.Conn, bool, error) {
return c.connPool.Get()
}
func (c *baseClient) putConn(cn *conn, err error, allowTimeout bool) bool {
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
if isBadConn(err, allowTimeout) {
err = c.connPool.Remove(cn, err)
if err != nil {
@ -61,7 +72,7 @@ func (c *baseClient) process(cmd Cmder) {
}
cn.WriteTimeout = c.opt.WriteTimeout
if err := cn.writeCmds(cmd); err != nil {
if err := writeCmd(cn, cmd); err != nil {
c.putConn(cn, err, false)
cmd.setErr(err)
if shouldRetry(err) {
@ -99,93 +110,6 @@ func (c *baseClient) Close() error {
//------------------------------------------------------------------------------
type Options struct {
// The network type, either tcp or unix.
// Default is tcp.
Network string
// host:port address.
Addr string
// Dialer creates new network connection and has priority over
// Network and Addr options.
Dialer func() (net.Conn, error)
// An optional password. Must match the password specified in the
// requirepass server configuration option.
Password string
// A database to be selected after connecting to server.
DB int64
// The maximum number of retries before giving up.
// Default is to not retry failed commands.
MaxRetries int
// Sets the deadline for establishing new connections. If reached,
// dial will fail with a timeout.
DialTimeout time.Duration
// Sets the deadline for socket reads. If reached, commands will
// fail with a timeout instead of blocking.
ReadTimeout time.Duration
// Sets the deadline for socket writes. If reached, commands will
// fail with a timeout instead of blocking.
WriteTimeout time.Duration
// The maximum number of socket connections.
// Default is 10 connections.
PoolSize int
// Specifies amount of time client waits for connection if all
// connections are busy before returning an error.
// Default is 1 seconds.
PoolTimeout time.Duration
// Specifies amount of time after which client closes idle
// connections. Should be less than server's timeout.
// Default is to not close idle connections.
IdleTimeout time.Duration
}
func (opt *Options) getNetwork() string {
if opt.Network == "" {
return "tcp"
}
return opt.Network
}
func (opt *Options) getDialer() func() (net.Conn, error) {
if opt.Dialer == nil {
opt.Dialer = func() (net.Conn, error) {
return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout())
}
}
return opt.Dialer
}
func (opt *Options) getPoolSize() int {
if opt.PoolSize == 0 {
return 10
}
return opt.PoolSize
}
func (opt *Options) getDialTimeout() time.Duration {
if opt.DialTimeout == 0 {
return 5 * time.Second
}
return opt.DialTimeout
}
func (opt *Options) getPoolTimeout() time.Duration {
if opt.PoolTimeout == 0 {
return 1 * time.Second
}
return opt.PoolTimeout
}
func (opt *Options) getIdleTimeout() time.Duration {
return opt.IdleTimeout
}
//------------------------------------------------------------------------------
// Client is a Redis client representing a pool of zero or more
// underlying connections. It's safe for concurrent use by multiple
// goroutines.
@ -194,7 +118,7 @@ type Client struct {
commandable
}
func newClient(opt *Options, pool pool) *Client {
func newClient(opt *Options, pool pool.Pooler) *Client {
base := baseClient{opt: opt, connPool: pool}
return &Client{
baseClient: base,
@ -206,11 +130,19 @@ func newClient(opt *Options, pool pool) *Client {
// NewClient returns a client to the Redis Server specified by Options.
func NewClient(opt *Options) *Client {
pool := newConnPool(opt)
return newClient(opt, pool)
return newClient(opt, newConnPool(opt))
}
// PoolStats returns connection pool stats
// PoolStats returns connection pool stats.
func (c *Client) PoolStats() *PoolStats {
return c.connPool.Stats()
s := c.connPool.Stats()
return &PoolStats{
Requests: atomic.LoadUint32(&s.Requests),
Hits: atomic.LoadUint32(&s.Hits),
Waits: atomic.LoadUint32(&s.Waits),
Timeouts: atomic.LoadUint32(&s.Timeouts),
TotalConns: atomic.LoadUint32(&s.TotalConns),
FreeConns: atomic.LoadUint32(&s.FreeConns),
}
}

View File

@ -160,7 +160,7 @@ var _ = Describe("Client", func() {
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
cn.NetConn = &badConn{}
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())
@ -174,10 +174,6 @@ var _ = Describe("Client", func() {
Expect(cn.UsedAt).NotTo(BeZero())
createdAt := cn.UsedAt
future := time.Now().Add(time.Hour)
redis.SetTime(future)
defer redis.RestoreTime()
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())
Expect(cn.UsedAt.Equal(createdAt)).To(BeTrue())
@ -187,6 +183,6 @@ var _ = Describe("Client", func() {
cn = client.Pool().First()
Expect(cn).NotTo(BeNil())
Expect(cn.UsedAt.Equal(future)).To(BeTrue())
Expect(cn.UsedAt.After(createdAt)).To(BeTrue())
})
})

View File

@ -8,6 +8,7 @@ import (
"gopkg.in/redis.v3/internal/consistenthash"
"gopkg.in/redis.v3/internal/hashtag"
"gopkg.in/redis.v3/internal/pool"
)
var (
@ -200,7 +201,7 @@ func (ring *Ring) heartbeat() {
for _, shard := range ring.shards {
err := shard.Client.Ping().Err()
if shard.Vote(err == nil || err == errPoolTimeout) {
if shard.Vote(err == nil || err == pool.ErrPoolTimeout) {
Logger.Printf("ring shard state changed: %s", shard)
rebalance = true
}

View File

@ -7,6 +7,8 @@ import (
"strings"
"sync"
"time"
"gopkg.in/redis.v3/internal/pool"
)
//------------------------------------------------------------------------------
@ -103,7 +105,7 @@ func (c *sentinelClient) PubSub() *PubSub {
return &PubSub{
base: &baseClient{
opt: c.opt,
connPool: newStickyConnPool(c.connPool, false),
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), false),
},
}
}
@ -126,7 +128,7 @@ type sentinelFailover struct {
opt *Options
pool pool
pool *pool.ConnPool
poolOnce sync.Once
mu sync.RWMutex
@ -145,7 +147,7 @@ func (d *sentinelFailover) dial() (net.Conn, error) {
return net.DialTimeout("tcp", addr, d.opt.DialTimeout)
}
func (d *sentinelFailover) Pool() pool {
func (d *sentinelFailover) Pool() *pool.ConnPool {
d.poolOnce.Do(func() {
d.opt.Dialer = d.dial
d.pool = newConnPool(d.opt)
@ -252,7 +254,7 @@ func (d *sentinelFailover) closeOldConns(newMaster string) {
// Good connections that should be put back to the pool. They
// can't be put immediately, because pool.First will return them
// again on next iteration.
cnsToPut := make([]*conn, 0)
cnsToPut := make([]*pool.Conn, 0)
for {
cn := d.pool.First()