forked from mirror/redis
Improved rate-limiter, use ratelimit package
This commit is contained in:
parent
2dc0bd1c0a
commit
e3ba7e7bf6
8
pool.go
8
pool.go
|
@ -9,6 +9,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gopkg.in/bsm/ratelimit.v1"
|
||||
"gopkg.in/bufio.v1"
|
||||
)
|
||||
|
||||
|
@ -102,7 +103,7 @@ func (cn *conn) isIdle(timeout time.Duration) bool {
|
|||
|
||||
type connPool struct {
|
||||
dial func() (*conn, error)
|
||||
rl *rateLimiter
|
||||
rl *ratelimit.RateLimiter
|
||||
|
||||
opt *options
|
||||
conns chan *conn
|
||||
|
@ -116,7 +117,7 @@ type connPool struct {
|
|||
func newConnPool(dial func() (*conn, error), opt *options) *connPool {
|
||||
return &connPool{
|
||||
dial: dial,
|
||||
rl: newRateLimiter(time.Second, 2*opt.PoolSize),
|
||||
rl: ratelimit.New(2*opt.PoolSize, time.Second),
|
||||
|
||||
opt: opt,
|
||||
conns: make(chan *conn, opt.PoolSize),
|
||||
|
@ -160,7 +161,7 @@ func (p *connPool) wait() (*conn, error) {
|
|||
|
||||
// Establish a new connection
|
||||
func (p *connPool) new() (*conn, error) {
|
||||
if !p.rl.Check() {
|
||||
if p.rl.Limit() {
|
||||
err := fmt.Errorf(
|
||||
"redis: you open connections too fast (last error: %v)",
|
||||
p.lastDialErr,
|
||||
|
@ -257,7 +258,6 @@ func (p *connPool) Close() (err error) {
|
|||
if !atomic.CompareAndSwapInt32(&p.closed, 0, 1) {
|
||||
return nil
|
||||
}
|
||||
p.rl.Close()
|
||||
|
||||
for {
|
||||
if p.Size() < 1 {
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type rateLimiter struct {
|
||||
v int64
|
||||
|
||||
_closed int64
|
||||
}
|
||||
|
||||
func newRateLimiter(limit time.Duration, bucketSize int) *rateLimiter {
|
||||
rl := &rateLimiter{
|
||||
v: int64(bucketSize),
|
||||
}
|
||||
go rl.loop(limit, int64(bucketSize))
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) loop(limit time.Duration, bucketSize int64) {
|
||||
for {
|
||||
if rl.closed() {
|
||||
break
|
||||
}
|
||||
if v := atomic.LoadInt64(&rl.v); v < bucketSize {
|
||||
atomic.AddInt64(&rl.v, 1)
|
||||
}
|
||||
time.Sleep(limit)
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) Check() bool {
|
||||
for {
|
||||
if v := atomic.LoadInt64(&rl.v); v > 0 {
|
||||
if atomic.CompareAndSwapInt64(&rl.v, v, v-1) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) Close() error {
|
||||
atomic.StoreInt64(&rl._closed, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) closed() bool {
|
||||
return atomic.LoadInt64(&rl._closed) == 1
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("RateLimiter", func() {
|
||||
var n = 100000
|
||||
if testing.Short() {
|
||||
n = 1000
|
||||
}
|
||||
|
||||
It("should rate limit", func() {
|
||||
rl := newRateLimiter(time.Minute, n)
|
||||
for i := 0; i <= n; i++ {
|
||||
Expect(rl.Check()).To(BeTrue())
|
||||
}
|
||||
Expect(rl.Check()).To(BeFalse())
|
||||
})
|
||||
|
||||
})
|
Loading…
Reference in New Issue