From 0d94a7bc885f25fba5a2fb2e8611477ceeb476e8 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Thu, 29 Jun 2017 17:05:08 +0300 Subject: [PATCH] Fix race in PubSub --- internal/pool/pool_test.go | 26 ----------------------- main_test.go | 3 +-- pubsub.go | 42 +++++++++++++++++++++----------------- redis.go | 5 +---- 4 files changed, 25 insertions(+), 51 deletions(-) diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index f86327a..68c9a1b 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -238,30 +238,4 @@ var _ = Describe("race", func() { } }) }) - - It("does not happen on Get and PopFree", func() { - connPool = pool.NewConnPool( - &pool.Options{ - Dialer: dummyDialer, - PoolSize: 10, - PoolTimeout: time.Minute, - IdleTimeout: time.Second, - IdleCheckFrequency: time.Millisecond, - }) - - perform(C, func(id int) { - for i := 0; i < N; i++ { - cn, _, err := connPool.Get() - Expect(err).NotTo(HaveOccurred()) - if err == nil { - Expect(connPool.Put(cn)).NotTo(HaveOccurred()) - } - - cn = connPool.PopFree() - if cn != nil { - Expect(connPool.Put(cn)).NotTo(HaveOccurred()) - } - } - }) - }) }) diff --git a/main_test.go b/main_test.go index 64f25d9..30f09c6 100644 --- a/main_test.go +++ b/main_test.go @@ -3,7 +3,6 @@ package redis_test import ( "errors" "fmt" - "log" "net" "os" "os/exec" @@ -52,7 +51,7 @@ var cluster = &clusterScenario{ } func init() { - redis.SetLogger(log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)) + //redis.SetLogger(log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)) } var _ = BeforeSuite(func() { diff --git a/pubsub.go b/pubsub.go index 7eba98b..4872b4e 100644 --- a/pubsub.go +++ b/pubsub.go @@ -28,37 +28,41 @@ type PubSub struct { cmd *Cmd } -func (c *PubSub) conn() (*pool.Conn, bool, error) { +func (c *PubSub) conn() (*pool.Conn, error) { c.mu.Lock() - defer c.mu.Unlock() + cn, err := c._conn() + c.mu.Unlock() + return cn, err +} +func (c *PubSub) _conn() (*pool.Conn, error) { if c.closed { - return nil, false, pool.ErrClosed + return nil, pool.ErrClosed } if c.cn != nil { - return c.cn, false, nil + return c.cn, nil } cn, err := c.base.connPool.NewConn() if err != nil { - return nil, false, err + return nil, err } if !cn.Inited { if err := c.base.initConn(cn); err != nil { _ = c.base.connPool.CloseConn(cn) - return nil, false, err + return nil, err } } if err := c.resubscribe(cn); err != nil { _ = c.base.connPool.CloseConn(cn) - return nil, false, err + return nil, err } c.cn = cn - return cn, true, nil + return cn, nil } func (c *PubSub) resubscribe(cn *pool.Conn) error { @@ -125,48 +129,48 @@ func (c *PubSub) Close() error { // empty subscription if there are no channels. func (c *PubSub) Subscribe(channels ...string) error { c.mu.Lock() + err := c.subscribe("subscribe", channels...) c.channels = appendIfNotExists(c.channels, channels...) c.mu.Unlock() - return c.subscribe("subscribe", channels...) + return err } // Subscribes the client to the given patterns. It returns // empty subscription if there are no patterns. func (c *PubSub) PSubscribe(patterns ...string) error { c.mu.Lock() + err := c.subscribe("psubscribe", patterns...) c.patterns = appendIfNotExists(c.patterns, patterns...) c.mu.Unlock() - return c.subscribe("psubscribe", patterns...) + return err } // Unsubscribes the client from the given channels, or from all of // them if none is given. func (c *PubSub) Unsubscribe(channels ...string) error { c.mu.Lock() + err := c.subscribe("unsubscribe", channels...) c.channels = remove(c.channels, channels...) c.mu.Unlock() - return c.subscribe("unsubscribe", channels...) + return err } // Unsubscribes the client from the given patterns, or from all of // them if none is given. func (c *PubSub) PUnsubscribe(patterns ...string) error { c.mu.Lock() + err := c.subscribe("punsubscribe", patterns...) c.patterns = remove(c.patterns, patterns...) c.mu.Unlock() - return c.subscribe("punsubscribe", patterns...) + return err } func (c *PubSub) subscribe(redisCmd string, channels ...string) error { - cn, isNew, err := c.conn() + cn, err := c._conn() if err != nil { return err } - if isNew { - return nil - } - err = c._subscribe(cn, redisCmd, channels...) c.putConn(cn, err) return err @@ -179,7 +183,7 @@ func (c *PubSub) Ping(payload ...string) error { } cmd := NewCmd(args...) - cn, _, err := c.conn() + cn, err := c.conn() if err != nil { return err } @@ -272,7 +276,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { c.cmd = NewCmd() } - cn, _, err := c.conn() + cn, err := c.conn() if err != nil { return nil, err } diff --git a/redis.go b/redis.go index 1fc25e1..9812daf 100644 --- a/redis.go +++ b/redis.go @@ -387,10 +387,7 @@ func (c *Client) pubSub() *PubSub { func (c *Client) Subscribe(channels ...string) *PubSub { pubsub := c.pubSub() if len(channels) > 0 { - err := pubsub.Subscribe(channels...) - if err != nil { - panic(err) - } + _ = pubsub.Subscribe(channels...) } return pubsub }