forked from mirror/redis
Rewrite rate limiter.
This commit is contained in:
parent
97695ed316
commit
551257a988
9
pool.go
9
pool.go
|
@ -115,12 +115,7 @@ func newConnPool(dial func() (*conn, error), opt *options) *connPool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *connPool) new() (*conn, error) {
|
func (p *connPool) new() (*conn, error) {
|
||||||
select {
|
if !p.rl.Check() {
|
||||||
case _, ok := <-p.rl.C:
|
|
||||||
if !ok {
|
|
||||||
return nil, errClosed
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, errRateLimited
|
return nil, errRateLimited
|
||||||
}
|
}
|
||||||
return p.dial()
|
return p.dial()
|
||||||
|
@ -263,7 +258,7 @@ func (p *connPool) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
p.closed = true
|
p.closed = true
|
||||||
close(p.rl.C)
|
p.rl.Close()
|
||||||
var retErr error
|
var retErr error
|
||||||
for {
|
for {
|
||||||
e := p.conns.Front()
|
e := p.conns.Front()
|
||||||
|
|
|
@ -1,33 +1,52 @@
|
||||||
package redis
|
package redis
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type rateLimiter struct {
|
type rateLimiter struct {
|
||||||
C chan struct{}
|
v int64
|
||||||
|
|
||||||
|
_closed int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRateLimiter(limit time.Duration, chanSize int) *rateLimiter {
|
func newRateLimiter(limit time.Duration, bucketSize int) *rateLimiter {
|
||||||
rl := &rateLimiter{
|
rl := &rateLimiter{
|
||||||
C: make(chan struct{}, chanSize),
|
v: int64(bucketSize),
|
||||||
}
|
}
|
||||||
for i := 0; i < chanSize; i++ {
|
go rl.loop(limit, int64(bucketSize))
|
||||||
rl.C <- struct{}{}
|
|
||||||
}
|
|
||||||
go rl.loop(limit)
|
|
||||||
return rl
|
return rl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rl *rateLimiter) loop(limit time.Duration) {
|
func (rl *rateLimiter) loop(limit time.Duration, bucketSize int64) {
|
||||||
defer func() {
|
|
||||||
recover()
|
|
||||||
}()
|
|
||||||
for {
|
for {
|
||||||
select {
|
if rl.closed() {
|
||||||
case rl.C <- struct{}{}:
|
break
|
||||||
default:
|
}
|
||||||
|
if v := atomic.LoadInt64(&rl.v); v < bucketSize {
|
||||||
|
atomic.AddInt64(&rl.v, 1)
|
||||||
}
|
}
|
||||||
time.Sleep(limit)
|
time.Sleep(limit)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rl *rateLimiter) Check() bool {
|
||||||
|
for {
|
||||||
|
if v := atomic.LoadInt64(&rl.v); v > 0 {
|
||||||
|
if atomic.CompareAndSwapInt64(&rl.v, v, v-1) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *rateLimiter) Close() error {
|
||||||
|
atomic.StoreInt64(&rl._closed, 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *rateLimiter) closed() bool {
|
||||||
|
return atomic.LoadInt64(&rl._closed) == 1
|
||||||
|
}
|
||||||
|
|
|
@ -2833,14 +2833,16 @@ func (t *RedisTest) transactionalIncr(c *C) ([]redis.Cmder, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *RedisTest) TestWatchUnwatch(c *C) {
|
func (t *RedisTest) TestWatchUnwatch(c *C) {
|
||||||
const N = 10000
|
var n = 10000
|
||||||
|
if testing.Short() {
|
||||||
|
n = 1000
|
||||||
|
}
|
||||||
|
|
||||||
set := t.client.Set("key", "0")
|
set := t.client.Set("key", "0")
|
||||||
c.Assert(set.Err(), IsNil)
|
c.Assert(set.Err(), IsNil)
|
||||||
c.Assert(set.Val(), Equals, "OK")
|
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
for i := 0; i < N; i++ {
|
for i := 0; i < n; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
@ -2858,19 +2860,22 @@ func (t *RedisTest) TestWatchUnwatch(c *C) {
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
get := t.client.Get("key")
|
val, err := t.client.Get("key").Int64()
|
||||||
c.Assert(get.Err(), IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(get.Val(), Equals, strconv.FormatInt(N, 10))
|
c.Assert(val, Equals, int64(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
func (t *RedisTest) TestRaceEcho(c *C) {
|
func (t *RedisTest) TestRaceEcho(c *C) {
|
||||||
const N = 10000
|
var n = 10000
|
||||||
|
if testing.Short() {
|
||||||
|
n = 1000
|
||||||
|
}
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
wg.Add(N)
|
wg.Add(n)
|
||||||
for i := 0; i < N; i++ {
|
for i := 0; i < n; i++ {
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
msg := "echo" + strconv.Itoa(i)
|
msg := "echo" + strconv.Itoa(i)
|
||||||
echo := t.client.Echo(msg)
|
echo := t.client.Echo(msg)
|
||||||
|
@ -2883,14 +2888,16 @@ func (t *RedisTest) TestRaceEcho(c *C) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *RedisTest) TestRaceIncr(c *C) {
|
func (t *RedisTest) TestRaceIncr(c *C) {
|
||||||
const N = 10000
|
var n = 10000
|
||||||
key := "TestIncrFromGoroutines"
|
if testing.Short() {
|
||||||
|
n = 1000
|
||||||
|
}
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
wg.Add(N)
|
wg.Add(n)
|
||||||
for i := int64(0); i < N; i++ {
|
for i := 0; i < n; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
incr := t.client.Incr(key)
|
incr := t.client.Incr("TestRaceIncr")
|
||||||
if err := incr.Err(); err != nil {
|
if err := incr.Err(); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -2899,9 +2906,9 @@ func (t *RedisTest) TestRaceIncr(c *C) {
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
get := t.client.Get(key)
|
val, err := t.client.Get("TestRaceIncr").Result()
|
||||||
c.Assert(get.Err(), IsNil)
|
c.Assert(err, IsNil)
|
||||||
c.Assert(get.Val(), Equals, strconv.Itoa(N))
|
c.Assert(val, Equals, strconv.Itoa(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
|
Loading…
Reference in New Issue