Merge pull request #280 from go-redis/fix/big-vals-race-test

Add race test for big vals.
This commit is contained in:
Vladimir Mihailenco 2016-03-14 20:08:06 +02:00
commit 7f594cdbe1
25 changed files with 278 additions and 160 deletions

View File

@ -1,6 +1,6 @@
all: testdeps all: testdeps
go test ./... -test.v -test.cpu=1,2,4 go test ./... -test.cpu=1,2,4
go test ./... -test.v -test.short -test.race go test ./... -test.short -test.race
testdeps: testdata/redis/src/redis-server testdeps: testdata/redis/src/redis-server

View File

@ -7,6 +7,7 @@ import (
"time" "time"
"gopkg.in/redis.v3/internal/hashtag" "gopkg.in/redis.v3/internal/hashtag"
"gopkg.in/redis.v3/internal/pool"
) )
// ClusterClient is a Redis Cluster client representing a pool of zero // ClusterClient is a Redis Cluster client representing a pool of zero
@ -80,7 +81,7 @@ func (c *ClusterClient) Close() error {
c.clientsMx.Lock() c.clientsMx.Lock()
if c.closed { if c.closed {
return errClosed return pool.ErrClosed
} }
c.closed = true c.closed = true
c.resetClients() c.resetClients()
@ -105,7 +106,7 @@ func (c *ClusterClient) getClient(addr string) (*Client, error) {
c.clientsMx.Lock() c.clientsMx.Lock()
if c.closed { if c.closed {
c.clientsMx.Unlock() c.clientsMx.Unlock()
return nil, errClosed return nil, pool.ErrClosed
} }
client, ok = c.clients[addr] client, ok = c.clients[addr]

View File

@ -34,7 +34,7 @@ func (pipe *ClusterPipeline) process(cmd Cmder) {
// Discard resets the pipeline and discards queued commands. // Discard resets the pipeline and discards queued commands.
func (pipe *ClusterPipeline) Discard() error { func (pipe *ClusterPipeline) Discard() error {
if pipe.closed { if pipe.closed {
return errClosed return pool.ErrClosed
} }
pipe.cmds = pipe.cmds[:0] pipe.cmds = pipe.cmds[:0]
return nil return nil
@ -42,7 +42,7 @@ func (pipe *ClusterPipeline) Discard() error {
func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) { func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) {
if pipe.closed { if pipe.closed {
return nil, errClosed return nil, pool.ErrClosed
} }
if len(pipe.cmds) == 0 { if len(pipe.cmds) == 0 {
return []Cmder{}, nil return []Cmder{}, nil

View File

@ -139,7 +139,7 @@ func startCluster(scenario *clusterScenario) error {
return fmt.Errorf("cluster did not reach consistent state (%v)", res) return fmt.Errorf("cluster did not reach consistent state (%v)", res)
} }
return nil return nil
}, 10*time.Second) }, 30*time.Second)
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,11 +5,13 @@ import (
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
"time"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"gopkg.in/redis.v3" "gopkg.in/redis.v3"
"gopkg.in/redis.v3/internal/pool"
) )
var _ = Describe("Command", func() { var _ = Describe("Command", func() {
@ -17,7 +19,8 @@ var _ = Describe("Command", func() {
connect := func() *redis.Client { connect := func() *redis.Client {
return redis.NewClient(&redis.Options{ return redis.NewClient(&redis.Options{
Addr: redisAddr, Addr: redisAddr,
PoolTimeout: time.Minute,
}) })
} }
@ -62,19 +65,19 @@ var _ = Describe("Command", func() {
}) })
It("should handle big vals", func() { It("should handle big vals", func() {
val := string(bytes.Repeat([]byte{'*'}, 1<<16)) bigVal := string(bytes.Repeat([]byte{'*'}, 1<<16))
set := client.Set("key", val, 0)
Expect(set.Err()).NotTo(HaveOccurred()) err := client.Set("key", bigVal, 0).Err()
Expect(set.Val()).To(Equal("OK")) Expect(err).NotTo(HaveOccurred())
// Reconnect to get new connection. // Reconnect to get new connection.
Expect(client.Close()).To(BeNil()) Expect(client.Close()).To(BeNil())
client = connect() client = connect()
get := client.Get("key") got, err := client.Get("key").Result()
Expect(get.Err()).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(len(get.Val())).To(Equal(len(val))) Expect(len(got)).To(Equal(len(bigVal)))
Expect(get.Val()).To(Equal(val)) Expect(got).To(Equal(bigVal))
}) })
It("should handle many keys #1", func() { It("should handle many keys #1", func() {
@ -136,52 +139,116 @@ var _ = Describe("Command", func() {
Describe("races", func() { Describe("races", func() {
var C, N = 10, 1000 var C, N = 10, 1000
if testing.Short() { if testing.Short() {
C = 3
N = 100 N = 100
} }
It("should echo", func() { It("should echo", func() {
wg := &sync.WaitGroup{} perform(C, func() {
for i := 0; i < C; i++ { for i := 0; i < N; i++ {
wg.Add(1) msg := "echo" + strconv.Itoa(i)
echo, err := client.Echo(msg).Result()
go func(i int) { Expect(err).NotTo(HaveOccurred())
defer GinkgoRecover() Expect(echo).To(Equal(msg))
defer wg.Done() }
})
for j := 0; j < N; j++ {
msg := "echo" + strconv.Itoa(i)
echo := client.Echo(msg)
Expect(echo.Err()).NotTo(HaveOccurred())
Expect(echo.Val()).To(Equal(msg))
}
}(i)
}
wg.Wait()
}) })
It("should incr", func() { It("should incr", func() {
key := "TestIncrFromGoroutines" key := "TestIncrFromGoroutines"
wg := &sync.WaitGroup{}
for i := 0; i < C; i++ {
wg.Add(1)
go func() { perform(C, func() {
defer GinkgoRecover() for i := 0; i < N; i++ {
defer wg.Done() err := client.Incr(key).Err()
Expect(err).NotTo(HaveOccurred())
for j := 0; j < N; j++ { }
err := client.Incr(key).Err() })
Expect(err).NotTo(HaveOccurred())
}
}()
}
wg.Wait()
val, err := client.Get(key).Int64() val, err := client.Get(key).Int64()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int64(C * N))) Expect(val).To(Equal(int64(C * N)))
}) })
It("should handle big vals", func() {
client2 := connect()
defer client2.Close()
bigVal := string(bytes.Repeat([]byte{'*'}, 1<<16))
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
perform(C, func() {
for i := 0; i < N; i++ {
got, err := client.Get("key").Result()
if err == redis.Nil {
continue
}
Expect(got).To(Equal(bigVal))
}
})
}()
go func() {
defer wg.Done()
perform(C, func() {
for i := 0; i < N; i++ {
err := client2.Set("key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred())
}
})
}()
wg.Wait()
})
It("should PubSub", func() {
connPool := client.Pool()
connPool.(*pool.ConnPool).DialLimiter = nil
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
perform(C, func() {
for i := 0; i < N; i++ {
pubsub, err := client.Subscribe("mychannel")
Expect(err).NotTo(HaveOccurred())
go func() {
defer GinkgoRecover()
time.Sleep(time.Millisecond)
err := pubsub.Close()
Expect(err).NotTo(HaveOccurred())
}()
_, err = pubsub.ReceiveMessage()
Expect(err.Error()).To(ContainSubstring("closed"))
}
})
}()
go func() {
defer wg.Done()
perform(C, func() {
for i := 0; i < N; i++ {
val := "echo" + strconv.Itoa(i)
echo, err := client.Echo(val).Result()
Expect(err).NotTo(HaveOccurred())
Expect(echo).To(Equal(val))
}
})
}()
wg.Wait()
Expect(connPool.Len()).To(Equal(connPool.FreeLen()))
Expect(connPool.Len()).To(BeNumerically("<=", 10))
})
}) })
}) })

View File

@ -57,16 +57,20 @@ var _ = Describe("Commands", func() {
}) })
It("should BgRewriteAOF", func() { It("should BgRewriteAOF", func() {
r := client.BgRewriteAOF() Skip("flaky test")
Expect(r.Err()).NotTo(HaveOccurred())
Expect(r.Val()).To(ContainSubstring("Background append only file rewriting")) val, err := client.BgRewriteAOF().Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(ContainSubstring("Background append only file rewriting"))
}) })
It("should BgSave", func() { It("should BgSave", func() {
Skip("flaky test")
// workaround for "ERR Can't BGSAVE while AOF log rewriting is in progress" // workaround for "ERR Can't BGSAVE while AOF log rewriting is in progress"
Eventually(func() string { Eventually(func() string {
return client.BgSave().Val() return client.BgSave().Val()
}, "10s").Should(Equal("Background saving started")) }, "30s").Should(Equal("Background saving started"))
}) })
It("should ClientKill", func() { It("should ClientKill", func() {

View File

@ -1,15 +1,12 @@
package redis package redis
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"strings" "strings"
) )
var errClosed = errors.New("redis: client is closed")
// Redis nil reply, .e.g. when key does not exist. // Redis nil reply, .e.g. when key does not exist.
var Nil = errorf("redis: nil") var Nil = errorf("redis: nil")

View File

@ -13,7 +13,8 @@ var client *redis.Client
func init() { func init() {
client = redis.NewClient(&redis.Options{ client = redis.NewClient(&redis.Options{
Addr: ":6379", Addr: ":6379",
DialTimeout: 10 * time.Second,
}) })
client.FlushDb() client.FlushDb()
} }
@ -220,13 +221,13 @@ func ExampleClient_Watch() {
} }
func ExamplePubSub() { func ExamplePubSub() {
pubsub, err := client.Subscribe("mychannel") pubsub, err := client.Subscribe("mychannel1")
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer pubsub.Close() defer pubsub.Close()
err = client.Publish("mychannel", "hello").Err() err = client.Publish("mychannel1", "hello").Err()
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -237,40 +238,42 @@ func ExamplePubSub() {
} }
fmt.Println(msg.Channel, msg.Payload) fmt.Println(msg.Channel, msg.Payload)
// Output: mychannel hello // Output: mychannel1 hello
} }
func ExamplePubSub_Receive() { func ExamplePubSub_Receive() {
pubsub, err := client.Subscribe("mychannel") pubsub, err := client.Subscribe("mychannel2")
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer pubsub.Close() defer pubsub.Close()
err = client.Publish("mychannel", "hello").Err() n, err := client.Publish("mychannel2", "hello").Result()
if err != nil { if err != nil {
panic(err) panic(err)
} }
fmt.Println(n, "clients received message")
for i := 0; i < 2; i++ { for {
// ReceiveTimeout is a low level API. Use ReceiveMessage instead. // ReceiveTimeout is a low level API. Use ReceiveMessage instead.
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(time.Second)
if err != nil { if err != nil {
panic(err) break
} }
switch msg := msgi.(type) { switch msg := msgi.(type) {
case *redis.Subscription: case *redis.Subscription:
fmt.Println(msg.Kind, msg.Channel) fmt.Println("subscribed to", msg.Channel)
case *redis.Message: case *redis.Message:
fmt.Println(msg.Channel, msg.Payload) fmt.Println("received", msg.Payload, "from", msg.Channel)
default: default:
panic(fmt.Sprintf("unknown message: %#v", msgi)) panic(fmt.Errorf("unknown message: %#v", msgi))
} }
} }
// Output: subscribe mychannel // Output: 1 clients received message
// mychannel hello // subscribed to mychannel2
// received hello from mychannel2
} }
func ExampleScript() { func ExampleScript() {

View File

@ -3,6 +3,7 @@ package pool
import ( import (
"bufio" "bufio"
"net" "net"
"sync/atomic"
"time" "time"
) )
@ -11,9 +12,9 @@ const defaultBufSize = 4096
var noDeadline = time.Time{} var noDeadline = time.Time{}
type Conn struct { type Conn struct {
idx int idx int32
netConn net.Conn NetConn net.Conn
Rd *bufio.Reader Rd *bufio.Reader
Buf []byte Buf []byte
@ -26,7 +27,7 @@ func NewConn(netConn net.Conn) *Conn {
cn := &Conn{ cn := &Conn{
idx: -1, idx: -1,
netConn: netConn, NetConn: netConn,
Buf: make([]byte, defaultBufSize), Buf: make([]byte, defaultBufSize),
UsedAt: time.Now(), UsedAt: time.Now(),
@ -35,39 +36,47 @@ func NewConn(netConn net.Conn) *Conn {
return cn return cn
} }
func (cn *Conn) IsStale(timeout time.Duration) bool { func (cn *Conn) Index() int {
return timeout > 0 && time.Since(cn.UsedAt) > timeout return int(atomic.LoadInt32(&cn.idx))
} }
func (cn *Conn) SetNetConn(netConn net.Conn) { func (cn *Conn) SetIndex(idx int) {
cn.netConn = netConn atomic.StoreInt32(&cn.idx, int32(idx))
cn.UsedAt = time.Now() }
func (cn *Conn) IsStale(timeout time.Duration) bool {
return timeout > 0 && time.Since(cn.UsedAt) > timeout
} }
func (cn *Conn) Read(b []byte) (int, error) { func (cn *Conn) Read(b []byte) (int, error) {
cn.UsedAt = time.Now() cn.UsedAt = time.Now()
if cn.ReadTimeout != 0 { if cn.ReadTimeout != 0 {
cn.netConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout)) cn.NetConn.SetReadDeadline(cn.UsedAt.Add(cn.ReadTimeout))
} else { } else {
cn.netConn.SetReadDeadline(noDeadline) cn.NetConn.SetReadDeadline(noDeadline)
} }
return cn.netConn.Read(b) return cn.NetConn.Read(b)
} }
func (cn *Conn) Write(b []byte) (int, error) { func (cn *Conn) Write(b []byte) (int, error) {
cn.UsedAt = time.Now() cn.UsedAt = time.Now()
if cn.WriteTimeout != 0 { if cn.WriteTimeout != 0 {
cn.netConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout)) cn.NetConn.SetWriteDeadline(cn.UsedAt.Add(cn.WriteTimeout))
} else { } else {
cn.netConn.SetWriteDeadline(noDeadline) cn.NetConn.SetWriteDeadline(noDeadline)
} }
return cn.netConn.Write(b) return cn.NetConn.Write(b)
} }
func (cn *Conn) RemoteAddr() net.Addr { func (cn *Conn) RemoteAddr() net.Addr {
return cn.netConn.RemoteAddr() return cn.NetConn.RemoteAddr()
} }
func (cn *Conn) Close() error { func (cn *Conn) Close() int {
return cn.netConn.Close() idx := cn.Index()
if !atomic.CompareAndSwapInt32(&cn.idx, int32(idx), -1) {
return -1
}
_ = cn.NetConn.Close()
return idx
} }

View File

@ -24,7 +24,7 @@ func (l *connList) Len() int {
} }
// Reserve reserves place in the list and returns true on success. // Reserve reserves place in the list and returns true on success.
// The caller must add or remove connection if place was reserved. // The caller must add connection or cancel reservation if it was reserved.
func (l *connList) Reserve() bool { func (l *connList) Reserve() bool {
len := atomic.AddInt32(&l.len, 1) len := atomic.AddInt32(&l.len, 1)
reserved := len <= l.size reserved := len <= l.size
@ -34,12 +34,16 @@ func (l *connList) Reserve() bool {
return reserved return reserved
} }
func (l *connList) CancelReservation() {
atomic.AddInt32(&l.len, -1)
}
// Add adds connection to the list. The caller must reserve place first. // Add adds connection to the list. The caller must reserve place first.
func (l *connList) Add(cn *Conn) { func (l *connList) Add(cn *Conn) {
l.mu.Lock() l.mu.Lock()
for i, c := range l.cns { for i, c := range l.cns {
if c == nil { if c == nil {
cn.idx = i cn.SetIndex(i)
l.cns[i] = cn l.cns[i] = cn
l.mu.Unlock() l.mu.Unlock()
return return
@ -48,37 +52,34 @@ func (l *connList) Add(cn *Conn) {
panic("not reached") panic("not reached")
} }
// Remove closes connection and removes it from the list. func (l *connList) Replace(cn *Conn) {
func (l *connList) Remove(cn *Conn) error {
atomic.AddInt32(&l.len, -1)
if cn == nil { // free reserved place
return nil
}
l.mu.Lock() l.mu.Lock()
if l.cns != nil { if l.cns != nil {
l.cns[cn.idx] = nil l.cns[cn.idx] = cn
cn.idx = -1
} }
l.mu.Unlock() l.mu.Unlock()
}
return nil // Remove closes connection and removes it from the list.
func (l *connList) Remove(idx int) {
l.mu.Lock()
if l.cns != nil {
l.cns[idx] = nil
l.len -= 1
}
l.mu.Unlock()
} }
func (l *connList) Close() error { func (l *connList) Close() error {
var retErr error
l.mu.Lock() l.mu.Lock()
for _, c := range l.cns { for _, c := range l.cns {
if c == nil { if c == nil {
continue continue
} }
if err := c.Close(); err != nil && retErr == nil { c.Close()
retErr = err
}
} }
l.cns = nil l.cns = nil
atomic.StoreInt32(&l.len, 0) l.len = 0
l.mu.Unlock() l.mu.Unlock()
return retErr return nil
} }

View File

@ -14,7 +14,8 @@ import (
var Logger *log.Logger var Logger *log.Logger
var ( var (
errClosed = errors.New("redis: client is closed") ErrClosed = errors.New("redis: client is closed")
errConnClosed = errors.New("redis: connection is closed")
ErrPoolTimeout = errors.New("redis: connection pool timeout") ErrPoolTimeout = errors.New("redis: connection pool timeout")
) )
@ -36,8 +37,9 @@ type Pooler interface {
Replace(*Conn, error) error Replace(*Conn, error) error
Len() int Len() int
FreeLen() int FreeLen() int
Close() error
Stats() *PoolStats Stats() *PoolStats
Close() error
Closed() bool
} }
type dialer func() (net.Conn, error) type dialer func() (net.Conn, error)
@ -58,6 +60,8 @@ type ConnPool struct {
lastErr atomic.Value lastErr atomic.Value
} }
var _ Pooler = (*ConnPool)(nil)
func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Duration) *ConnPool { func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Duration) *ConnPool {
p := &ConnPool{ p := &ConnPool{
_dial: dial, _dial: dial,
@ -75,7 +79,7 @@ func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Durati
return p return p
} }
func (p *ConnPool) closed() bool { func (p *ConnPool) Closed() bool {
return atomic.LoadInt32(&p._closed) == 1 return atomic.LoadInt32(&p._closed) == 1
} }
@ -152,8 +156,8 @@ func (p *ConnPool) newConn() (*Conn, error) {
// Get returns existed connection from the pool or creates a new one. // Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) { func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) {
if p.closed() { if p.Closed() {
err = errClosed err = ErrClosed
return return
} }
@ -171,7 +175,7 @@ func (p *ConnPool) Get() (cn *Conn, isNew bool, err error) {
cn, err = p.newConn() cn, err = p.newConn()
if err != nil { if err != nil {
p.conns.Remove(nil) p.conns.CancelReservation()
return return
} }
p.conns.Add(cn) p.conns.Add(cn)
@ -201,14 +205,20 @@ func (p *ConnPool) Put(cn *Conn) error {
} }
func (p *ConnPool) replace(cn *Conn) (*Conn, error) { func (p *ConnPool) replace(cn *Conn) (*Conn, error) {
_ = cn.Close() idx := cn.Close()
if idx == -1 {
return nil, errConnClosed
}
netConn, err := p.dial() netConn, err := p.dial()
if err != nil { if err != nil {
_ = p.conns.Remove(cn) p.conns.Remove(idx)
return nil, err return nil, err
} }
cn.SetNetConn(netConn)
cn = NewConn(netConn)
cn.SetIndex(idx)
p.conns.Replace(cn)
return cn, nil return cn, nil
} }
@ -226,9 +236,14 @@ func (p *ConnPool) Replace(cn *Conn, reason error) error {
} }
func (p *ConnPool) Remove(cn *Conn, reason error) error { func (p *ConnPool) Remove(cn *Conn, reason error) error {
idx := cn.Close()
if idx == -1 {
return errConnClosed
}
p.storeLastErr(reason.Error()) p.storeLastErr(reason.Error())
_ = cn.Close() p.conns.Remove(idx)
return p.conns.Remove(cn) return nil
} }
// Len returns total number of connections. // Len returns total number of connections.
@ -253,7 +268,7 @@ func (p *ConnPool) Stats() *PoolStats {
func (p *ConnPool) Close() (retErr error) { func (p *ConnPool) Close() (retErr error) {
if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) { if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) {
return errClosed return ErrClosed
} }
// Wait for app to free connections, but don't close them immediately. // Wait for app to free connections, but don't close them immediately.
for i := 0; i < p.Len(); i++ { for i := 0; i < p.Len(); i++ {
@ -287,7 +302,7 @@ func (p *ConnPool) reaper() {
defer ticker.Stop() defer ticker.Stop()
for _ = range ticker.C { for _ = range ticker.C {
if p.closed() { if p.Closed() {
break break
} }
n, err := p.ReapStaleConns() n, err := p.ReapStaleConns()

View File

@ -4,6 +4,8 @@ type SingleConnPool struct {
cn *Conn cn *Conn
} }
var _ Pooler = (*SingleConnPool)(nil)
func NewSingleConnPool(cn *Conn) *SingleConnPool { func NewSingleConnPool(cn *Conn) *SingleConnPool {
return &SingleConnPool{ return &SingleConnPool{
cn: cn, cn: cn,
@ -40,8 +42,14 @@ func (p *SingleConnPool) FreeLen() int {
return 0 return 0
} }
func (p *SingleConnPool) Stats() *PoolStats { return nil } func (p *SingleConnPool) Stats() *PoolStats {
return nil
}
func (p *SingleConnPool) Close() error { func (p *SingleConnPool) Close() error {
return nil return nil
} }
func (p *SingleConnPool) Closed() bool {
return false
}

View File

@ -14,6 +14,8 @@ type StickyConnPool struct {
mx sync.Mutex mx sync.Mutex
} }
var _ Pooler = (*StickyConnPool)(nil)
func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool { func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool {
return &StickyConnPool{ return &StickyConnPool{
pool: pool, pool: pool,
@ -33,7 +35,7 @@ func (p *StickyConnPool) Get() (cn *Conn, isNew bool, err error) {
p.mx.Lock() p.mx.Lock()
if p.closed { if p.closed {
err = errClosed err = ErrClosed
return return
} }
if p.cn != nil { if p.cn != nil {
@ -59,7 +61,7 @@ func (p *StickyConnPool) Put(cn *Conn) error {
defer p.mx.Unlock() defer p.mx.Unlock()
p.mx.Lock() p.mx.Lock()
if p.closed { if p.closed {
return errClosed return ErrClosed
} }
if p.cn != cn { if p.cn != cn {
panic("p.cn != cn") panic("p.cn != cn")
@ -77,7 +79,7 @@ func (p *StickyConnPool) Replace(cn *Conn, reason error) error {
defer p.mx.Unlock() defer p.mx.Unlock()
p.mx.Lock() p.mx.Lock()
if p.closed { if p.closed {
return errClosed return nil
} }
if p.cn == nil { if p.cn == nil {
panic("p.cn == nil") panic("p.cn == nil")
@ -112,7 +114,7 @@ func (p *StickyConnPool) Close() error {
defer p.mx.Unlock() defer p.mx.Unlock()
p.mx.Lock() p.mx.Lock()
if p.closed { if p.closed {
return errClosed return ErrClosed
} }
p.closed = true p.closed = true
var err error var err error
@ -126,3 +128,10 @@ func (p *StickyConnPool) Close() error {
} }
return err return err
} }
func (p *StickyConnPool) Closed() bool {
p.mx.Lock()
closed := p.closed
p.mx.Unlock()
return closed
}

View File

@ -31,12 +31,14 @@ var _ = Describe("conns reapser", func() {
cn := pool.NewConn(&net.TCPConn{}) cn := pool.NewConn(&net.TCPConn{})
cn.UsedAt = time.Now().Add(-2 * time.Minute) cn.UsedAt = time.Now().Add(-2 * time.Minute)
Expect(connPool.Add(cn)).To(BeTrue()) Expect(connPool.Add(cn)).To(BeTrue())
Expect(cn.Index()).To(Equal(i))
} }
// add fresh connections // add fresh connections
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn := pool.NewConn(&net.TCPConn{}) cn := pool.NewConn(&net.TCPConn{})
Expect(connPool.Add(cn)).To(BeTrue()) Expect(connPool.Add(cn)).To(BeTrue())
Expect(cn.Index()).To(Equal(3 + i))
} }
Expect(connPool.Len()).To(Equal(6)) Expect(connPool.Len()).To(Equal(6))

View File

@ -6,6 +6,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -98,6 +99,20 @@ func TestGinkgoSuite(t *testing.T) {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func perform(n int, cb func()) {
var wg sync.WaitGroup
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer GinkgoRecover()
defer wg.Done()
cb()
}()
}
wg.Wait()
}
func eventually(fn func() error, timeout time.Duration) error { func eventually(fn func() error, timeout time.Duration) error {
done := make(chan struct{}) done := make(chan struct{})
var exit int32 var exit int32
@ -138,7 +153,7 @@ func connectTo(port string) (*redis.Client, error) {
err := eventually(func() error { err := eventually(func() error {
return client.Ping().Err() return client.Ping().Err()
}, 10*time.Second) }, 30*time.Second)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -109,7 +109,7 @@ func (c *Multi) Discard() error {
// failed command or nil. // failed command or nil.
func (c *Multi) Exec(f func() error) ([]Cmder, error) { func (c *Multi) Exec(f func() error) ([]Cmder, error) {
if c.closed { if c.closed {
return nil, errClosed return nil, pool.ErrClosed
} }
c.cmds = []Cmder{NewStatusCmd("MULTI")} c.cmds = []Cmder{NewStatusCmd("MULTI")}

View File

@ -145,7 +145,7 @@ var _ = Describe("Multi", func() {
cn, _, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.NetConn = &badConn{}
err = client.Pool().Put(cn) err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -172,7 +172,7 @@ var _ = Describe("Multi", func() {
cn, _, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.NetConn = &badConn{}
err = client.Pool().Put(cn) err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())

View File

@ -30,6 +30,7 @@ type Options struct {
// Sets the deadline for establishing new connections. If reached, // Sets the deadline for establishing new connections. If reached,
// dial will fail with a timeout. // dial will fail with a timeout.
// Default is 5 seconds.
DialTimeout time.Duration DialTimeout time.Duration
// Sets the deadline for socket reads. If reached, commands will // Sets the deadline for socket reads. If reached, commands will
// fail with a timeout instead of blocking. // fail with a timeout instead of blocking.
@ -43,7 +44,7 @@ type Options struct {
PoolSize int PoolSize int
// Specifies amount of time client waits for connection if all // Specifies amount of time client waits for connection if all
// connections are busy before returning an error. // connections are busy before returning an error.
// Default is 1 seconds. // Default is 1 second.
PoolTimeout time.Duration PoolTimeout time.Duration
// Specifies amount of time after which client closes idle // Specifies amount of time after which client closes idle
// connections. Should be less than server's timeout. // connections. Should be less than server's timeout.

View File

@ -62,7 +62,7 @@ func (pipe *Pipeline) Discard() error {
defer pipe.mu.Unlock() defer pipe.mu.Unlock()
pipe.mu.Lock() pipe.mu.Lock()
if pipe.isClosed() { if pipe.isClosed() {
return errClosed return pool.ErrClosed
} }
pipe.cmds = pipe.cmds[:0] pipe.cmds = pipe.cmds[:0]
return nil return nil
@ -75,7 +75,7 @@ func (pipe *Pipeline) Discard() error {
// command if any. // command if any.
func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) { func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) {
if pipe.isClosed() { if pipe.isClosed() {
return nil, errClosed return nil, pool.ErrClosed
} }
defer pipe.mu.Unlock() defer pipe.mu.Unlock()

View File

@ -2,7 +2,6 @@ package redis_test
import ( import (
"errors" "errors"
"sync"
"time" "time"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -14,20 +13,6 @@ import (
var _ = Describe("pool", func() { var _ = Describe("pool", func() {
var client *redis.Client var client *redis.Client
var perform = func(n int, cb func()) {
wg := &sync.WaitGroup{}
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer GinkgoRecover()
defer wg.Done()
cb()
}()
}
wg.Wait()
}
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(&redis.Options{ client = redis.NewClient(&redis.Options{
Addr: redisAddr, Addr: redisAddr,
@ -108,12 +93,11 @@ var _ = Describe("pool", func() {
It("should remove broken connections", func() { It("should remove broken connections", func() {
cn, _, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn.Close()).NotTo(HaveOccurred()) cn.NetConn = &badConn{}
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())
err = client.Ping().Err() err = client.Ping().Err()
Expect(err).To(HaveOccurred()) Expect(err).To(MatchError("bad connection"))
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
val, err := client.Ping().Result() val, err := client.Ping().Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())

View File

@ -54,6 +54,7 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
if err != nil { if err != nil {
return err return err
} }
c.putConn(cn, err)
args := make([]interface{}, 1+len(channels)) args := make([]interface{}, 1+len(channels))
args[0] = redisCmd args[0] = redisCmd
@ -306,6 +307,9 @@ func (c *PubSub) putConn(cn *pool.Conn, err error) {
} }
func (c *PubSub) resubscribe() { func (c *PubSub) resubscribe() {
if c.base.closed() {
return
}
if len(c.channels) > 0 { if len(c.channels) > 0 {
if err := c.Subscribe(c.channels...); err != nil { if err := c.Subscribe(c.channels...); err != nil {
Logger.Printf("Subscribe failed: %s", err) Logger.Printf("Subscribe failed: %s", err)

View File

@ -291,10 +291,10 @@ var _ = Describe("PubSub", func() {
expectReceiveMessageOnError := func(pubsub *redis.PubSub) { expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
cn1, _, err := pubsub.Pool().Get() cn1, _, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn1.SetNetConn(&badConn{ cn1.NetConn = &badConn{
readErr: io.EOF, readErr: io.EOF,
writeErr: io.EOF, writeErr: io.EOF,
}) }
done := make(chan bool, 1) done := make(chan bool, 1)
go func() { go func() {

View File

@ -45,17 +45,11 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) {
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
if isBadConn(err, allowTimeout) { if isBadConn(err, allowTimeout) {
err = c.connPool.Replace(cn, err) _ = c.connPool.Replace(cn, err)
if err != nil {
Logger.Printf("pool.Remove failed: %s", err)
}
return false return false
} }
err = c.connPool.Put(cn) _ = c.connPool.Put(cn)
if err != nil {
Logger.Printf("pool.Put failed: %s", err)
}
return true return true
} }
@ -121,6 +115,10 @@ func (c *baseClient) process(cmd Cmder) {
} }
} }
func (c *baseClient) closed() bool {
return c.connPool.Closed()
}
// Close closes the client, releasing any open resources. // Close closes the client, releasing any open resources.
// //
// It is rare to Close a Client, as the Client is meant to be // It is rare to Close a Client, as the Client is meant to be

View File

@ -160,7 +160,7 @@ var _ = Describe("Client", func() {
cn, _, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{}) cn.NetConn = &badConn{}
err = client.Pool().Put(cn) err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())

View File

@ -149,7 +149,7 @@ func (ring *Ring) getClient(key string) (*Client, error) {
ring.mx.RLock() ring.mx.RLock()
if ring.closed { if ring.closed {
return nil, errClosed return nil, pool.ErrClosed
} }
name := ring.hash.Get(hashtag.Key(key)) name := ring.hash.Get(hashtag.Key(key))
@ -277,7 +277,7 @@ func (pipe *RingPipeline) process(cmd Cmder) {
// Discard resets the pipeline and discards queued commands. // Discard resets the pipeline and discards queued commands.
func (pipe *RingPipeline) Discard() error { func (pipe *RingPipeline) Discard() error {
if pipe.closed { if pipe.closed {
return errClosed return pool.ErrClosed
} }
pipe.cmds = pipe.cmds[:0] pipe.cmds = pipe.cmds[:0]
return nil return nil
@ -287,7 +287,7 @@ func (pipe *RingPipeline) Discard() error {
// command if any. // command if any.
func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) { func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) {
if pipe.closed { if pipe.closed {
return nil, errClosed return nil, pool.ErrClosed
} }
if len(pipe.cmds) == 0 { if len(pipe.cmds) == 0 {
return pipe.cmds, nil return pipe.cmds, nil