mirror of https://github.com/go-redis/redis.git
Merge pull request #42 from go-redis/fix/rewrite_rate_limiter
Rewrite rate limiter.
This commit is contained in:
commit
48ff0f00a2
9
pool.go
9
pool.go
|
@ -115,12 +115,7 @@ func newConnPool(dial func() (*conn, error), opt *options) *connPool {
|
|||
}
|
||||
|
||||
func (p *connPool) new() (*conn, error) {
|
||||
select {
|
||||
case _, ok := <-p.rl.C:
|
||||
if !ok {
|
||||
return nil, errClosed
|
||||
}
|
||||
default:
|
||||
if !p.rl.Check() {
|
||||
return nil, errRateLimited
|
||||
}
|
||||
return p.dial()
|
||||
|
@ -263,7 +258,7 @@ func (p *connPool) Close() error {
|
|||
return nil
|
||||
}
|
||||
p.closed = true
|
||||
close(p.rl.C)
|
||||
p.rl.Close()
|
||||
var retErr error
|
||||
for {
|
||||
e := p.conns.Front()
|
||||
|
|
|
@ -1,33 +1,52 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type rateLimiter struct {
|
||||
C chan struct{}
|
||||
v int64
|
||||
|
||||
_closed int64
|
||||
}
|
||||
|
||||
func newRateLimiter(limit time.Duration, chanSize int) *rateLimiter {
|
||||
func newRateLimiter(limit time.Duration, bucketSize int) *rateLimiter {
|
||||
rl := &rateLimiter{
|
||||
C: make(chan struct{}, chanSize),
|
||||
v: int64(bucketSize),
|
||||
}
|
||||
for i := 0; i < chanSize; i++ {
|
||||
rl.C <- struct{}{}
|
||||
}
|
||||
go rl.loop(limit)
|
||||
go rl.loop(limit, int64(bucketSize))
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) loop(limit time.Duration) {
|
||||
defer func() {
|
||||
recover()
|
||||
}()
|
||||
func (rl *rateLimiter) loop(limit time.Duration, bucketSize int64) {
|
||||
for {
|
||||
select {
|
||||
case rl.C <- struct{}{}:
|
||||
default:
|
||||
if rl.closed() {
|
||||
break
|
||||
}
|
||||
if v := atomic.LoadInt64(&rl.v); v < bucketSize {
|
||||
atomic.AddInt64(&rl.v, 1)
|
||||
}
|
||||
time.Sleep(limit)
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) Check() bool {
|
||||
for {
|
||||
if v := atomic.LoadInt64(&rl.v); v > 0 {
|
||||
if atomic.CompareAndSwapInt64(&rl.v, v, v-1) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) Close() error {
|
||||
atomic.StoreInt64(&rl._closed, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) closed() bool {
|
||||
return atomic.LoadInt64(&rl._closed) == 1
|
||||
}
|
||||
|
|
|
@ -2833,14 +2833,16 @@ func (t *RedisTest) transactionalIncr(c *C) ([]redis.Cmder, error) {
|
|||
}
|
||||
|
||||
func (t *RedisTest) TestWatchUnwatch(c *C) {
|
||||
const N = 10000
|
||||
var n = 10000
|
||||
if testing.Short() {
|
||||
n = 1000
|
||||
}
|
||||
|
||||
set := t.client.Set("key", "0")
|
||||
c.Assert(set.Err(), IsNil)
|
||||
c.Assert(set.Val(), Equals, "OK")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
for i := 0; i < N; i++ {
|
||||
for i := 0; i < n; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
@ -2858,19 +2860,22 @@ func (t *RedisTest) TestWatchUnwatch(c *C) {
|
|||
}
|
||||
wg.Wait()
|
||||
|
||||
get := t.client.Get("key")
|
||||
c.Assert(get.Err(), IsNil)
|
||||
c.Assert(get.Val(), Equals, strconv.FormatInt(N, 10))
|
||||
val, err := t.client.Get("key").Int64()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(val, Equals, int64(n))
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
func (t *RedisTest) TestRaceEcho(c *C) {
|
||||
const N = 10000
|
||||
var n = 10000
|
||||
if testing.Short() {
|
||||
n = 1000
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(N)
|
||||
for i := 0; i < N; i++ {
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func(i int) {
|
||||
msg := "echo" + strconv.Itoa(i)
|
||||
echo := t.client.Echo(msg)
|
||||
|
@ -2883,14 +2888,16 @@ func (t *RedisTest) TestRaceEcho(c *C) {
|
|||
}
|
||||
|
||||
func (t *RedisTest) TestRaceIncr(c *C) {
|
||||
const N = 10000
|
||||
key := "TestIncrFromGoroutines"
|
||||
var n = 10000
|
||||
if testing.Short() {
|
||||
n = 1000
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(N)
|
||||
for i := int64(0); i < N; i++ {
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
incr := t.client.Incr(key)
|
||||
incr := t.client.Incr("TestRaceIncr")
|
||||
if err := incr.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -2899,9 +2906,9 @@ func (t *RedisTest) TestRaceIncr(c *C) {
|
|||
}
|
||||
wg.Wait()
|
||||
|
||||
get := t.client.Get(key)
|
||||
c.Assert(get.Err(), IsNil)
|
||||
c.Assert(get.Val(), Equals, strconv.Itoa(N))
|
||||
val, err := t.client.Get("TestRaceIncr").Result()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(val, Equals, strconv.Itoa(n))
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
|
Loading…
Reference in New Issue