diff --git a/main_test.go b/main_test.go index ed2f7de9..c4b5a597 100644 --- a/main_test.go +++ b/main_test.go @@ -80,6 +80,7 @@ var _ = BeforeSuite(func() { var _ = AfterSuite(func() { Expect(redisMain.Close()).NotTo(HaveOccurred()) + Expect(ringShard1.Close()).NotTo(HaveOccurred()) Expect(ringShard2.Close()).NotTo(HaveOccurred()) diff --git a/multi.go b/multi.go index edd5ce3a..9d87de9a 100644 --- a/multi.go +++ b/multi.go @@ -3,6 +3,7 @@ package redis import ( "errors" "fmt" + "log" ) var errDiscard = errors.New("redis: Discard can be used only inside Exec") @@ -18,7 +19,10 @@ type Multi struct { func (c *Client) Multi() *Multi { multi := &Multi{ - base: &baseClient{opt: c.opt, connPool: newSingleConnPool(c.connPool, true)}, + base: &baseClient{ + opt: c.opt, + connPool: newSingleConnPool(c.connPool, true), + }, } multi.commandable.process = multi.process return multi @@ -34,7 +38,7 @@ func (c *Multi) process(cmd Cmder) { func (c *Multi) Close() error { if err := c.Unwatch().Err(); err != nil { - return err + log.Printf("redis: Unwatch failed: %s", err) } return c.base.Close() } diff --git a/pool.go b/pool.go index 301b9e00..714cbe5b 100644 --- a/pool.go +++ b/pool.go @@ -258,10 +258,12 @@ func (p *connPool) Remove(cn *conn) error { // Replace existing connection with new one and unblock waiter. newcn, err := p.new() if err != nil { + log.Printf("redis: new failed: %s", err) return p.conns.Remove(cn) } + err = p.conns.Replace(cn, newcn) p.freeConns <- newcn - return p.conns.Replace(cn, newcn) + return err } // Len returns total number of connections. @@ -312,14 +314,12 @@ func (p *connPool) reaper() { //------------------------------------------------------------------------------ type singleConnPool struct { - pool pool - - cnMtx sync.Mutex - cn *conn - + pool pool reusable bool + cn *conn closed bool + mx sync.Mutex } func newSingleConnPool(pool pool, reusable bool) *singleConnPool { @@ -330,20 +330,24 @@ func newSingleConnPool(pool pool, reusable bool) *singleConnPool { } func (p *singleConnPool) SetConn(cn *conn) { - p.cnMtx.Lock() + p.mx.Lock() + if p.cn != nil { + panic("p.cn != nil") + } p.cn = cn - p.cnMtx.Unlock() + p.mx.Unlock() } func (p *singleConnPool) First() *conn { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() - return p.cn + p.mx.Lock() + cn := p.cn + p.mx.Unlock() + return cn } func (p *singleConnPool) Get() (*conn, error) { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() + defer p.mx.Unlock() + p.mx.Lock() if p.closed { return nil, errClosed @@ -362,8 +366,8 @@ func (p *singleConnPool) Get() (*conn, error) { } func (p *singleConnPool) Put(cn *conn) error { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() + defer p.mx.Unlock() + p.mx.Lock() if p.cn != cn { panic("p.cn != cn") } @@ -374,8 +378,8 @@ func (p *singleConnPool) Put(cn *conn) error { } func (p *singleConnPool) Remove(cn *conn) error { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() + defer p.mx.Unlock() + p.mx.Lock() if p.cn == nil { panic("p.cn == nil") } @@ -395,8 +399,8 @@ func (p *singleConnPool) remove() error { } func (p *singleConnPool) Len() int { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() + defer p.mx.Unlock() + p.mx.Lock() if p.cn == nil { return 0 } @@ -404,19 +408,19 @@ func (p *singleConnPool) Len() int { } func (p *singleConnPool) FreeLen() int { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() + defer p.mx.Unlock() + p.mx.Lock() if p.cn == nil { - return 0 + return 1 } - return 1 + return 0 } func (p *singleConnPool) Close() error { - defer p.cnMtx.Unlock() - p.cnMtx.Lock() + defer p.mx.Unlock() + p.mx.Lock() if p.closed { - return nil + return errClosed } p.closed = true var err error diff --git a/redis_test.go b/redis_test.go index 4ad44866..8a8663e1 100644 --- a/redis_test.go +++ b/redis_test.go @@ -88,6 +88,24 @@ var _ = Describe("Client", func() { Expect(client.Ping().Err()).NotTo(HaveOccurred()) }) + It("should close pubsub when client is closed", func() { + pubsub := client.PubSub() + Expect(client.Close()).NotTo(HaveOccurred()) + Expect(pubsub.Close()).NotTo(HaveOccurred()) + }) + + It("should close multi when client is closed", func() { + multi := client.Multi() + Expect(client.Close()).NotTo(HaveOccurred()) + Expect(multi.Close()).NotTo(HaveOccurred()) + }) + + It("should close pipeline when client is closed", func() { + pipeline := client.Pipeline() + Expect(client.Close()).NotTo(HaveOccurred()) + Expect(pipeline.Close()).NotTo(HaveOccurred()) + }) + It("should support idle-timeouts", func() { idle := redis.NewClient(&redis.Options{ Addr: redisAddr,