diff --git a/Makefile b/Makefile index b7867b48..9ee35b2c 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ all: testdeps - go test ./... -test.cpu=1,2,4 - go test ./... -test.short -test.race + go test ./... + go test ./... -short -race testdeps: testdata/redis/src/redis-server diff --git a/cluster_test.go b/cluster_test.go index 392a898d..7423a7eb 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -66,7 +66,10 @@ func startCluster(scenario *clusterScenario) error { return err } - client := redis.NewClient(&redis.Options{Addr: "127.0.0.1:" + port}) + client := redis.NewClient(&redis.Options{ + Addr: ":" + port, + }) + info, err := client.ClusterNodes().Result() if err != nil { return err diff --git a/command_test.go b/command_test.go index e7ebc602..2b218acf 100644 --- a/command_test.go +++ b/command_test.go @@ -1,35 +1,21 @@ package redis_test import ( - "bytes" - "strconv" - "sync" - "testing" - "time" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "gopkg.in/redis.v3" - "gopkg.in/redis.v3/internal/pool" ) var _ = Describe("Command", func() { var client *redis.Client - connect := func() *redis.Client { - return redis.NewClient(&redis.Options{ - Addr: redisAddr, - PoolTimeout: time.Minute, - }) - } - BeforeEach(func() { - client = connect() + client = redis.NewClient(redisOptions()) + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) }) AfterEach(func() { - Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred()) }) @@ -54,64 +40,6 @@ var _ = Describe("Command", func() { Expect(set.Val()).To(Equal("OK")) }) - It("should escape special chars", func() { - set := client.Set("key", "hello1\r\nhello2\r\n", 0) - Expect(set.Err()).NotTo(HaveOccurred()) - Expect(set.Val()).To(Equal("OK")) - - get := client.Get("key") - Expect(get.Err()).NotTo(HaveOccurred()) - Expect(get.Val()).To(Equal("hello1\r\nhello2\r\n")) - }) - - It("should handle big vals", func() { - bigVal := string(bytes.Repeat([]byte{'*'}, 1<<16)) - - err := client.Set("key", bigVal, 0).Err() - Expect(err).NotTo(HaveOccurred()) - - // Reconnect to get new connection. - Expect(client.Close()).To(BeNil()) - client = connect() - - got, err := client.Get("key").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(len(got)).To(Equal(len(bigVal))) - Expect(got).To(Equal(bigVal)) - }) - - It("should handle many keys #1", func() { - const n = 100000 - for i := 0; i < n; i++ { - client.Set("keys.key"+strconv.Itoa(i), "hello"+strconv.Itoa(i), 0) - } - keys := client.Keys("keys.*") - Expect(keys.Err()).NotTo(HaveOccurred()) - Expect(len(keys.Val())).To(Equal(n)) - }) - - It("should handle many keys #2", func() { - const n = 100000 - - keys := []string{"non-existent-key"} - for i := 0; i < n; i++ { - key := "keys.key" + strconv.Itoa(i) - client.Set(key, "hello"+strconv.Itoa(i), 0) - keys = append(keys, key) - } - keys = append(keys, "non-existent-key") - - mget := client.MGet(keys...) - Expect(mget.Err()).NotTo(HaveOccurred()) - Expect(len(mget.Val())).To(Equal(n + 2)) - vals := mget.Val() - for i := 0; i < n; i++ { - Expect(vals[i+1]).To(Equal("hello" + strconv.Itoa(i))) - } - Expect(vals[0]).To(BeNil()) - Expect(vals[n+1]).To(BeNil()) - }) - It("should convert strings via helpers", func() { set := client.Set("key", "10", 0) Expect(set.Err()).NotTo(HaveOccurred()) @@ -129,126 +57,4 @@ var _ = Describe("Command", func() { Expect(f).To(Equal(float64(10))) }) - It("Cmd should return string", func() { - cmd := redis.NewCmd("PING") - client.Process(cmd) - Expect(cmd.Err()).NotTo(HaveOccurred()) - Expect(cmd.Val()).To(Equal("PONG")) - }) - - Describe("races", func() { - var C, N = 10, 1000 - if testing.Short() { - C = 3 - N = 100 - } - - It("should echo", func() { - perform(C, func() { - for i := 0; i < N; i++ { - msg := "echo" + strconv.Itoa(i) - echo, err := client.Echo(msg).Result() - Expect(err).NotTo(HaveOccurred()) - Expect(echo).To(Equal(msg)) - } - }) - }) - - It("should incr", func() { - key := "TestIncrFromGoroutines" - - perform(C, func() { - for i := 0; i < N; i++ { - err := client.Incr(key).Err() - Expect(err).NotTo(HaveOccurred()) - } - }) - - val, err := client.Get(key).Int64() - Expect(err).NotTo(HaveOccurred()) - Expect(val).To(Equal(int64(C * N))) - }) - - It("should handle big vals", func() { - client2 := connect() - defer client2.Close() - - bigVal := string(bytes.Repeat([]byte{'*'}, 1<<16)) - - var wg sync.WaitGroup - wg.Add(2) - - go func() { - defer wg.Done() - perform(C, func() { - for i := 0; i < N; i++ { - got, err := client.Get("key").Result() - if err == redis.Nil { - continue - } - Expect(got).To(Equal(bigVal)) - } - }) - }() - - go func() { - defer wg.Done() - perform(C, func() { - for i := 0; i < N; i++ { - err := client2.Set("key", bigVal, 0).Err() - Expect(err).NotTo(HaveOccurred()) - } - }) - }() - - wg.Wait() - }) - - It("should PubSub", func() { - connPool := client.Pool() - connPool.(*pool.ConnPool).DialLimiter = nil - - var wg sync.WaitGroup - wg.Add(2) - - go func() { - defer wg.Done() - perform(C, func() { - for i := 0; i < N; i++ { - pubsub, err := client.Subscribe("mychannel") - Expect(err).NotTo(HaveOccurred()) - - go func() { - defer GinkgoRecover() - - time.Sleep(time.Millisecond) - err := pubsub.Close() - Expect(err).NotTo(HaveOccurred()) - }() - - _, err = pubsub.ReceiveMessage() - Expect(err.Error()).To(ContainSubstring("closed")) - } - }) - }() - - go func() { - defer wg.Done() - perform(C, func() { - for i := 0; i < N; i++ { - val := "echo" + strconv.Itoa(i) - echo, err := client.Echo(val).Result() - Expect(err).NotTo(HaveOccurred()) - Expect(echo).To(Equal(val)) - } - }) - }() - - wg.Wait() - - Expect(connPool.Len()).To(Equal(connPool.FreeLen())) - Expect(connPool.Len()).To(BeNumerically("<=", 10)) - }) - }) - }) diff --git a/commands_test.go b/commands_test.go index 28139676..6e2e6c19 100644 --- a/commands_test.go +++ b/commands_test.go @@ -19,14 +19,11 @@ var _ = Describe("Commands", func() { var client *redis.Client BeforeEach(func() { - client = redis.NewClient(&redis.Options{ - Addr: redisAddr, - PoolTimeout: 30 * time.Second, - }) + client = redis.NewClient(redisOptions()) + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) }) AfterEach(func() { - Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred()) }) @@ -299,7 +296,7 @@ var _ = Describe("Commands", func() { }) It("should Move", func() { - move := client.Move("key", 1) + move := client.Move("key", 2) Expect(move.Err()).NotTo(HaveOccurred()) Expect(move.Val()).To(Equal(false)) @@ -307,7 +304,7 @@ var _ = Describe("Commands", func() { Expect(set.Err()).NotTo(HaveOccurred()) Expect(set.Val()).To(Equal("OK")) - move = client.Move("key", 1) + move = client.Move("key", 2) Expect(move.Err()).NotTo(HaveOccurred()) Expect(move.Val()).To(Equal(true)) @@ -315,7 +312,7 @@ var _ = Describe("Commands", func() { Expect(get.Err()).To(Equal(redis.Nil)) Expect(get.Val()).To(Equal("")) - sel := client.Select(1) + sel := client.Select(2) Expect(sel.Err()).NotTo(HaveOccurred()) Expect(sel.Val()).To(Equal("OK")) @@ -323,7 +320,7 @@ var _ = Describe("Commands", func() { Expect(get.Err()).NotTo(HaveOccurred()) Expect(get.Val()).To(Equal("hello")) Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) - Expect(client.Select(0).Err()).NotTo(HaveOccurred()) + Expect(client.Select(1).Err()).NotTo(HaveOccurred()) }) It("should Object", func() { diff --git a/example_test.go b/example_test.go index e5cb5b28..dc4e1bde 100644 --- a/example_test.go +++ b/example_test.go @@ -12,10 +12,9 @@ import ( var client *redis.Client func init() { - client = redis.NewClient(&redis.Options{ - Addr: ":6379", - DialTimeout: 10 * time.Second, - }) + opt := redisOptions() + opt.Addr = ":6379" + client = redis.NewClient(opt) client.FlushDb() } diff --git a/internal/pool/conn.go b/internal/pool/conn.go index c3768862..0fbb4193 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -2,6 +2,7 @@ package pool import ( "bufio" + "io" "net" "sync/atomic" "time" @@ -78,6 +79,17 @@ func (cn *Conn) RemoteAddr() net.Addr { return cn.NetConn.RemoteAddr() } +func (cn *Conn) ReadN(n int) ([]byte, error) { + if d := n - cap(cn.Buf); d > 0 { + cn.Buf = cn.Buf[:cap(cn.Buf)] + cn.Buf = append(cn.Buf, make([]byte, d)...) + } else { + cn.Buf = cn.Buf[:n] + } + _, err := io.ReadFull(cn.Rd, cn.Buf) + return cn.Buf, err +} + func (cn *Conn) Close() error { return cn.NetConn.Close() } diff --git a/internal/pool/conn_list.go b/internal/pool/conn_list.go index b3f58704..61bf99ba 100644 --- a/internal/pool/conn_list.go +++ b/internal/pool/conn_list.go @@ -43,7 +43,7 @@ func (l *connList) Add(cn *Conn) { l.mu.Lock() for i, c := range l.cns { if c == nil { - cn.idx = int32(i) + cn.SetIndex(i) l.cns[i] = cn l.mu.Unlock() return @@ -65,22 +65,25 @@ func (l *connList) Remove(idx int) { l.mu.Lock() if l.cns != nil { l.cns[idx] = nil - l.len -= 1 + atomic.AddInt32(&l.len, -1) } l.mu.Unlock() } -func (l *connList) Close() error { +func (l *connList) Reset() []*Conn { l.mu.Lock() - for _, c := range l.cns { - if c == nil { + + for _, cn := range l.cns { + if cn == nil { continue } - c.idx = -1 - c.Close() + cn.SetIndex(-1) } + + cns := l.cns l.cns = nil l.len = 0 + l.mu.Unlock() - return nil + return cns } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 4de11fc6..932146ea 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -5,13 +5,14 @@ import ( "fmt" "log" "net" + "os" "sync/atomic" "time" "gopkg.in/bsm/ratelimit.v1" ) -var Logger *log.Logger +var Logger = log.New(os.Stderr, "pg: ", log.LstdFlags) var ( ErrClosed = errors.New("redis: client is closed") @@ -47,6 +48,7 @@ type dialer func() (net.Conn, error) type ConnPool struct { _dial dialer DialLimiter *ratelimit.RateLimiter + OnClose func(*Conn) error poolTimeout time.Duration idleTimeout time.Duration @@ -74,19 +76,11 @@ func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Durati freeConns: newConnStack(poolSize), } if idleTimeout > 0 { - go p.reaper() + go p.reaper(getIdleCheckFrequency()) } return p } -func (p *ConnPool) Closed() bool { - return atomic.LoadInt32(&p._closed) == 1 -} - -func (p *ConnPool) isIdle(cn *Conn) bool { - return p.idleTimeout > 0 && time.Since(cn.UsedAt) > p.idleTimeout -} - func (p *ConnPool) Add(cn *Conn) bool { if !p.conns.Reserve() { return false @@ -266,23 +260,43 @@ func (p *ConnPool) Stats() *PoolStats { return &stats } +func (p *ConnPool) Closed() bool { + return atomic.LoadInt32(&p._closed) == 1 +} + func (p *ConnPool) Close() (retErr error) { if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) { return ErrClosed } + // Wait for app to free connections, but don't close them immediately. for i := 0; i < p.Len(); i++ { if cn := p.wait(); cn == nil { break } } + // Close all connections. - if err := p.conns.Close(); err != nil { - retErr = err + cns := p.conns.Reset() + for _, cn := range cns { + if cn == nil { + continue + } + if err := p.closeConn(cn); err != nil && retErr == nil { + retErr = err + } } + return retErr } +func (p *ConnPool) closeConn(cn *Conn) error { + if p.OnClose != nil { + _ = p.OnClose(cn) + } + return cn.Close() +} + func (p *ConnPool) ReapStaleConns() (n int, err error) { for { cn := p.freeConns.ShiftStale(p.idleTimeout) @@ -297,8 +311,8 @@ func (p *ConnPool) ReapStaleConns() (n int, err error) { return } -func (p *ConnPool) reaper() { - ticker := time.NewTicker(time.Minute) +func (p *ConnPool) reaper(frequency time.Duration) { + ticker := time.NewTicker(frequency) defer ticker.Stop() for _ = range ticker.C { @@ -324,3 +338,19 @@ func (p *ConnPool) loadLastErr() string { } return "" } + +//------------------------------------------------------------------------------ + +var idleCheckFrequency atomic.Value + +func SetIdleCheckFrequency(d time.Duration) { + idleCheckFrequency.Store(d) +} + +func getIdleCheckFrequency() time.Duration { + v := idleCheckFrequency.Load() + if v == nil { + return time.Minute + } + return v.(time.Duration) +} diff --git a/main_test.go b/main_test.go index e3e747fe..67905cf8 100644 --- a/main_test.go +++ b/main_test.go @@ -15,6 +15,7 @@ import ( . "github.com/onsi/gomega" "gopkg.in/redis.v3" + "gopkg.in/redis.v3/internal/pool" ) const ( @@ -52,6 +53,8 @@ var cluster = &clusterScenario{ var _ = BeforeSuite(func() { var err error + pool.SetIdleCheckFrequency(time.Second) // be aggressive in tests + redisMain, err = startRedis(redisPort) Expect(err).NotTo(HaveOccurred()) @@ -99,31 +102,49 @@ func TestGinkgoSuite(t *testing.T) { //------------------------------------------------------------------------------ -func perform(n int, cb func()) { +func redisOptions() *redis.Options { + return &redis.Options{ + Addr: redisAddr, + DB: 15, + DialTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + PoolSize: 10, + PoolTimeout: 30 * time.Second, + IdleTimeout: time.Second, // be aggressive in tests + } +} + +func perform(n int, cb func(int)) { var wg sync.WaitGroup for i := 0; i < n; i++ { wg.Add(1) - go func() { + go func(i int) { defer GinkgoRecover() defer wg.Done() - cb() - }() + cb(i) + }(i) } wg.Wait() } func eventually(fn func() error, timeout time.Duration) error { - done := make(chan struct{}) var exit int32 - var err error + var retErr error + var mu sync.Mutex + done := make(chan struct{}) + go func() { for atomic.LoadInt32(&exit) == 0 { - err = fn() + err := fn() if err == nil { close(done) return } + mu.Lock() + retErr = err + mu.Unlock() time.Sleep(timeout / 100) } }() @@ -133,6 +154,9 @@ func eventually(fn func() error, timeout time.Duration) error { return nil case <-time.After(timeout): atomic.StoreInt32(&exit, 1) + mu.Lock() + err := retErr + mu.Unlock() return err } } diff --git a/multi_test.go b/multi_test.go index a82a347a..e76c2b33 100644 --- a/multi_test.go +++ b/multi_test.go @@ -14,13 +14,11 @@ var _ = Describe("Multi", func() { var client *redis.Client BeforeEach(func() { - client = redis.NewClient(&redis.Options{ - Addr: redisAddr, - }) + client = redis.NewClient(redisOptions()) + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) }) AfterEach(func() { - Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred()) }) @@ -54,6 +52,7 @@ var _ = Describe("Multi", func() { for i := 0; i < 100; i++ { wg.Add(1) go func() { + defer GinkgoRecover() defer wg.Done() err := incr("key") diff --git a/options.go b/options.go index de91d4e8..935e7564 100644 --- a/options.go +++ b/options.go @@ -60,12 +60,12 @@ func (opt *Options) getNetwork() string { } func (opt *Options) getDialer() func() (net.Conn, error) { - if opt.Dialer == nil { - opt.Dialer = func() (net.Conn, error) { - return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout()) - } + if opt.Dialer != nil { + return opt.Dialer + } + return func() (net.Conn, error) { + return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout()) } - return opt.Dialer } func (opt *Options) getPoolSize() int { diff --git a/parser.go b/parser.go index 3d8742c9..07988579 100644 --- a/parser.go +++ b/parser.go @@ -3,7 +3,6 @@ package redis import ( "errors" "fmt" - "io" "net" "strconv" @@ -245,17 +244,6 @@ func isNilReply(b []byte) bool { b[1] == '-' && b[2] == '1' } -func readN(cn *pool.Conn, n int) ([]byte, error) { - if d := n - cap(cn.Buf); d > 0 { - cn.Buf = cn.Buf[:cap(cn.Buf)] - cn.Buf = append(cn.Buf, make([]byte, d)...) - } else { - cn.Buf = cn.Buf[:n] - } - _, err := io.ReadFull(cn.Rd, cn.Buf) - return cn.Buf, err -} - //------------------------------------------------------------------------------ func parseErrorReply(cn *pool.Conn, line []byte) error { @@ -299,7 +287,7 @@ func parseBytesReply(cn *pool.Conn, line []byte) ([]byte, error) { return nil, err } - b, err := readN(cn, replyLen+2) + b, err := cn.ReadN(replyLen + 2) if err != nil { return nil, err } diff --git a/pipeline_test.go b/pipeline_test.go index ed01baf4..cfbed6e5 100644 --- a/pipeline_test.go +++ b/pipeline_test.go @@ -14,13 +14,11 @@ var _ = Describe("Pipelining", func() { var client *redis.Client BeforeEach(func() { - client = redis.NewClient(&redis.Options{ - Addr: redisAddr, - }) + client = redis.NewClient(redisOptions()) + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) }) AfterEach(func() { - Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred()) }) diff --git a/pool_test.go b/pool_test.go index 006ab0be..bf1ae4aa 100644 --- a/pool_test.go +++ b/pool_test.go @@ -8,16 +8,15 @@ import ( . "github.com/onsi/gomega" "gopkg.in/redis.v3" + "gopkg.in/redis.v3/internal/pool" ) var _ = Describe("pool", func() { var client *redis.Client BeforeEach(func() { - client = redis.NewClient(&redis.Options{ - Addr: redisAddr, - PoolSize: 10, - }) + client = redis.NewClient(redisOptions()) + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) }) AfterEach(func() { @@ -25,7 +24,7 @@ var _ = Describe("pool", func() { }) It("should respect max size", func() { - perform(1000, func() { + perform(1000, func(id int) { val, err := client.Ping().Result() Expect(err).NotTo(HaveOccurred()) Expect(val).To(Equal("PONG")) @@ -38,7 +37,7 @@ var _ = Describe("pool", func() { }) It("should respect max on multi", func() { - perform(1000, func() { + perform(1000, func(id int) { var ping *redis.StatusCmd multi := client.Multi() @@ -60,7 +59,7 @@ var _ = Describe("pool", func() { }) It("should respect max on pipelines", func() { - perform(1000, func() { + perform(1000, func(id int) { pipe := client.Pipeline() ping := pipe.Ping() cmds, err := pipe.Exec() @@ -78,16 +77,17 @@ var _ = Describe("pool", func() { }) It("should respect max on pubsub", func() { - perform(10, func() { + connPool := client.Pool() + connPool.(*pool.ConnPool).DialLimiter = nil + + perform(1000, func(id int) { pubsub := client.PubSub() Expect(pubsub.Subscribe()).NotTo(HaveOccurred()) Expect(pubsub.Close()).NotTo(HaveOccurred()) }) - pool := client.Pool() - Expect(pool.Len()).To(BeNumerically("<=", 10)) - Expect(pool.FreeLen()).To(BeNumerically("<=", 10)) - Expect(pool.Len()).To(Equal(pool.FreeLen())) + Expect(connPool.Len()).To(Equal(connPool.FreeLen())) + Expect(connPool.Len()).To(BeNumerically("<=", 10)) }) It("should remove broken connections", func() { @@ -108,8 +108,8 @@ var _ = Describe("pool", func() { Expect(pool.FreeLen()).To(Equal(1)) stats := pool.Stats() - Expect(stats.Requests).To(Equal(uint32(3))) - Expect(stats.Hits).To(Equal(uint32(2))) + Expect(stats.Requests).To(Equal(uint32(4))) + Expect(stats.Hits).To(Equal(uint32(3))) Expect(stats.Waits).To(Equal(uint32(0))) Expect(stats.Timeouts).To(Equal(uint32(0))) }) @@ -126,8 +126,8 @@ var _ = Describe("pool", func() { Expect(pool.FreeLen()).To(Equal(1)) stats := pool.Stats() - Expect(stats.Requests).To(Equal(uint32(100))) - Expect(stats.Hits).To(Equal(uint32(99))) + Expect(stats.Requests).To(Equal(uint32(101))) + Expect(stats.Hits).To(Equal(uint32(100))) Expect(stats.Waits).To(Equal(uint32(0))) Expect(stats.Timeouts).To(Equal(uint32(0))) }) diff --git a/pubsub_test.go b/pubsub_test.go index 835d7c1a..ca1cdb25 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -14,13 +14,9 @@ import ( var _ = Describe("PubSub", func() { var client *redis.Client - readTimeout := 3 * time.Second BeforeEach(func() { - client = redis.NewClient(&redis.Options{ - Addr: redisAddr, - ReadTimeout: readTimeout, - }) + client = redis.NewClient(redisOptions()) Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) }) diff --git a/race_test.go b/race_test.go new file mode 100644 index 00000000..0b942e11 --- /dev/null +++ b/race_test.go @@ -0,0 +1,213 @@ +package redis_test + +import ( + "bytes" + "fmt" + "net" + "strconv" + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3" + "gopkg.in/redis.v3/internal/pool" +) + +var _ = Describe("races", func() { + var client *redis.Client + + var C, N = 10, 1000 + if testing.Short() { + C = 4 + N = 100 + } + + BeforeEach(func() { + client = redis.NewClient(redisOptions()) + Expect(client.FlushDb().Err()).To(BeNil()) + }) + + AfterEach(func() { + err := client.Close() + Expect(err).NotTo(HaveOccurred()) + }) + + It("should echo", func() { + perform(C, func(id int) { + for i := 0; i < N; i++ { + msg := fmt.Sprintf("echo %d %d", id, i) + echo, err := client.Echo(msg).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(echo).To(Equal(msg)) + } + }) + }) + + It("should incr", func() { + key := "TestIncrFromGoroutines" + + perform(C, func(id int) { + for i := 0; i < N; i++ { + err := client.Incr(key).Err() + Expect(err).NotTo(HaveOccurred()) + } + }) + + val, err := client.Get(key).Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal(int64(C * N))) + }) + + It("should handle many keys", func() { + perform(C, func(id int) { + for i := 0; i < N; i++ { + err := client.Set( + fmt.Sprintf("keys.key-%d-%d", id, i), + fmt.Sprintf("hello-%d-%d", id, i), + 0, + ).Err() + Expect(err).NotTo(HaveOccurred()) + } + }) + + keys := client.Keys("keys.*") + Expect(keys.Err()).NotTo(HaveOccurred()) + Expect(len(keys.Val())).To(Equal(C * N)) + }) + + It("should handle many keys 2", func() { + perform(C, func(id int) { + keys := []string{"non-existent-key"} + for i := 0; i < N; i++ { + key := fmt.Sprintf("keys.key-%d", i) + keys = append(keys, key) + + err := client.Set(key, fmt.Sprintf("hello-%d", i), 0).Err() + Expect(err).NotTo(HaveOccurred()) + } + keys = append(keys, "non-existent-key") + + vals, err := client.MGet(keys...).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(vals)).To(Equal(N + 2)) + + for i := 0; i < N; i++ { + Expect(vals[i+1]).To(Equal(fmt.Sprintf("hello-%d", i))) + } + + Expect(vals[0]).To(BeNil()) + Expect(vals[N+1]).To(BeNil()) + }) + }) + + It("should handle big vals in Get", func() { + bigVal := string(bytes.Repeat([]byte{'*'}, 1<<17)) // 128kb + + err := client.Set("key", bigVal, 0).Err() + Expect(err).NotTo(HaveOccurred()) + + // Reconnect to get new connection. + Expect(client.Close()).To(BeNil()) + client = redis.NewClient(redisOptions()) + + perform(C, func(id int) { + for i := 0; i < N; i++ { + got, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(got).To(Equal(bigVal)) + } + }) + + }) + + It("should handle big vals in Set", func() { + bigVal := string(bytes.Repeat([]byte{'*'}, 1<<17)) // 128kb + + perform(C, func(id int) { + for i := 0; i < N; i++ { + err := client.Set("key", bigVal, 0).Err() + Expect(err).NotTo(HaveOccurred()) + + got, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(got).To(Equal(bigVal)) + } + }) + }) + + It("should PubSub", func() { + connPool := client.Pool() + connPool.(*pool.ConnPool).DialLimiter = nil + + perform(C, func(id int) { + for i := 0; i < N; i++ { + pubsub, err := client.Subscribe(fmt.Sprintf("mychannel%d", id)) + Expect(err).NotTo(HaveOccurred()) + + go func() { + defer GinkgoRecover() + + time.Sleep(time.Millisecond) + err := pubsub.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + _, err = pubsub.ReceiveMessage() + Expect(err.Error()).To(ContainSubstring("closed")) + + val := "echo" + strconv.Itoa(i) + echo, err := client.Echo(val).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(echo).To(Equal(val)) + } + }) + + Expect(connPool.Len()).To(Equal(connPool.FreeLen())) + Expect(connPool.Len()).To(BeNumerically("<=", 10)) + }) + + It("should select db", func() { + err := client.Set("db", 1, 0).Err() + Expect(err).NotTo(HaveOccurred()) + + perform(C, func(id int) { + opt := redisOptions() + opt.DB = int64(id) + client := redis.NewClient(opt) + for i := 0; i < N; i++ { + err := client.Set("db", id, 0).Err() + Expect(err).NotTo(HaveOccurred()) + + n, err := client.Get("db").Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(id))) + } + err := client.Close() + Expect(err).NotTo(HaveOccurred()) + }) + + n, err := client.Get("db").Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + }) + + It("should select DB with read timeout", func() { + perform(C, func(id int) { + opt := redisOptions() + opt.DB = int64(id) + opt.ReadTimeout = time.Nanosecond + client := redis.NewClient(opt) + + perform(C, func(id int) { + err := client.Ping().Err() + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + }) + + err := client.Close() + Expect(err).NotTo(HaveOccurred()) + }) + }) +}) diff --git a/redis_test.go b/redis_test.go index 8b3d8dbd..de297738 100644 --- a/redis_test.go +++ b/redis_test.go @@ -1,8 +1,8 @@ package redis_test import ( + "bytes" "net" - "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -14,9 +14,8 @@ var _ = Describe("Client", func() { var client *redis.Client BeforeEach(func() { - client = redis.NewClient(&redis.Options{ - Addr: redisAddr, - }) + client = redis.NewClient(redisOptions()) + Expect(client.FlushDb().Err()).To(BeNil()) }) AfterEach(func() { @@ -24,7 +23,7 @@ var _ = Describe("Client", func() { }) It("should Stringer", func() { - Expect(client.String()).To(Equal("Redis<:6380 db:0>")) + Expect(client.String()).To(Equal("Redis<:6380 db:15>")) }) It("should ping", func() { @@ -39,6 +38,7 @@ var _ = Describe("Client", func() { It("should support custom dialers", func() { custom := redis.NewClient(&redis.Options{ + Addr: ":1234", Dialer: func() (net.Conn, error) { return net.Dial("tcp", redisAddr) }, @@ -107,45 +107,30 @@ var _ = Describe("Client", func() { Expect(pipeline.Close()).NotTo(HaveOccurred()) }) - It("should support idle-timeouts", func() { - idle := redis.NewClient(&redis.Options{ - Addr: redisAddr, - IdleTimeout: 100 * time.Microsecond, - }) - defer idle.Close() - - Expect(idle.Ping().Err()).NotTo(HaveOccurred()) - time.Sleep(time.Millisecond) - Expect(idle.Ping().Err()).NotTo(HaveOccurred()) - }) - - It("should support DB selection", func() { - db1 := redis.NewClient(&redis.Options{ + It("should select DB", func() { + db2 := redis.NewClient(&redis.Options{ Addr: redisAddr, - DB: 1, + DB: 2, }) - defer db1.Close() + Expect(db2.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(db2.Get("db").Err()).To(Equal(redis.Nil)) + Expect(db2.Set("db", 2, 0).Err()).NotTo(HaveOccurred()) - Expect(db1.Get("key").Err()).To(Equal(redis.Nil)) - Expect(db1.Set("key", "value", 0).Err()).NotTo(HaveOccurred()) + n, err := db2.Get("db").Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(2))) - Expect(client.Get("key").Err()).To(Equal(redis.Nil)) - Expect(db1.Get("key").Val()).To(Equal("value")) - Expect(db1.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(client.Get("db").Err()).To(Equal(redis.Nil)) + + Expect(db2.FlushDb().Err()).NotTo(HaveOccurred()) + Expect(db2.Close()).NotTo(HaveOccurred()) }) - It("should support DB selection with read timeout (issue #135)", func() { - for i := 0; i < 100; i++ { - db1 := redis.NewClient(&redis.Options{ - Addr: redisAddr, - DB: 1, - ReadTimeout: time.Nanosecond, - }) - - err := db1.Ping().Err() - Expect(err).To(HaveOccurred()) - Expect(err.(net.Error).Timeout()).To(BeTrue()) - } + It("should process custom commands", func() { + cmd := redis.NewCmd("PING") + client.Process(cmd) + Expect(cmd.Err()).NotTo(HaveOccurred()) + Expect(cmd.Val()).To(Equal("PONG")) }) It("should retry command on network error", func() { @@ -168,7 +153,7 @@ var _ = Describe("Client", func() { Expect(err).NotTo(HaveOccurred()) }) - It("should maintain conn.UsedAt", func() { + It("should update conn.UsedAt on read/write", func() { cn, err := client.Pool().Get() Expect(err).NotTo(HaveOccurred()) Expect(cn.UsedAt).NotTo(BeZero()) @@ -185,4 +170,31 @@ var _ = Describe("Client", func() { Expect(cn).NotTo(BeNil()) Expect(cn.UsedAt.After(createdAt)).To(BeTrue()) }) + + It("should escape special chars", func() { + set := client.Set("key", "hello1\r\nhello2\r\n", 0) + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal("OK")) + + get := client.Get("key") + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("hello1\r\nhello2\r\n")) + }) + + It("should handle big vals", func() { + bigVal := string(bytes.Repeat([]byte{'*'}, 1<<17)) // 128kb + + err := client.Set("key", bigVal, 0).Err() + Expect(err).NotTo(HaveOccurred()) + + // Reconnect to get new connection. + Expect(client.Close()).To(BeNil()) + client = redis.NewClient(redisOptions()) + + got, err := client.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(got)).To(Equal(len(bigVal))) + Expect(got).To(Equal(bigVal)) + }) + }) diff --git a/sentinel.go b/sentinel.go index 694dd602..887c5eb3 100644 --- a/sentinel.go +++ b/sentinel.go @@ -206,7 +206,7 @@ func (d *sentinelFailover) MasterAddr() (string, error) { func (d *sentinelFailover) setSentinel(sentinel *sentinelClient) { d.discoverSentinels(sentinel) d.sentinel = sentinel - go d.listen() + go d.listen(sentinel) } func (d *sentinelFailover) resetSentinel() error { @@ -278,11 +278,11 @@ func (d *sentinelFailover) closeOldConns(newMaster string) { } } -func (d *sentinelFailover) listen() { +func (d *sentinelFailover) listen(sentinel *sentinelClient) { var pubsub *PubSub for { if pubsub == nil { - pubsub = d.sentinel.PubSub() + pubsub = sentinel.PubSub() if err := pubsub.Subscribe("+switch-master"); err != nil { Logger.Printf("sentinel: Subscribe failed: %s", err) d.resetSentinel() diff --git a/sentinel_test.go b/sentinel_test.go index 14dcf834..693b9574 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -15,6 +15,7 @@ var _ = Describe("Sentinel", func() { MasterName: sentinelName, SentinelAddrs: []string{":" + sentinelPort}, }) + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) }) AfterEach(func() {