diff --git a/pool.go b/pool.go index a6ab257..03bb1a0 100644 --- a/pool.go +++ b/pool.go @@ -415,12 +415,12 @@ func (p *stickyConnPool) put() (err error) { func (p *stickyConnPool) Put(cn *conn) error { defer p.mx.Unlock() p.mx.Lock() - if p.cn != cn { - panic("p.cn != cn") - } if p.closed { return errClosed } + if p.cn != cn { + panic("p.cn != cn") + } return nil } @@ -433,15 +433,15 @@ func (p *stickyConnPool) remove() (err error) { func (p *stickyConnPool) Remove(cn *conn) error { defer p.mx.Unlock() p.mx.Lock() + if p.closed { + return errClosed + } if p.cn == nil { panic("p.cn == nil") } if cn != nil && p.cn != cn { panic("p.cn != cn") } - if p.closed { - return errClosed - } if cn == nil { return p.remove() } else { diff --git a/pubsub_test.go b/pubsub_test.go index 1506a8e..dd24bc6 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -235,8 +235,11 @@ var _ = Describe("PubSub", func() { Expect(err).NotTo(HaveOccurred()) defer pubsub.Close() + var wg sync.WaitGroup + wg.Add(1) go func() { defer GinkgoRecover() + defer wg.Done() time.Sleep(readTimeout + 100*time.Millisecond) n, err := client.Publish("mychannel", "hello").Result() @@ -248,6 +251,8 @@ var _ = Describe("PubSub", func() { Expect(err).NotTo(HaveOccurred()) Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Payload).To(Equal("hello")) + + wg.Wait() }) It("should reconnect on ReceiveMessage error", func() { @@ -281,4 +286,24 @@ var _ = Describe("PubSub", func() { wg.Wait() }) + It("should not panic on Close", func() { + pubsub, err := client.Subscribe("mychannel") + Expect(err).NotTo(HaveOccurred()) + defer pubsub.Close() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer GinkgoRecover() + + wg.Done() + _, err := pubsub.ReceiveMessage() + Expect(err).To(MatchError("redis: client is closed")) + }() + wg.Wait() + + err = pubsub.Close() + Expect(err).NotTo(HaveOccurred()) + }) + })