Merge pull request #42 from go-redis/fix/rewrite_rate_limiter

Rewrite rate limiter.
This commit is contained in:
Vladimir Mihailenco 2014-10-03 10:05:34 +03:00
commit 48ff0f00a2
3 changed files with 59 additions and 38 deletions

View File

@ -115,12 +115,7 @@ func newConnPool(dial func() (*conn, error), opt *options) *connPool {
} }
func (p *connPool) new() (*conn, error) { func (p *connPool) new() (*conn, error) {
select { if !p.rl.Check() {
case _, ok := <-p.rl.C:
if !ok {
return nil, errClosed
}
default:
return nil, errRateLimited return nil, errRateLimited
} }
return p.dial() return p.dial()
@ -263,7 +258,7 @@ func (p *connPool) Close() error {
return nil return nil
} }
p.closed = true p.closed = true
close(p.rl.C) p.rl.Close()
var retErr error var retErr error
for { for {
e := p.conns.Front() e := p.conns.Front()

View File

@ -1,33 +1,52 @@
package redis package redis
import ( import (
"sync/atomic"
"time" "time"
) )
type rateLimiter struct { type rateLimiter struct {
C chan struct{} v int64
_closed int64
} }
func newRateLimiter(limit time.Duration, chanSize int) *rateLimiter { func newRateLimiter(limit time.Duration, bucketSize int) *rateLimiter {
rl := &rateLimiter{ rl := &rateLimiter{
C: make(chan struct{}, chanSize), v: int64(bucketSize),
} }
for i := 0; i < chanSize; i++ { go rl.loop(limit, int64(bucketSize))
rl.C <- struct{}{}
}
go rl.loop(limit)
return rl return rl
} }
func (rl *rateLimiter) loop(limit time.Duration) { func (rl *rateLimiter) loop(limit time.Duration, bucketSize int64) {
defer func() {
recover()
}()
for { for {
select { if rl.closed() {
case rl.C <- struct{}{}: break
default: }
if v := atomic.LoadInt64(&rl.v); v < bucketSize {
atomic.AddInt64(&rl.v, 1)
} }
time.Sleep(limit) time.Sleep(limit)
} }
} }
func (rl *rateLimiter) Check() bool {
for {
if v := atomic.LoadInt64(&rl.v); v > 0 {
if atomic.CompareAndSwapInt64(&rl.v, v, v-1) {
return true
}
}
return false
}
}
func (rl *rateLimiter) Close() error {
atomic.StoreInt64(&rl._closed, 1)
return nil
}
func (rl *rateLimiter) closed() bool {
return atomic.LoadInt64(&rl._closed) == 1
}

View File

@ -2833,14 +2833,16 @@ func (t *RedisTest) transactionalIncr(c *C) ([]redis.Cmder, error) {
} }
func (t *RedisTest) TestWatchUnwatch(c *C) { func (t *RedisTest) TestWatchUnwatch(c *C) {
const N = 10000 var n = 10000
if testing.Short() {
n = 1000
}
set := t.client.Set("key", "0") set := t.client.Set("key", "0")
c.Assert(set.Err(), IsNil) c.Assert(set.Err(), IsNil)
c.Assert(set.Val(), Equals, "OK")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
for i := 0; i < N; i++ { for i := 0; i < n; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
@ -2858,19 +2860,22 @@ func (t *RedisTest) TestWatchUnwatch(c *C) {
} }
wg.Wait() wg.Wait()
get := t.client.Get("key") val, err := t.client.Get("key").Int64()
c.Assert(get.Err(), IsNil) c.Assert(err, IsNil)
c.Assert(get.Val(), Equals, strconv.FormatInt(N, 10)) c.Assert(val, Equals, int64(n))
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func (t *RedisTest) TestRaceEcho(c *C) { func (t *RedisTest) TestRaceEcho(c *C) {
const N = 10000 var n = 10000
if testing.Short() {
n = 1000
}
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(N) wg.Add(n)
for i := 0; i < N; i++ { for i := 0; i < n; i++ {
go func(i int) { go func(i int) {
msg := "echo" + strconv.Itoa(i) msg := "echo" + strconv.Itoa(i)
echo := t.client.Echo(msg) echo := t.client.Echo(msg)
@ -2883,14 +2888,16 @@ func (t *RedisTest) TestRaceEcho(c *C) {
} }
func (t *RedisTest) TestRaceIncr(c *C) { func (t *RedisTest) TestRaceIncr(c *C) {
const N = 10000 var n = 10000
key := "TestIncrFromGoroutines" if testing.Short() {
n = 1000
}
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(N) wg.Add(n)
for i := int64(0); i < N; i++ { for i := 0; i < n; i++ {
go func() { go func() {
incr := t.client.Incr(key) incr := t.client.Incr("TestRaceIncr")
if err := incr.Err(); err != nil { if err := incr.Err(); err != nil {
panic(err) panic(err)
} }
@ -2899,9 +2906,9 @@ func (t *RedisTest) TestRaceIncr(c *C) {
} }
wg.Wait() wg.Wait()
get := t.client.Get(key) val, err := t.client.Get("TestRaceIncr").Result()
c.Assert(get.Err(), IsNil) c.Assert(err, IsNil)
c.Assert(get.Val(), Equals, strconv.Itoa(N)) c.Assert(val, Equals, strconv.Itoa(n))
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------