diff --git a/pool.go b/pool.go index 8a99e3ad..bca4d196 100644 --- a/pool.go +++ b/pool.go @@ -115,12 +115,7 @@ func newConnPool(dial func() (*conn, error), opt *options) *connPool { } func (p *connPool) new() (*conn, error) { - select { - case _, ok := <-p.rl.C: - if !ok { - return nil, errClosed - } - default: + if !p.rl.Check() { return nil, errRateLimited } return p.dial() @@ -263,7 +258,7 @@ func (p *connPool) Close() error { return nil } p.closed = true - close(p.rl.C) + p.rl.Close() var retErr error for { e := p.conns.Front() diff --git a/rate_limit.go b/rate_limit.go index 2534ddc4..a046648c 100644 --- a/rate_limit.go +++ b/rate_limit.go @@ -1,33 +1,52 @@ package redis import ( + "sync/atomic" "time" ) type rateLimiter struct { - C chan struct{} + v int64 + + _closed int64 } -func newRateLimiter(limit time.Duration, chanSize int) *rateLimiter { +func newRateLimiter(limit time.Duration, bucketSize int) *rateLimiter { rl := &rateLimiter{ - C: make(chan struct{}, chanSize), + v: int64(bucketSize), } - for i := 0; i < chanSize; i++ { - rl.C <- struct{}{} - } - go rl.loop(limit) + go rl.loop(limit, int64(bucketSize)) return rl } -func (rl *rateLimiter) loop(limit time.Duration) { - defer func() { - recover() - }() +func (rl *rateLimiter) loop(limit time.Duration, bucketSize int64) { for { - select { - case rl.C <- struct{}{}: - default: + if rl.closed() { + break + } + if v := atomic.LoadInt64(&rl.v); v < bucketSize { + atomic.AddInt64(&rl.v, 1) } time.Sleep(limit) } } + +func (rl *rateLimiter) Check() bool { + for { + if v := atomic.LoadInt64(&rl.v); v > 0 { + if atomic.CompareAndSwapInt64(&rl.v, v, v-1) { + return true + } + } + return false + } +} + +func (rl *rateLimiter) Close() error { + atomic.StoreInt64(&rl._closed, 1) + return nil +} + +func (rl *rateLimiter) closed() bool { + return atomic.LoadInt64(&rl._closed) == 1 +} diff --git a/redis_test.go b/redis_test.go index b1d58037..877be34e 100644 --- a/redis_test.go +++ b/redis_test.go @@ -2833,14 +2833,16 @@ func (t *RedisTest) transactionalIncr(c *C) ([]redis.Cmder, error) { } func (t *RedisTest) TestWatchUnwatch(c *C) { - const N = 10000 + var n = 10000 + if testing.Short() { + n = 1000 + } set := t.client.Set("key", "0") c.Assert(set.Err(), IsNil) - c.Assert(set.Val(), Equals, "OK") wg := &sync.WaitGroup{} - for i := 0; i < N; i++ { + for i := 0; i < n; i++ { wg.Add(1) go func() { defer wg.Done() @@ -2858,19 +2860,22 @@ func (t *RedisTest) TestWatchUnwatch(c *C) { } wg.Wait() - get := t.client.Get("key") - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, strconv.FormatInt(N, 10)) + val, err := t.client.Get("key").Int64() + c.Assert(err, IsNil) + c.Assert(val, Equals, int64(n)) } //------------------------------------------------------------------------------ func (t *RedisTest) TestRaceEcho(c *C) { - const N = 10000 + var n = 10000 + if testing.Short() { + n = 1000 + } wg := &sync.WaitGroup{} - wg.Add(N) - for i := 0; i < N; i++ { + wg.Add(n) + for i := 0; i < n; i++ { go func(i int) { msg := "echo" + strconv.Itoa(i) echo := t.client.Echo(msg) @@ -2883,14 +2888,16 @@ func (t *RedisTest) TestRaceEcho(c *C) { } func (t *RedisTest) TestRaceIncr(c *C) { - const N = 10000 - key := "TestIncrFromGoroutines" + var n = 10000 + if testing.Short() { + n = 1000 + } wg := &sync.WaitGroup{} - wg.Add(N) - for i := int64(0); i < N; i++ { + wg.Add(n) + for i := 0; i < n; i++ { go func() { - incr := t.client.Incr(key) + incr := t.client.Incr("TestRaceIncr") if err := incr.Err(); err != nil { panic(err) } @@ -2899,9 +2906,9 @@ func (t *RedisTest) TestRaceIncr(c *C) { } wg.Wait() - get := t.client.Get(key) - c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, strconv.Itoa(N)) + val, err := t.client.Get("TestRaceIncr").Result() + c.Assert(err, IsNil) + c.Assert(val, Equals, strconv.Itoa(n)) } //------------------------------------------------------------------------------