diff --git a/multi.go b/multi.go index 0360a76..d76af2f 100644 --- a/multi.go +++ b/multi.go @@ -24,7 +24,9 @@ func (c *Client) Multi() *Multi { } func (c *Multi) Close() error { - c.Unwatch() + if err := c.Unwatch().Err(); err != nil { + return err + } return c.Client.Close() } diff --git a/pool.go b/pool.go index 8f596a4..452562d 100644 --- a/pool.go +++ b/pool.go @@ -12,7 +12,8 @@ import ( ) var ( - errClosed = errors.New("redis: client is closed") + errClosed = errors.New("redis: client is closed") + errRateLimited = errors.New("redis: you open connections too fast") ) var ( @@ -83,7 +84,8 @@ func (cn *conn) Close() error { //------------------------------------------------------------------------------ type connPool struct { - New func() (*conn, error) + dial func() (*conn, error) + rl *rateLimiter cond *sync.Cond conns *list.List @@ -101,7 +103,8 @@ func newConnPool( idleTimeout time.Duration, ) *connPool { return &connPool{ - New: dial, + dial: dial, + rl: newRateLimiter(time.Second, 2*maxSize), cond: sync.NewCond(&sync.Mutex{}), conns: list.New(), @@ -111,6 +114,15 @@ func newConnPool( } } +func (p *connPool) new() (*conn, error) { + select { + case <-p.rl.C: + default: + return nil, errRateLimited + } + return p.dial() +} + func (p *connPool) Get() (*conn, bool, error) { p.cond.L.Lock() @@ -152,7 +164,7 @@ func (p *connPool) Get() (*conn, bool, error) { } if p.conns.Len() < p.maxSize { - cn, err := p.New() + cn, err := p.new() if err != nil { p.cond.L.Unlock() return nil, false, err diff --git a/rate_limit.go b/rate_limit.go new file mode 100644 index 0000000..bd84a24 --- /dev/null +++ b/rate_limit.go @@ -0,0 +1,30 @@ +package redis + +import ( + "time" +) + +type rateLimiter struct { + C chan struct{} +} + +func newRateLimiter(limit time.Duration, chanSize int) *rateLimiter { + rl := &rateLimiter{ + C: make(chan struct{}, chanSize), + } + for i := 0; i < chanSize; i++ { + rl.C <- struct{}{} + } + go rl.loop(limit) + return rl +} + +func (rl *rateLimiter) loop(limit time.Duration) { + for { + select { + case rl.C <- struct{}{}: + default: + } + time.Sleep(limit) + } +} diff --git a/redis_test.go b/redis_test.go index 48386d8..fd68b00 100644 --- a/redis_test.go +++ b/redis_test.go @@ -239,16 +239,16 @@ func (t *RedisConnPoolTest) TestConnPoolMaxSizeOnMultiClient(c *C) { } func (t *RedisConnPoolTest) TestConnPoolMaxSizeOnPubSub(c *C) { - const N = 1000 + const N = 10 wg := &sync.WaitGroup{} wg.Add(N) for i := 0; i < N; i++ { go func() { + defer wg.Done() pubsub := t.client.PubSub() c.Assert(pubsub.Subscribe(), IsNil) c.Assert(pubsub.Close(), IsNil) - wg.Done() }() } wg.Wait() @@ -296,21 +296,11 @@ var _ = Suite(&RedisTest{}) func Test(t *testing.T) { TestingT(t) } -func (t *RedisTest) SetUpSuite(c *C) { +func (t *RedisTest) SetUpTest(c *C) { t.client = redis.NewTCPClient(&redis.Options{ Addr: ":6379", }) -} -func (t *RedisTest) TearDownSuite(c *C) { - c.Assert(t.client.Close(), IsNil) -} - -func (t *RedisTest) SetUpTest(c *C) { - t.resetRedis(c) -} - -func (t *RedisTest) resetRedis(c *C) { // This is much faster than Flushall. c.Assert(t.client.Select(1).Err(), IsNil) c.Assert(t.client.FlushDb().Err(), IsNil) @@ -318,6 +308,10 @@ func (t *RedisTest) resetRedis(c *C) { c.Assert(t.client.FlushDb().Err(), IsNil) } +func (t *RedisTest) TearDownTest(c *C) { + c.Assert(t.client.Close(), IsNil) +} + //------------------------------------------------------------------------------ func (t *RedisTest) TestCmdStringMethod(c *C) { @@ -2787,14 +2781,17 @@ func (t *RedisTest) transactionalIncr(c *C) ([]redis.Cmder, error) { } func (t *RedisTest) TestWatchUnwatch(c *C) { + const N = 10000 + set := t.client.Set("key", "0") c.Assert(set.Err(), IsNil) c.Assert(set.Val(), Equals, "OK") wg := &sync.WaitGroup{} - for i := 0; i < 1000; i++ { + for i := 0; i < N; i++ { wg.Add(1) go func() { + defer wg.Done() for { cmds, err := t.transactionalIncr(c) if err == redis.TxFailedErr { @@ -2805,14 +2802,13 @@ func (t *RedisTest) TestWatchUnwatch(c *C) { c.Assert(cmds[0].Err(), IsNil) break } - wg.Done() }() } wg.Wait() get := t.client.Get("key") c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "1000") + c.Assert(get.Val(), Equals, strconv.FormatInt(N, 10)) } //------------------------------------------------------------------------------