Merge pull request #283 from go-redis/fix/race-tests

Extract race tests to separate file. Add more race tests.
This commit is contained in:
Vladimir Mihailenco 2016-03-17 09:46:52 +03:00
commit 998148be40
19 changed files with 411 additions and 330 deletions

View File

@ -1,6 +1,6 @@
all: testdeps all: testdeps
go test ./... -test.cpu=1,2,4 go test ./...
go test ./... -test.short -test.race go test ./... -short -race
testdeps: testdata/redis/src/redis-server testdeps: testdata/redis/src/redis-server

View File

@ -66,7 +66,10 @@ func startCluster(scenario *clusterScenario) error {
return err return err
} }
client := redis.NewClient(&redis.Options{Addr: "127.0.0.1:" + port}) client := redis.NewClient(&redis.Options{
Addr: ":" + port,
})
info, err := client.ClusterNodes().Result() info, err := client.ClusterNodes().Result()
if err != nil { if err != nil {
return err return err

View File

@ -1,35 +1,21 @@
package redis_test package redis_test
import ( import (
"bytes"
"strconv"
"sync"
"testing"
"time"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"gopkg.in/redis.v3" "gopkg.in/redis.v3"
"gopkg.in/redis.v3/internal/pool"
) )
var _ = Describe("Command", func() { var _ = Describe("Command", func() {
var client *redis.Client var client *redis.Client
connect := func() *redis.Client {
return redis.NewClient(&redis.Options{
Addr: redisAddr,
PoolTimeout: time.Minute,
})
}
BeforeEach(func() { BeforeEach(func() {
client = connect() client = redis.NewClient(redisOptions())
Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {
Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
Expect(client.Close()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred())
}) })
@ -54,64 +40,6 @@ var _ = Describe("Command", func() {
Expect(set.Val()).To(Equal("OK")) Expect(set.Val()).To(Equal("OK"))
}) })
It("should escape special chars", func() {
set := client.Set("key", "hello1\r\nhello2\r\n", 0)
Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK"))
get := client.Get("key")
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello1\r\nhello2\r\n"))
})
It("should handle big vals", func() {
bigVal := string(bytes.Repeat([]byte{'*'}, 1<<16))
err := client.Set("key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred())
// Reconnect to get new connection.
Expect(client.Close()).To(BeNil())
client = connect()
got, err := client.Get("key").Result()
Expect(err).NotTo(HaveOccurred())
Expect(len(got)).To(Equal(len(bigVal)))
Expect(got).To(Equal(bigVal))
})
It("should handle many keys #1", func() {
const n = 100000
for i := 0; i < n; i++ {
client.Set("keys.key"+strconv.Itoa(i), "hello"+strconv.Itoa(i), 0)
}
keys := client.Keys("keys.*")
Expect(keys.Err()).NotTo(HaveOccurred())
Expect(len(keys.Val())).To(Equal(n))
})
It("should handle many keys #2", func() {
const n = 100000
keys := []string{"non-existent-key"}
for i := 0; i < n; i++ {
key := "keys.key" + strconv.Itoa(i)
client.Set(key, "hello"+strconv.Itoa(i), 0)
keys = append(keys, key)
}
keys = append(keys, "non-existent-key")
mget := client.MGet(keys...)
Expect(mget.Err()).NotTo(HaveOccurred())
Expect(len(mget.Val())).To(Equal(n + 2))
vals := mget.Val()
for i := 0; i < n; i++ {
Expect(vals[i+1]).To(Equal("hello" + strconv.Itoa(i)))
}
Expect(vals[0]).To(BeNil())
Expect(vals[n+1]).To(BeNil())
})
It("should convert strings via helpers", func() { It("should convert strings via helpers", func() {
set := client.Set("key", "10", 0) set := client.Set("key", "10", 0)
Expect(set.Err()).NotTo(HaveOccurred()) Expect(set.Err()).NotTo(HaveOccurred())
@ -129,126 +57,4 @@ var _ = Describe("Command", func() {
Expect(f).To(Equal(float64(10))) Expect(f).To(Equal(float64(10)))
}) })
It("Cmd should return string", func() {
cmd := redis.NewCmd("PING")
client.Process(cmd)
Expect(cmd.Err()).NotTo(HaveOccurred())
Expect(cmd.Val()).To(Equal("PONG"))
})
Describe("races", func() {
var C, N = 10, 1000
if testing.Short() {
C = 3
N = 100
}
It("should echo", func() {
perform(C, func() {
for i := 0; i < N; i++ {
msg := "echo" + strconv.Itoa(i)
echo, err := client.Echo(msg).Result()
Expect(err).NotTo(HaveOccurred())
Expect(echo).To(Equal(msg))
}
})
})
It("should incr", func() {
key := "TestIncrFromGoroutines"
perform(C, func() {
for i := 0; i < N; i++ {
err := client.Incr(key).Err()
Expect(err).NotTo(HaveOccurred())
}
})
val, err := client.Get(key).Int64()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int64(C * N)))
})
It("should handle big vals", func() {
client2 := connect()
defer client2.Close()
bigVal := string(bytes.Repeat([]byte{'*'}, 1<<16))
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
perform(C, func() {
for i := 0; i < N; i++ {
got, err := client.Get("key").Result()
if err == redis.Nil {
continue
}
Expect(got).To(Equal(bigVal))
}
})
}()
go func() {
defer wg.Done()
perform(C, func() {
for i := 0; i < N; i++ {
err := client2.Set("key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred())
}
})
}()
wg.Wait()
})
It("should PubSub", func() {
connPool := client.Pool()
connPool.(*pool.ConnPool).DialLimiter = nil
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
perform(C, func() {
for i := 0; i < N; i++ {
pubsub, err := client.Subscribe("mychannel")
Expect(err).NotTo(HaveOccurred())
go func() {
defer GinkgoRecover()
time.Sleep(time.Millisecond)
err := pubsub.Close()
Expect(err).NotTo(HaveOccurred())
}()
_, err = pubsub.ReceiveMessage()
Expect(err.Error()).To(ContainSubstring("closed"))
}
})
}()
go func() {
defer wg.Done()
perform(C, func() {
for i := 0; i < N; i++ {
val := "echo" + strconv.Itoa(i)
echo, err := client.Echo(val).Result()
Expect(err).NotTo(HaveOccurred())
Expect(echo).To(Equal(val))
}
})
}()
wg.Wait()
Expect(connPool.Len()).To(Equal(connPool.FreeLen()))
Expect(connPool.Len()).To(BeNumerically("<=", 10))
})
})
}) })

View File

@ -19,14 +19,11 @@ var _ = Describe("Commands", func() {
var client *redis.Client var client *redis.Client
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(&redis.Options{ client = redis.NewClient(redisOptions())
Addr: redisAddr, Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
PoolTimeout: 30 * time.Second,
})
}) })
AfterEach(func() { AfterEach(func() {
Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
Expect(client.Close()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred())
}) })
@ -299,7 +296,7 @@ var _ = Describe("Commands", func() {
}) })
It("should Move", func() { It("should Move", func() {
move := client.Move("key", 1) move := client.Move("key", 2)
Expect(move.Err()).NotTo(HaveOccurred()) Expect(move.Err()).NotTo(HaveOccurred())
Expect(move.Val()).To(Equal(false)) Expect(move.Val()).To(Equal(false))
@ -307,7 +304,7 @@ var _ = Describe("Commands", func() {
Expect(set.Err()).NotTo(HaveOccurred()) Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK")) Expect(set.Val()).To(Equal("OK"))
move = client.Move("key", 1) move = client.Move("key", 2)
Expect(move.Err()).NotTo(HaveOccurred()) Expect(move.Err()).NotTo(HaveOccurred())
Expect(move.Val()).To(Equal(true)) Expect(move.Val()).To(Equal(true))
@ -315,7 +312,7 @@ var _ = Describe("Commands", func() {
Expect(get.Err()).To(Equal(redis.Nil)) Expect(get.Err()).To(Equal(redis.Nil))
Expect(get.Val()).To(Equal("")) Expect(get.Val()).To(Equal(""))
sel := client.Select(1) sel := client.Select(2)
Expect(sel.Err()).NotTo(HaveOccurred()) Expect(sel.Err()).NotTo(HaveOccurred())
Expect(sel.Val()).To(Equal("OK")) Expect(sel.Val()).To(Equal("OK"))
@ -323,7 +320,7 @@ var _ = Describe("Commands", func() {
Expect(get.Err()).NotTo(HaveOccurred()) Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello")) Expect(get.Val()).To(Equal("hello"))
Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
Expect(client.Select(0).Err()).NotTo(HaveOccurred()) Expect(client.Select(1).Err()).NotTo(HaveOccurred())
}) })
It("should Object", func() { It("should Object", func() {

View File

@ -12,10 +12,9 @@ import (
var client *redis.Client var client *redis.Client
func init() { func init() {
client = redis.NewClient(&redis.Options{ opt := redisOptions()
Addr: ":6379", opt.Addr = ":6379"
DialTimeout: 10 * time.Second, client = redis.NewClient(opt)
})
client.FlushDb() client.FlushDb()
} }

View File

@ -2,6 +2,7 @@ package pool
import ( import (
"bufio" "bufio"
"io"
"net" "net"
"sync/atomic" "sync/atomic"
"time" "time"
@ -78,6 +79,17 @@ func (cn *Conn) RemoteAddr() net.Addr {
return cn.NetConn.RemoteAddr() return cn.NetConn.RemoteAddr()
} }
func (cn *Conn) ReadN(n int) ([]byte, error) {
if d := n - cap(cn.Buf); d > 0 {
cn.Buf = cn.Buf[:cap(cn.Buf)]
cn.Buf = append(cn.Buf, make([]byte, d)...)
} else {
cn.Buf = cn.Buf[:n]
}
_, err := io.ReadFull(cn.Rd, cn.Buf)
return cn.Buf, err
}
func (cn *Conn) Close() error { func (cn *Conn) Close() error {
return cn.NetConn.Close() return cn.NetConn.Close()
} }

View File

@ -43,7 +43,7 @@ func (l *connList) Add(cn *Conn) {
l.mu.Lock() l.mu.Lock()
for i, c := range l.cns { for i, c := range l.cns {
if c == nil { if c == nil {
cn.idx = int32(i) cn.SetIndex(i)
l.cns[i] = cn l.cns[i] = cn
l.mu.Unlock() l.mu.Unlock()
return return
@ -65,22 +65,25 @@ func (l *connList) Remove(idx int) {
l.mu.Lock() l.mu.Lock()
if l.cns != nil { if l.cns != nil {
l.cns[idx] = nil l.cns[idx] = nil
l.len -= 1 atomic.AddInt32(&l.len, -1)
} }
l.mu.Unlock() l.mu.Unlock()
} }
func (l *connList) Close() error { func (l *connList) Reset() []*Conn {
l.mu.Lock() l.mu.Lock()
for _, c := range l.cns {
if c == nil { for _, cn := range l.cns {
if cn == nil {
continue continue
} }
c.idx = -1 cn.SetIndex(-1)
c.Close()
} }
cns := l.cns
l.cns = nil l.cns = nil
l.len = 0 l.len = 0
l.mu.Unlock() l.mu.Unlock()
return nil return cns
} }

View File

@ -5,13 +5,14 @@ import (
"fmt" "fmt"
"log" "log"
"net" "net"
"os"
"sync/atomic" "sync/atomic"
"time" "time"
"gopkg.in/bsm/ratelimit.v1" "gopkg.in/bsm/ratelimit.v1"
) )
var Logger *log.Logger var Logger = log.New(os.Stderr, "pg: ", log.LstdFlags)
var ( var (
ErrClosed = errors.New("redis: client is closed") ErrClosed = errors.New("redis: client is closed")
@ -47,6 +48,7 @@ type dialer func() (net.Conn, error)
type ConnPool struct { type ConnPool struct {
_dial dialer _dial dialer
DialLimiter *ratelimit.RateLimiter DialLimiter *ratelimit.RateLimiter
OnClose func(*Conn) error
poolTimeout time.Duration poolTimeout time.Duration
idleTimeout time.Duration idleTimeout time.Duration
@ -74,19 +76,11 @@ func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout time.Durati
freeConns: newConnStack(poolSize), freeConns: newConnStack(poolSize),
} }
if idleTimeout > 0 { if idleTimeout > 0 {
go p.reaper() go p.reaper(getIdleCheckFrequency())
} }
return p return p
} }
func (p *ConnPool) Closed() bool {
return atomic.LoadInt32(&p._closed) == 1
}
func (p *ConnPool) isIdle(cn *Conn) bool {
return p.idleTimeout > 0 && time.Since(cn.UsedAt) > p.idleTimeout
}
func (p *ConnPool) Add(cn *Conn) bool { func (p *ConnPool) Add(cn *Conn) bool {
if !p.conns.Reserve() { if !p.conns.Reserve() {
return false return false
@ -266,23 +260,43 @@ func (p *ConnPool) Stats() *PoolStats {
return &stats return &stats
} }
func (p *ConnPool) Closed() bool {
return atomic.LoadInt32(&p._closed) == 1
}
func (p *ConnPool) Close() (retErr error) { func (p *ConnPool) Close() (retErr error) {
if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) { if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) {
return ErrClosed return ErrClosed
} }
// Wait for app to free connections, but don't close them immediately. // Wait for app to free connections, but don't close them immediately.
for i := 0; i < p.Len(); i++ { for i := 0; i < p.Len(); i++ {
if cn := p.wait(); cn == nil { if cn := p.wait(); cn == nil {
break break
} }
} }
// Close all connections. // Close all connections.
if err := p.conns.Close(); err != nil { cns := p.conns.Reset()
for _, cn := range cns {
if cn == nil {
continue
}
if err := p.closeConn(cn); err != nil && retErr == nil {
retErr = err retErr = err
} }
}
return retErr return retErr
} }
func (p *ConnPool) closeConn(cn *Conn) error {
if p.OnClose != nil {
_ = p.OnClose(cn)
}
return cn.Close()
}
func (p *ConnPool) ReapStaleConns() (n int, err error) { func (p *ConnPool) ReapStaleConns() (n int, err error) {
for { for {
cn := p.freeConns.ShiftStale(p.idleTimeout) cn := p.freeConns.ShiftStale(p.idleTimeout)
@ -297,8 +311,8 @@ func (p *ConnPool) ReapStaleConns() (n int, err error) {
return return
} }
func (p *ConnPool) reaper() { func (p *ConnPool) reaper(frequency time.Duration) {
ticker := time.NewTicker(time.Minute) ticker := time.NewTicker(frequency)
defer ticker.Stop() defer ticker.Stop()
for _ = range ticker.C { for _ = range ticker.C {
@ -324,3 +338,19 @@ func (p *ConnPool) loadLastErr() string {
} }
return "" return ""
} }
//------------------------------------------------------------------------------
var idleCheckFrequency atomic.Value
func SetIdleCheckFrequency(d time.Duration) {
idleCheckFrequency.Store(d)
}
func getIdleCheckFrequency() time.Duration {
v := idleCheckFrequency.Load()
if v == nil {
return time.Minute
}
return v.(time.Duration)
}

View File

@ -15,6 +15,7 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"gopkg.in/redis.v3" "gopkg.in/redis.v3"
"gopkg.in/redis.v3/internal/pool"
) )
const ( const (
@ -52,6 +53,8 @@ var cluster = &clusterScenario{
var _ = BeforeSuite(func() { var _ = BeforeSuite(func() {
var err error var err error
pool.SetIdleCheckFrequency(time.Second) // be aggressive in tests
redisMain, err = startRedis(redisPort) redisMain, err = startRedis(redisPort)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -99,31 +102,49 @@ func TestGinkgoSuite(t *testing.T) {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func perform(n int, cb func()) { func redisOptions() *redis.Options {
return &redis.Options{
Addr: redisAddr,
DB: 15,
DialTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
PoolSize: 10,
PoolTimeout: 30 * time.Second,
IdleTimeout: time.Second, // be aggressive in tests
}
}
func perform(n int, cb func(int)) {
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
wg.Add(1) wg.Add(1)
go func() { go func(i int) {
defer GinkgoRecover() defer GinkgoRecover()
defer wg.Done() defer wg.Done()
cb() cb(i)
}() }(i)
} }
wg.Wait() wg.Wait()
} }
func eventually(fn func() error, timeout time.Duration) error { func eventually(fn func() error, timeout time.Duration) error {
done := make(chan struct{})
var exit int32 var exit int32
var err error var retErr error
var mu sync.Mutex
done := make(chan struct{})
go func() { go func() {
for atomic.LoadInt32(&exit) == 0 { for atomic.LoadInt32(&exit) == 0 {
err = fn() err := fn()
if err == nil { if err == nil {
close(done) close(done)
return return
} }
mu.Lock()
retErr = err
mu.Unlock()
time.Sleep(timeout / 100) time.Sleep(timeout / 100)
} }
}() }()
@ -133,6 +154,9 @@ func eventually(fn func() error, timeout time.Duration) error {
return nil return nil
case <-time.After(timeout): case <-time.After(timeout):
atomic.StoreInt32(&exit, 1) atomic.StoreInt32(&exit, 1)
mu.Lock()
err := retErr
mu.Unlock()
return err return err
} }
} }

View File

@ -14,13 +14,11 @@ var _ = Describe("Multi", func() {
var client *redis.Client var client *redis.Client
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(&redis.Options{ client = redis.NewClient(redisOptions())
Addr: redisAddr, Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
})
}) })
AfterEach(func() { AfterEach(func() {
Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
Expect(client.Close()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred())
}) })
@ -54,6 +52,7 @@ var _ = Describe("Multi", func() {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer GinkgoRecover()
defer wg.Done() defer wg.Done()
err := incr("key") err := incr("key")

View File

@ -60,12 +60,12 @@ func (opt *Options) getNetwork() string {
} }
func (opt *Options) getDialer() func() (net.Conn, error) { func (opt *Options) getDialer() func() (net.Conn, error) {
if opt.Dialer == nil { if opt.Dialer != nil {
opt.Dialer = func() (net.Conn, error) { return opt.Dialer
}
return func() (net.Conn, error) {
return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout()) return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout())
} }
}
return opt.Dialer
} }
func (opt *Options) getPoolSize() int { func (opt *Options) getPoolSize() int {

View File

@ -3,7 +3,6 @@ package redis
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"strconv" "strconv"
@ -245,17 +244,6 @@ func isNilReply(b []byte) bool {
b[1] == '-' && b[2] == '1' b[1] == '-' && b[2] == '1'
} }
func readN(cn *pool.Conn, n int) ([]byte, error) {
if d := n - cap(cn.Buf); d > 0 {
cn.Buf = cn.Buf[:cap(cn.Buf)]
cn.Buf = append(cn.Buf, make([]byte, d)...)
} else {
cn.Buf = cn.Buf[:n]
}
_, err := io.ReadFull(cn.Rd, cn.Buf)
return cn.Buf, err
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func parseErrorReply(cn *pool.Conn, line []byte) error { func parseErrorReply(cn *pool.Conn, line []byte) error {
@ -299,7 +287,7 @@ func parseBytesReply(cn *pool.Conn, line []byte) ([]byte, error) {
return nil, err return nil, err
} }
b, err := readN(cn, replyLen+2) b, err := cn.ReadN(replyLen + 2)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -14,13 +14,11 @@ var _ = Describe("Pipelining", func() {
var client *redis.Client var client *redis.Client
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(&redis.Options{ client = redis.NewClient(redisOptions())
Addr: redisAddr, Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
})
}) })
AfterEach(func() { AfterEach(func() {
Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
Expect(client.Close()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred())
}) })

View File

@ -8,16 +8,15 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"gopkg.in/redis.v3" "gopkg.in/redis.v3"
"gopkg.in/redis.v3/internal/pool"
) )
var _ = Describe("pool", func() { var _ = Describe("pool", func() {
var client *redis.Client var client *redis.Client
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(&redis.Options{ client = redis.NewClient(redisOptions())
Addr: redisAddr, Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
PoolSize: 10,
})
}) })
AfterEach(func() { AfterEach(func() {
@ -25,7 +24,7 @@ var _ = Describe("pool", func() {
}) })
It("should respect max size", func() { It("should respect max size", func() {
perform(1000, func() { perform(1000, func(id int) {
val, err := client.Ping().Result() val, err := client.Ping().Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("PONG")) Expect(val).To(Equal("PONG"))
@ -38,7 +37,7 @@ var _ = Describe("pool", func() {
}) })
It("should respect max on multi", func() { It("should respect max on multi", func() {
perform(1000, func() { perform(1000, func(id int) {
var ping *redis.StatusCmd var ping *redis.StatusCmd
multi := client.Multi() multi := client.Multi()
@ -60,7 +59,7 @@ var _ = Describe("pool", func() {
}) })
It("should respect max on pipelines", func() { It("should respect max on pipelines", func() {
perform(1000, func() { perform(1000, func(id int) {
pipe := client.Pipeline() pipe := client.Pipeline()
ping := pipe.Ping() ping := pipe.Ping()
cmds, err := pipe.Exec() cmds, err := pipe.Exec()
@ -78,16 +77,17 @@ var _ = Describe("pool", func() {
}) })
It("should respect max on pubsub", func() { It("should respect max on pubsub", func() {
perform(10, func() { connPool := client.Pool()
connPool.(*pool.ConnPool).DialLimiter = nil
perform(1000, func(id int) {
pubsub := client.PubSub() pubsub := client.PubSub()
Expect(pubsub.Subscribe()).NotTo(HaveOccurred()) Expect(pubsub.Subscribe()).NotTo(HaveOccurred())
Expect(pubsub.Close()).NotTo(HaveOccurred()) Expect(pubsub.Close()).NotTo(HaveOccurred())
}) })
pool := client.Pool() Expect(connPool.Len()).To(Equal(connPool.FreeLen()))
Expect(pool.Len()).To(BeNumerically("<=", 10)) Expect(connPool.Len()).To(BeNumerically("<=", 10))
Expect(pool.FreeLen()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(Equal(pool.FreeLen()))
}) })
It("should remove broken connections", func() { It("should remove broken connections", func() {
@ -108,8 +108,8 @@ var _ = Describe("pool", func() {
Expect(pool.FreeLen()).To(Equal(1)) Expect(pool.FreeLen()).To(Equal(1))
stats := pool.Stats() stats := pool.Stats()
Expect(stats.Requests).To(Equal(uint32(3))) Expect(stats.Requests).To(Equal(uint32(4)))
Expect(stats.Hits).To(Equal(uint32(2))) Expect(stats.Hits).To(Equal(uint32(3)))
Expect(stats.Waits).To(Equal(uint32(0))) Expect(stats.Waits).To(Equal(uint32(0)))
Expect(stats.Timeouts).To(Equal(uint32(0))) Expect(stats.Timeouts).To(Equal(uint32(0)))
}) })
@ -126,8 +126,8 @@ var _ = Describe("pool", func() {
Expect(pool.FreeLen()).To(Equal(1)) Expect(pool.FreeLen()).To(Equal(1))
stats := pool.Stats() stats := pool.Stats()
Expect(stats.Requests).To(Equal(uint32(100))) Expect(stats.Requests).To(Equal(uint32(101)))
Expect(stats.Hits).To(Equal(uint32(99))) Expect(stats.Hits).To(Equal(uint32(100)))
Expect(stats.Waits).To(Equal(uint32(0))) Expect(stats.Waits).To(Equal(uint32(0)))
Expect(stats.Timeouts).To(Equal(uint32(0))) Expect(stats.Timeouts).To(Equal(uint32(0)))
}) })

View File

@ -14,13 +14,9 @@ import (
var _ = Describe("PubSub", func() { var _ = Describe("PubSub", func() {
var client *redis.Client var client *redis.Client
readTimeout := 3 * time.Second
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(&redis.Options{ client = redis.NewClient(redisOptions())
Addr: redisAddr,
ReadTimeout: readTimeout,
})
Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
}) })

213
race_test.go Normal file
View File

@ -0,0 +1,213 @@
package redis_test
import (
"bytes"
"fmt"
"net"
"strconv"
"testing"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"gopkg.in/redis.v3"
"gopkg.in/redis.v3/internal/pool"
)
var _ = Describe("races", func() {
var client *redis.Client
var C, N = 10, 1000
if testing.Short() {
C = 4
N = 100
}
BeforeEach(func() {
client = redis.NewClient(redisOptions())
Expect(client.FlushDb().Err()).To(BeNil())
})
AfterEach(func() {
err := client.Close()
Expect(err).NotTo(HaveOccurred())
})
It("should echo", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
msg := fmt.Sprintf("echo %d %d", id, i)
echo, err := client.Echo(msg).Result()
Expect(err).NotTo(HaveOccurred())
Expect(echo).To(Equal(msg))
}
})
})
It("should incr", func() {
key := "TestIncrFromGoroutines"
perform(C, func(id int) {
for i := 0; i < N; i++ {
err := client.Incr(key).Err()
Expect(err).NotTo(HaveOccurred())
}
})
val, err := client.Get(key).Int64()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal(int64(C * N)))
})
It("should handle many keys", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
err := client.Set(
fmt.Sprintf("keys.key-%d-%d", id, i),
fmt.Sprintf("hello-%d-%d", id, i),
0,
).Err()
Expect(err).NotTo(HaveOccurred())
}
})
keys := client.Keys("keys.*")
Expect(keys.Err()).NotTo(HaveOccurred())
Expect(len(keys.Val())).To(Equal(C * N))
})
It("should handle many keys 2", func() {
perform(C, func(id int) {
keys := []string{"non-existent-key"}
for i := 0; i < N; i++ {
key := fmt.Sprintf("keys.key-%d", i)
keys = append(keys, key)
err := client.Set(key, fmt.Sprintf("hello-%d", i), 0).Err()
Expect(err).NotTo(HaveOccurred())
}
keys = append(keys, "non-existent-key")
vals, err := client.MGet(keys...).Result()
Expect(err).NotTo(HaveOccurred())
Expect(len(vals)).To(Equal(N + 2))
for i := 0; i < N; i++ {
Expect(vals[i+1]).To(Equal(fmt.Sprintf("hello-%d", i)))
}
Expect(vals[0]).To(BeNil())
Expect(vals[N+1]).To(BeNil())
})
})
It("should handle big vals in Get", func() {
bigVal := string(bytes.Repeat([]byte{'*'}, 1<<17)) // 128kb
err := client.Set("key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred())
// Reconnect to get new connection.
Expect(client.Close()).To(BeNil())
client = redis.NewClient(redisOptions())
perform(C, func(id int) {
for i := 0; i < N; i++ {
got, err := client.Get("key").Result()
Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal))
}
})
})
It("should handle big vals in Set", func() {
bigVal := string(bytes.Repeat([]byte{'*'}, 1<<17)) // 128kb
perform(C, func(id int) {
for i := 0; i < N; i++ {
err := client.Set("key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred())
got, err := client.Get("key").Result()
Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal))
}
})
})
It("should PubSub", func() {
connPool := client.Pool()
connPool.(*pool.ConnPool).DialLimiter = nil
perform(C, func(id int) {
for i := 0; i < N; i++ {
pubsub, err := client.Subscribe(fmt.Sprintf("mychannel%d", id))
Expect(err).NotTo(HaveOccurred())
go func() {
defer GinkgoRecover()
time.Sleep(time.Millisecond)
err := pubsub.Close()
Expect(err).NotTo(HaveOccurred())
}()
_, err = pubsub.ReceiveMessage()
Expect(err.Error()).To(ContainSubstring("closed"))
val := "echo" + strconv.Itoa(i)
echo, err := client.Echo(val).Result()
Expect(err).NotTo(HaveOccurred())
Expect(echo).To(Equal(val))
}
})
Expect(connPool.Len()).To(Equal(connPool.FreeLen()))
Expect(connPool.Len()).To(BeNumerically("<=", 10))
})
It("should select db", func() {
err := client.Set("db", 1, 0).Err()
Expect(err).NotTo(HaveOccurred())
perform(C, func(id int) {
opt := redisOptions()
opt.DB = int64(id)
client := redis.NewClient(opt)
for i := 0; i < N; i++ {
err := client.Set("db", id, 0).Err()
Expect(err).NotTo(HaveOccurred())
n, err := client.Get("db").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(id)))
}
err := client.Close()
Expect(err).NotTo(HaveOccurred())
})
n, err := client.Get("db").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
})
It("should select DB with read timeout", func() {
perform(C, func(id int) {
opt := redisOptions()
opt.DB = int64(id)
opt.ReadTimeout = time.Nanosecond
client := redis.NewClient(opt)
perform(C, func(id int) {
err := client.Ping().Err()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
})
err := client.Close()
Expect(err).NotTo(HaveOccurred())
})
})
})

View File

@ -1,8 +1,8 @@
package redis_test package redis_test
import ( import (
"bytes"
"net" "net"
"time"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -14,9 +14,8 @@ var _ = Describe("Client", func() {
var client *redis.Client var client *redis.Client
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(&redis.Options{ client = redis.NewClient(redisOptions())
Addr: redisAddr, Expect(client.FlushDb().Err()).To(BeNil())
})
}) })
AfterEach(func() { AfterEach(func() {
@ -24,7 +23,7 @@ var _ = Describe("Client", func() {
}) })
It("should Stringer", func() { It("should Stringer", func() {
Expect(client.String()).To(Equal("Redis<:6380 db:0>")) Expect(client.String()).To(Equal("Redis<:6380 db:15>"))
}) })
It("should ping", func() { It("should ping", func() {
@ -39,6 +38,7 @@ var _ = Describe("Client", func() {
It("should support custom dialers", func() { It("should support custom dialers", func() {
custom := redis.NewClient(&redis.Options{ custom := redis.NewClient(&redis.Options{
Addr: ":1234",
Dialer: func() (net.Conn, error) { Dialer: func() (net.Conn, error) {
return net.Dial("tcp", redisAddr) return net.Dial("tcp", redisAddr)
}, },
@ -107,45 +107,30 @@ var _ = Describe("Client", func() {
Expect(pipeline.Close()).NotTo(HaveOccurred()) Expect(pipeline.Close()).NotTo(HaveOccurred())
}) })
It("should support idle-timeouts", func() { It("should select DB", func() {
idle := redis.NewClient(&redis.Options{ db2 := redis.NewClient(&redis.Options{
Addr: redisAddr, Addr: redisAddr,
IdleTimeout: 100 * time.Microsecond, DB: 2,
}) })
defer idle.Close() Expect(db2.FlushDb().Err()).NotTo(HaveOccurred())
Expect(db2.Get("db").Err()).To(Equal(redis.Nil))
Expect(db2.Set("db", 2, 0).Err()).NotTo(HaveOccurred())
Expect(idle.Ping().Err()).NotTo(HaveOccurred()) n, err := db2.Get("db").Int64()
time.Sleep(time.Millisecond) Expect(err).NotTo(HaveOccurred())
Expect(idle.Ping().Err()).NotTo(HaveOccurred()) Expect(n).To(Equal(int64(2)))
Expect(client.Get("db").Err()).To(Equal(redis.Nil))
Expect(db2.FlushDb().Err()).NotTo(HaveOccurred())
Expect(db2.Close()).NotTo(HaveOccurred())
}) })
It("should support DB selection", func() { It("should process custom commands", func() {
db1 := redis.NewClient(&redis.Options{ cmd := redis.NewCmd("PING")
Addr: redisAddr, client.Process(cmd)
DB: 1, Expect(cmd.Err()).NotTo(HaveOccurred())
}) Expect(cmd.Val()).To(Equal("PONG"))
defer db1.Close()
Expect(db1.Get("key").Err()).To(Equal(redis.Nil))
Expect(db1.Set("key", "value", 0).Err()).NotTo(HaveOccurred())
Expect(client.Get("key").Err()).To(Equal(redis.Nil))
Expect(db1.Get("key").Val()).To(Equal("value"))
Expect(db1.FlushDb().Err()).NotTo(HaveOccurred())
})
It("should support DB selection with read timeout (issue #135)", func() {
for i := 0; i < 100; i++ {
db1 := redis.NewClient(&redis.Options{
Addr: redisAddr,
DB: 1,
ReadTimeout: time.Nanosecond,
})
err := db1.Ping().Err()
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
}
}) })
It("should retry command on network error", func() { It("should retry command on network error", func() {
@ -168,7 +153,7 @@ var _ = Describe("Client", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
It("should maintain conn.UsedAt", func() { It("should update conn.UsedAt on read/write", func() {
cn, err := client.Pool().Get() cn, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn.UsedAt).NotTo(BeZero()) Expect(cn.UsedAt).NotTo(BeZero())
@ -185,4 +170,31 @@ var _ = Describe("Client", func() {
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
Expect(cn.UsedAt.After(createdAt)).To(BeTrue()) Expect(cn.UsedAt.After(createdAt)).To(BeTrue())
}) })
It("should escape special chars", func() {
set := client.Set("key", "hello1\r\nhello2\r\n", 0)
Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK"))
get := client.Get("key")
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello1\r\nhello2\r\n"))
})
It("should handle big vals", func() {
bigVal := string(bytes.Repeat([]byte{'*'}, 1<<17)) // 128kb
err := client.Set("key", bigVal, 0).Err()
Expect(err).NotTo(HaveOccurred())
// Reconnect to get new connection.
Expect(client.Close()).To(BeNil())
client = redis.NewClient(redisOptions())
got, err := client.Get("key").Result()
Expect(err).NotTo(HaveOccurred())
Expect(len(got)).To(Equal(len(bigVal)))
Expect(got).To(Equal(bigVal))
})
}) })

View File

@ -206,7 +206,7 @@ func (d *sentinelFailover) MasterAddr() (string, error) {
func (d *sentinelFailover) setSentinel(sentinel *sentinelClient) { func (d *sentinelFailover) setSentinel(sentinel *sentinelClient) {
d.discoverSentinels(sentinel) d.discoverSentinels(sentinel)
d.sentinel = sentinel d.sentinel = sentinel
go d.listen() go d.listen(sentinel)
} }
func (d *sentinelFailover) resetSentinel() error { func (d *sentinelFailover) resetSentinel() error {
@ -278,11 +278,11 @@ func (d *sentinelFailover) closeOldConns(newMaster string) {
} }
} }
func (d *sentinelFailover) listen() { func (d *sentinelFailover) listen(sentinel *sentinelClient) {
var pubsub *PubSub var pubsub *PubSub
for { for {
if pubsub == nil { if pubsub == nil {
pubsub = d.sentinel.PubSub() pubsub = sentinel.PubSub()
if err := pubsub.Subscribe("+switch-master"); err != nil { if err := pubsub.Subscribe("+switch-master"); err != nil {
Logger.Printf("sentinel: Subscribe failed: %s", err) Logger.Printf("sentinel: Subscribe failed: %s", err)
d.resetSentinel() d.resetSentinel()

View File

@ -15,6 +15,7 @@ var _ = Describe("Sentinel", func() {
MasterName: sentinelName, MasterName: sentinelName,
SentinelAddrs: []string{":" + sentinelPort}, SentinelAddrs: []string{":" + sentinelPort},
}) })
Expect(client.FlushDb().Err()).NotTo(HaveOccurred())
}) })
AfterEach(func() { AfterEach(func() {