Merge pull request #42 from go-redis/fix/rewrite_rate_limiter

Rewrite rate limiter.
This commit is contained in:
Vladimir Mihailenco 2014-10-03 10:05:34 +03:00
commit 48ff0f00a2
3 changed files with 59 additions and 38 deletions

View File

@ -115,12 +115,7 @@ func newConnPool(dial func() (*conn, error), opt *options) *connPool {
}
func (p *connPool) new() (*conn, error) {
select {
case _, ok := <-p.rl.C:
if !ok {
return nil, errClosed
}
default:
if !p.rl.Check() {
return nil, errRateLimited
}
return p.dial()
@ -263,7 +258,7 @@ func (p *connPool) Close() error {
return nil
}
p.closed = true
close(p.rl.C)
p.rl.Close()
var retErr error
for {
e := p.conns.Front()

View File

@ -1,33 +1,52 @@
package redis
import (
"sync/atomic"
"time"
)
type rateLimiter struct {
C chan struct{}
v int64
_closed int64
}
func newRateLimiter(limit time.Duration, chanSize int) *rateLimiter {
func newRateLimiter(limit time.Duration, bucketSize int) *rateLimiter {
rl := &rateLimiter{
C: make(chan struct{}, chanSize),
v: int64(bucketSize),
}
for i := 0; i < chanSize; i++ {
rl.C <- struct{}{}
}
go rl.loop(limit)
go rl.loop(limit, int64(bucketSize))
return rl
}
func (rl *rateLimiter) loop(limit time.Duration) {
defer func() {
recover()
}()
func (rl *rateLimiter) loop(limit time.Duration, bucketSize int64) {
for {
select {
case rl.C <- struct{}{}:
default:
if rl.closed() {
break
}
if v := atomic.LoadInt64(&rl.v); v < bucketSize {
atomic.AddInt64(&rl.v, 1)
}
time.Sleep(limit)
}
}
func (rl *rateLimiter) Check() bool {
for {
if v := atomic.LoadInt64(&rl.v); v > 0 {
if atomic.CompareAndSwapInt64(&rl.v, v, v-1) {
return true
}
}
return false
}
}
func (rl *rateLimiter) Close() error {
atomic.StoreInt64(&rl._closed, 1)
return nil
}
func (rl *rateLimiter) closed() bool {
return atomic.LoadInt64(&rl._closed) == 1
}

View File

@ -2833,14 +2833,16 @@ func (t *RedisTest) transactionalIncr(c *C) ([]redis.Cmder, error) {
}
func (t *RedisTest) TestWatchUnwatch(c *C) {
const N = 10000
var n = 10000
if testing.Short() {
n = 1000
}
set := t.client.Set("key", "0")
c.Assert(set.Err(), IsNil)
c.Assert(set.Val(), Equals, "OK")
wg := &sync.WaitGroup{}
for i := 0; i < N; i++ {
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer wg.Done()
@ -2858,19 +2860,22 @@ func (t *RedisTest) TestWatchUnwatch(c *C) {
}
wg.Wait()
get := t.client.Get("key")
c.Assert(get.Err(), IsNil)
c.Assert(get.Val(), Equals, strconv.FormatInt(N, 10))
val, err := t.client.Get("key").Int64()
c.Assert(err, IsNil)
c.Assert(val, Equals, int64(n))
}
//------------------------------------------------------------------------------
func (t *RedisTest) TestRaceEcho(c *C) {
const N = 10000
var n = 10000
if testing.Short() {
n = 1000
}
wg := &sync.WaitGroup{}
wg.Add(N)
for i := 0; i < N; i++ {
wg.Add(n)
for i := 0; i < n; i++ {
go func(i int) {
msg := "echo" + strconv.Itoa(i)
echo := t.client.Echo(msg)
@ -2883,14 +2888,16 @@ func (t *RedisTest) TestRaceEcho(c *C) {
}
func (t *RedisTest) TestRaceIncr(c *C) {
const N = 10000
key := "TestIncrFromGoroutines"
var n = 10000
if testing.Short() {
n = 1000
}
wg := &sync.WaitGroup{}
wg.Add(N)
for i := int64(0); i < N; i++ {
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
incr := t.client.Incr(key)
incr := t.client.Incr("TestRaceIncr")
if err := incr.Err(); err != nil {
panic(err)
}
@ -2899,9 +2906,9 @@ func (t *RedisTest) TestRaceIncr(c *C) {
}
wg.Wait()
get := t.client.Get(key)
c.Assert(get.Err(), IsNil)
c.Assert(get.Val(), Equals, strconv.Itoa(N))
val, err := t.client.Get("TestRaceIncr").Result()
c.Assert(err, IsNil)
c.Assert(val, Equals, strconv.Itoa(n))
}
//------------------------------------------------------------------------------