From 3f491f8a8cdc95201bbb6e4125b5e604a5428c32 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Thu, 7 Nov 2013 16:20:15 +0200 Subject: [PATCH] Close all connections. --- v2/export_test.go | 9 ++ v2/pool.go | 164 +++++++++++++++------ v2/pubsub.go | 1 + v2/redis.go | 7 +- v2/redis_test.go | 353 ++++++++++++++++++++++------------------------ 5 files changed, 294 insertions(+), 240 deletions(-) create mode 100644 v2/export_test.go diff --git a/v2/export_test.go b/v2/export_test.go new file mode 100644 index 0000000..4ebca2e --- /dev/null +++ b/v2/export_test.go @@ -0,0 +1,9 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redis + +func (c *baseClient) Pool() pool { + return c.connPool +} diff --git a/v2/pool.go b/v2/pool.go index d87cd05..ecd494a 100644 --- a/v2/pool.go +++ b/v2/pool.go @@ -2,18 +2,25 @@ package redis import ( "container/list" + "errors" "net" "sync" "time" + "github.com/golang/glog" "github.com/vmihailenco/bufio" ) +var ( + errPoolClosed = errors.New("attempt to use closed connection pool") +) + type pool interface { Get() (*conn, bool, error) Put(*conn) error Remove(*conn) error Len() int + Size() int Close() error } @@ -22,17 +29,27 @@ type pool interface { type conn struct { cn net.Conn rd reader + inUse bool usedAt time.Time readTimeout, writeTimeout time.Duration + + elem *list.Element } -func newConn(netcn net.Conn) *conn { - cn := &conn{ - cn: netcn, +func newConnFunc(dial func() (net.Conn, error)) func() (*conn, error) { + return func() (*conn, error) { + netcn, err := dial() + if err != nil { + return nil, err + } + + cn := &conn{ + cn: netcn, + } + cn.rd = bufio.NewReader(cn) + return cn, nil } - cn.rd = bufio.NewReader(cn) - return cn } func (cn *conn) Read(b []byte) (int, error) { @@ -60,22 +77,25 @@ func (cn *conn) Close() error { //------------------------------------------------------------------------------ type connPool struct { - dial func() (net.Conn, error) + New func() (*conn, error) cond *sync.Cond conns *list.List - size, maxSize int - idleTimeout time.Duration + len int + maxSize int + idleTimeout time.Duration + + closed bool } func newConnPool( - dial func() (net.Conn, error), + dial func() (*conn, error), maxSize int, idleTimeout time.Duration, ) *connPool { return &connPool{ - dial: dial, + New: dial, cond: sync.NewCond(&sync.Mutex{}), conns: list.New(), @@ -86,87 +106,129 @@ func newConnPool( } func (p *connPool) Get() (*conn, bool, error) { - defer p.cond.L.Unlock() p.cond.L.Lock() - for p.conns.Len() == 0 && p.size >= p.maxSize { - p.cond.Wait() + if p.closed { + p.cond.L.Unlock() + return nil, false, errPoolClosed } if p.idleTimeout > 0 { for e := p.conns.Front(); e != nil; e = e.Next() { cn := e.Value.(*conn) + if cn.inUse { + break + } if time.Since(cn.usedAt) > p.idleTimeout { - p.conns.Remove(e) + if err := p.Remove(cn); err != nil { + glog.Errorf("Remove failed: %s", err) + } } } } - if p.conns.Len() == 0 { - rw, err := p.dial() + for p.conns.Len() >= p.maxSize && p.len == 0 { + p.cond.Wait() + } + + if p.len > 0 { + elem := p.conns.Front() + cn := elem.Value.(*conn) + if cn.inUse { + panic("pool: precondition failed") + } + cn.inUse = true + p.conns.MoveToBack(elem) + p.len-- + + p.cond.L.Unlock() + return cn, false, nil + } + + if p.conns.Len() < p.maxSize { + cn, err := p.New() if err != nil { + p.cond.L.Unlock() return nil, false, err } - p.size++ - return newConn(rw), true, nil + cn.inUse = true + cn.elem = p.conns.PushBack(cn) + + p.cond.L.Unlock() + return cn, true, nil } - elem := p.conns.Front() - p.conns.Remove(elem) - return elem.Value.(*conn), false, nil + panic("not reached") } func (p *connPool) Put(cn *conn) error { if cn.rd.Buffered() != 0 { panic("redis: attempt to put connection with buffered data") } - p.cond.L.Lock() + if p.closed { + p.cond.L.Unlock() + return errPoolClosed + } + cn.inUse = false cn.usedAt = time.Now() - p.conns.PushFront(cn) + p.conns.MoveToFront(cn.elem) + p.len++ p.cond.Signal() p.cond.L.Unlock() return nil } -func (p *connPool) Remove(cn *conn) error { - var err error +func (p *connPool) Remove(cn *conn) (err error) { + p.cond.L.Lock() + if p.closed { + // Noop, connection is already closed. + p.cond.L.Unlock() + return nil + } if cn != nil { err = cn.Close() } - p.cond.L.Lock() - p.size-- + p.conns.Remove(cn.elem) + cn.elem = nil p.cond.Signal() p.cond.L.Unlock() return err } +// Returns number of idle connections. func (p *connPool) Len() int { + defer p.cond.L.Unlock() + p.cond.L.Lock() + return p.len +} + +// Returns size of the pool. +func (p *connPool) Size() int { defer p.cond.L.Unlock() p.cond.L.Lock() return p.conns.Len() } -func (p *connPool) Size() int { - defer p.cond.L.Unlock() - p.cond.L.Lock() - return p.size -} - func (p *connPool) Close() error { defer p.cond.L.Unlock() p.cond.L.Lock() - - for e := p.conns.Front(); e != nil; e = e.Next() { - if err := e.Value.(*conn).Close(); err != nil { - return err - } + if p.closed { + return nil } - p.conns.Init() - p.size = 0 - - return nil + p.closed = true + var retErr error + for e := p.conns.Front(); e != nil; e = e.Next() { + cn := e.Value.(*conn) + if err := cn.Close(); err != nil { + glog.Errorf("cn.Close failed: %s", err) + retErr = err + } + cn.elem = nil + } + p.conns = nil + return retErr } //------------------------------------------------------------------------------ @@ -195,34 +257,33 @@ func (p *singleConnPool) Get() (*conn, bool, error) { } p.l.RUnlock() - defer p.l.Unlock() p.l.Lock() - cn, isNew, err := p.pool.Get() if err != nil { + p.l.Unlock() return nil, false, err } p.cn = cn - + p.l.Unlock() return cn, isNew, nil } func (p *singleConnPool) Put(cn *conn) error { - defer p.l.Unlock() p.l.Lock() if p.cn != cn { panic("p.cn != cn") } + p.l.Unlock() return nil } func (p *singleConnPool) Remove(cn *conn) error { - defer p.l.Unlock() p.l.Lock() if p.cn != cn { panic("p.cn != cn") } p.cn = nil + p.l.Unlock() return nil } @@ -235,6 +296,15 @@ func (p *singleConnPool) Len() int { return 1 } +func (p *singleConnPool) Size() int { + defer p.l.Unlock() + p.l.Lock() + if p.cn == nil { + return 0 + } + return 1 +} + func (p *singleConnPool) Close() error { defer p.l.Unlock() p.l.Lock() diff --git a/v2/pubsub.go b/v2/pubsub.go index 07bf51a..8d58854 100644 --- a/v2/pubsub.go +++ b/v2/pubsub.go @@ -5,6 +5,7 @@ import ( "time" ) +// Not thread-safe. type PubSub struct { *baseClient } diff --git a/v2/redis.go b/v2/redis.go index 6db08aa..ae75c42 100644 --- a/v2/redis.go +++ b/v2/redis.go @@ -113,7 +113,7 @@ func (c *baseClient) run(cmd Cmder) { } if err := c.writeCmd(cn, cmd); err != nil { - c.removeConn(cn) + c.freeConn(cn, err) cmd.setErr(err) return } @@ -173,10 +173,7 @@ func newClient(opt *Options, dial func() (net.Conn, error)) *Client { baseClient: &baseClient{ opt: opt, - connPool: newConnPool( - dial, opt.getPoolSize(), - opt.IdleTimeout, - ), + connPool: newConnPool(newConnFunc(dial), opt.getPoolSize(), opt.IdleTimeout), }, } } diff --git a/v2/redis_test.go b/v2/redis_test.go index 981cad5..68613f8 100644 --- a/v2/redis_test.go +++ b/v2/redis_test.go @@ -83,147 +83,141 @@ func (t *RedisConnectorTest) TestUnixConnector(c *C) { //------------------------------------------------------------------------------ -// type RedisConnPoolTest struct { -// dialedConns, closedConns int64 +type RedisConnPoolTest struct { + client *redis.Client +} -// client *redis.Client -// } +var _ = Suite(&RedisConnPoolTest{}) -// var _ = Suite(&RedisConnPoolTest{}) +func (t *RedisConnPoolTest) SetUpTest(c *C) { + t.client = redis.NewTCPClient(&redis.Options{ + Addr: redisAddr, + }) +} -// func (t *RedisConnPoolTest) SetUpTest(c *C) { -// if t.client == nil { -// dial := func() (net.Conn, error) { -// t.dialedConns++ -// return net.Dial("tcp", redisAddr) -// } -// close := func(conn net.Conn) error { -// t.closedConns++ -// return nil -// } +func (t *RedisConnPoolTest) TearDownTest(c *C) { + c.Assert(t.client.FlushDb().Err(), IsNil) + c.Assert(t.client.Close(), IsNil) +} -// t.client = (&redis.ClientFactory{ -// Dial: dial, -// Close: close, -// }).New() -// } -// } +func (t *RedisConnPoolTest) TestConnPoolMaxSize(c *C) { + wg := &sync.WaitGroup{} + for i := 0; i < 1000; i++ { + wg.Add(1) + go func() { + ping := t.client.Ping() + c.Assert(ping.Err(), IsNil) + c.Assert(ping.Val(), Equals, "PONG") + wg.Done() + }() + } + wg.Wait() -// func (t *RedisConnPoolTest) TearDownTest(c *C) { -// t.resetRedis(c) -// t.resetClient(c) -// } + c.Assert(t.client.Pool().Size(), Equals, 10) + c.Assert(t.client.Pool().Len(), Equals, 10) +} -// func (t *RedisConnPoolTest) resetRedis(c *C) { -// // This is much faster than Flushall. -// c.Assert(t.client.Select(1).Err(), IsNil) -// c.Assert(t.client.FlushDb().Err(), IsNil) -// c.Assert(t.client.Select(0).Err(), IsNil) -// c.Assert(t.client.FlushDb().Err(), IsNil) -// } +func (t *RedisConnPoolTest) TestConnPoolMaxSizeOnPipelineClient(c *C) { + const N = 1000 -// func (t *RedisConnPoolTest) resetClient(c *C) { -// t.client.Close() -// c.Check(t.closedConns, Equals, t.dialedConns) -// t.dialedConns, t.closedConns = 0, 0 -// } + wg := &sync.WaitGroup{} + wg.Add(N) + for i := 0; i < N; i++ { + go func() { + pipeline := t.client.Pipeline() + ping := pipeline.Ping() + cmds, err := pipeline.Exec() + c.Assert(err, IsNil) + c.Assert(cmds, HasLen, 1) + c.Assert(ping.Err(), IsNil) + c.Assert(ping.Val(), Equals, "PONG") -// func (t *RedisConnPoolTest) TestConnPoolMaxSize(c *C) { -// wg := &sync.WaitGroup{} -// for i := 0; i < 1000; i++ { -// wg.Add(1) -// go func() { -// ping := t.client.Ping() -// c.Assert(ping.Err(), IsNil) -// c.Assert(ping.Val(), Equals, "PONG") -// wg.Done() -// }() -// } -// wg.Wait() + c.Assert(pipeline.Close(), IsNil) -// c.Assert(t.client.Close(), IsNil) -// c.Assert(t.dialedConns, Equals, int64(10)) -// c.Assert(t.closedConns, Equals, int64(10)) -// } + wg.Done() + }() + } + wg.Wait() -// func (t *RedisConnPoolTest) TestConnPoolMaxSizeOnPipelineClient(c *C) { -// wg := &sync.WaitGroup{} -// for i := 0; i < 1000; i++ { -// wg.Add(1) -// go func() { -// pipeline, err := t.client.PipelineClient() -// c.Assert(err, IsNil) + c.Assert(t.client.Pool().Size(), Equals, 10) + c.Assert(t.client.Pool().Len(), Equals, 10) +} -// ping := pipeline.Ping() -// cmds, err := pipeline.RunQueued() -// c.Assert(err, IsNil) -// c.Assert(cmds, HasLen, 1) -// c.Assert(ping.Err(), IsNil) -// c.Assert(ping.Val(), Equals, "PONG") +func (t *RedisConnPoolTest) TestConnPoolMaxSizeOnMultiClient(c *C) { + const N = 1000 -// c.Assert(pipeline.Close(), IsNil) + wg := &sync.WaitGroup{} + wg.Add(N) + for i := 0; i < N; i++ { + go func() { + multi := t.client.Multi() + var ping *redis.StatusCmd + cmds, err := multi.Exec(func() { + ping = multi.Ping() + }) + c.Assert(err, IsNil) + c.Assert(cmds, HasLen, 1) + c.Assert(ping.Err(), IsNil) + c.Assert(ping.Val(), Equals, "PONG") -// wg.Done() -// }() -// } -// wg.Wait() + c.Assert(multi.Close(), IsNil) -// c.Assert(t.client.Close(), IsNil) -// c.Assert(t.dialedConns, Equals, int64(10)) -// c.Assert(t.closedConns, Equals, int64(10)) -// } + wg.Done() + }() + } + wg.Wait() -// func (t *RedisConnPoolTest) TestConnPoolMaxSizeOnMultiClient(c *C) { -// wg := &sync.WaitGroup{} -// for i := 0; i < 1000; i++ { -// wg.Add(1) -// go func() { -// multi, err := t.client.MultiClient() -// c.Assert(err, IsNil) + c.Assert(t.client.Pool().Size(), Equals, 10) + c.Assert(t.client.Pool().Len(), Equals, 10) +} -// var ping *redis.StatusCmd -// cmds, err := multi.Exec(func() { -// ping = multi.Ping() -// }) -// c.Assert(err, IsNil) -// c.Assert(cmds, HasLen, 1) -// c.Assert(ping.Err(), IsNil) -// c.Assert(ping.Val(), Equals, "PONG") +func (t *RedisConnPoolTest) TestConnPoolMaxSizeOnPubSub(c *C) { + const N = 1000 -// c.Assert(multi.Close(), IsNil) + wg := &sync.WaitGroup{} + wg.Add(N) + for i := 0; i < N; i++ { + go func() { + pubsub := t.client.PubSub() + c.Assert(pubsub.Subscribe(), IsNil) + c.Assert(pubsub.Close(), IsNil) + wg.Done() + }() + } + wg.Wait() -// wg.Done() -// }() -// } -// wg.Wait() + c.Assert(t.client.Pool().Size(), Equals, 0) + c.Assert(t.client.Pool().Len(), Equals, 0) +} -// c.Assert(t.client.Close(), IsNil) -// c.Assert(t.dialedConns, Equals, int64(10)) -// c.Assert(t.closedConns, Equals, int64(10)) -// } +func (t *RedisConnPoolTest) TestConnPoolRemovesBrokenConn(c *C) { + cn, _, err := t.client.Pool().Get() + c.Assert(err, IsNil) + c.Assert(cn.Close(), IsNil) + c.Assert(t.client.Pool().Put(cn), IsNil) -// func (t *RedisConnPoolTest) TestConnPoolMaxSizeOnPubSub(c *C) { -// wg := &sync.WaitGroup{} -// for i := 0; i < 1000; i++ { -// wg.Add(1) -// go func() { -// pubsub, err := t.client.PubSub() -// c.Assert(err, IsNil) + ping := t.client.Ping() + c.Assert(ping.Err().Error(), Equals, "use of closed network connection") + c.Assert(ping.Val(), Equals, "") -// _, err = pubsub.Subscribe() -// c.Assert(err, IsNil) + ping = t.client.Ping() + c.Assert(ping.Err(), IsNil) + c.Assert(ping.Val(), Equals, "PONG") -// c.Assert(pubsub.Close(), IsNil) + c.Assert(t.client.Pool().Size(), Equals, 1) + c.Assert(t.client.Pool().Len(), Equals, 1) +} -// wg.Done() -// }() -// } -// wg.Wait() +func (t *RedisConnPoolTest) TestConnPoolReusesConn(c *C) { + for i := 0; i < 1000; i++ { + ping := t.client.Ping() + c.Assert(ping.Err(), IsNil) + c.Assert(ping.Val(), Equals, "PONG") + } -// c.Assert(t.client.Close(), IsNil) -// c.Assert(t.dialedConns, Equals, int64(1000)) -// c.Assert(t.closedConns, Equals, int64(1000)) -// } + c.Assert(t.client.Pool().Size(), Equals, 1) + c.Assert(t.client.Pool().Len(), Equals, 1) +} //------------------------------------------------------------------------------ @@ -235,12 +229,14 @@ var _ = Suite(&RedisTest{}) func Test(t *testing.T) { TestingT(t) } +func (t *RedisTest) SetUpSuite(c *C) { + t.client = redis.NewTCPClient(&redis.Options{ + Addr: ":6379", + }) +} + func (t *RedisTest) SetUpTest(c *C) { - if t.client == nil { - t.client = redis.NewTCPClient(&redis.Options{ - Addr: ":6379", - }) - } + t.resetRedis(c) } func (t *RedisTest) TearDownTest(c *C) { @@ -336,33 +332,6 @@ func (t *RedisTest) TestManyKeys2(c *C) { //------------------------------------------------------------------------------ -func (t *RedisTest) TestConnPoolRemovesBrokenConn(c *C) { - c.Skip("fix me") - - conn, err := net.Dial("tcp", redisAddr) - c.Assert(err, IsNil) - c.Assert(conn.Close(), IsNil) - - client := redis.NewTCPClient(&redis.Options{ - Addr: redisAddr, - }) - defer func() { - c.Assert(client.Close(), IsNil) - }() - - // c.Assert(client.ConnPool.Add(redis.NewConn(conn)), IsNil) - - ping := client.Ping() - c.Assert(ping.Err().Error(), Equals, "use of closed network connection") - c.Assert(ping.Val(), Equals, "") - - ping = client.Ping() - c.Assert(ping.Err(), IsNil) - c.Assert(ping.Val(), Equals, "PONG") -} - -//------------------------------------------------------------------------------ - func (t *RedisTest) TestAuth(c *C) { auth := t.client.Auth("password") c.Assert(auth.Err(), ErrorMatches, "ERR Client sent AUTH, but no password is set") @@ -2446,15 +2415,14 @@ func (t *RedisTest) TestPipeline(c *C) { func (t *RedisTest) TestPipelineDiscardQueued(c *C) { pipeline := t.client.Pipeline() - defer func() { - c.Assert(pipeline.Close(), IsNil) - }() pipeline.Get("key") pipeline.Discard() cmds, err := pipeline.Exec() c.Assert(err, IsNil) c.Assert(cmds, HasLen, 0) + + c.Assert(pipeline.Close(), IsNil) } func (t *RedisTest) TestPipelineFunc(c *C) { @@ -2489,19 +2457,18 @@ func (t *RedisTest) TestPipelineRunQueuedOnEmptyQueue(c *C) { c.Assert(cmds, HasLen, 0) } -func (t *RedisTest) TestPipelineIncrFromGoroutines(c *C) { +// TODO: make thread safe? +func (t *RedisTest) TestPipelineIncr(c *C) { + const N = 20000 + key := "TestPipelineIncr" + pipeline := t.client.Pipeline() - defer func() { - c.Assert(pipeline.Close(), IsNil) - }() wg := &sync.WaitGroup{} - for i := int64(0); i < 20000; i++ { - wg.Add(1) - go func() { - pipeline.Incr("TestIncrPipeliningFromGoroutinesKey") - wg.Done() - }() + wg.Add(N) + for i := 0; i < N; i++ { + pipeline.Incr(key) + wg.Done() } wg.Wait() @@ -2514,23 +2481,24 @@ func (t *RedisTest) TestPipelineIncrFromGoroutines(c *C) { } } - get := t.client.Get("TestIncrPipeliningFromGoroutinesKey") + get := t.client.Get(key) c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "20000") + c.Assert(get.Val(), Equals, strconv.Itoa(N)) + + c.Assert(pipeline.Close(), IsNil) } -func (t *RedisTest) TestPipelineEchoFromGoroutines(c *C) { - pipeline := t.client.Pipeline() - defer func() { - c.Assert(pipeline.Close(), IsNil) - }() +func (t *RedisTest) TestPipelineEcho(c *C) { + const N = 1000 wg := &sync.WaitGroup{} - for i := int64(0); i < 1000; i += 2 { - wg.Add(1) - go func() { - msg1 := "echo" + strconv.FormatInt(i, 10) - msg2 := "echo" + strconv.FormatInt(i+1, 10) + wg.Add(N) + for i := 0; i < N; i++ { + go func(i int) { + pipeline := t.client.Pipeline() + + msg1 := "echo" + strconv.Itoa(i) + msg2 := "echo" + strconv.Itoa(i+1) echo1 := pipeline.Echo(msg1) echo2 := pipeline.Echo(msg2) @@ -2545,8 +2513,10 @@ func (t *RedisTest) TestPipelineEchoFromGoroutines(c *C) { c.Assert(echo2.Err(), IsNil) c.Assert(echo2.Val(), Equals, msg2) + c.Assert(pipeline.Close(), IsNil) + wg.Done() - }() + }(i) } wg.Wait() } @@ -2703,36 +2673,43 @@ func (t *RedisTest) TestWatchUnwatch(c *C) { //------------------------------------------------------------------------------ -func (t *RedisTest) TestSyncEchoFromGoroutines(c *C) { +func (t *RedisTest) TestRaceEcho(c *C) { + const N = 10000 + wg := &sync.WaitGroup{} - for i := int64(0); i < 1000; i++ { - wg.Add(1) - go func() { - msg := "echo" + strconv.FormatInt(i, 10) + wg.Add(N) + for i := 0; i < N; i++ { + go func(i int) { + msg := "echo" + strconv.Itoa(i) echo := t.client.Echo(msg) c.Assert(echo.Err(), IsNil) c.Assert(echo.Val(), Equals, msg) wg.Done() - }() + }(i) } wg.Wait() } -func (t *RedisTest) TestIncrFromGoroutines(c *C) { +func (t *RedisTest) TestRaceIncr(c *C) { + const N = 10000 + key := "TestIncrFromGoroutines" + wg := &sync.WaitGroup{} - for i := int64(0); i < 20000; i++ { - wg.Add(1) + wg.Add(N) + for i := int64(0); i < N; i++ { go func() { - incr := t.client.Incr("TestIncrFromGoroutinesKey") - c.Assert(incr.Err(), IsNil) + incr := t.client.Incr(key) + if err := incr.Err(); err != nil { + panic(err) + } wg.Done() }() } wg.Wait() - get := t.client.Get("TestIncrFromGoroutinesKey") + get := t.client.Get(key) c.Assert(get.Err(), IsNil) - c.Assert(get.Val(), Equals, "20000") + c.Assert(get.Val(), Equals, strconv.Itoa(N)) } //------------------------------------------------------------------------------