diff --git a/pubsub.go b/pubsub.go index 74ac51c1..4a5c65f5 100644 --- a/pubsub.go +++ b/pubsub.go @@ -88,19 +88,19 @@ func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) return writeCmd(cn, cmd) } -func (c *PubSub) putConn(cn *pool.Conn, err error) { - if !internal.IsBadConn(err, true) { - return - } - +func (c *PubSub) releaseConn(cn *pool.Conn, err error) { c.mu.Lock() - if c.cn == cn { - _ = c.releaseConn() - } + c._releaseConn(cn, err) c.mu.Unlock() } -func (c *PubSub) releaseConn() error { +func (c *PubSub) _releaseConn(cn *pool.Conn, err error) { + if internal.IsBadConn(err, true) && c.cn == cn { + _ = c.closeTheCn() + } +} + +func (c *PubSub) closeTheCn() error { err := c.closeConn(c.cn) c.cn = nil return err @@ -116,7 +116,7 @@ func (c *PubSub) Close() error { c.closed = true if c.cn != nil { - return c.releaseConn() + return c.closeTheCn() } return nil } @@ -168,7 +168,7 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { } err = c._subscribe(cn, redisCmd, channels...) - c.putConn(cn, err) + c._releaseConn(cn, err) return err } @@ -186,7 +186,7 @@ func (c *PubSub) Ping(payload ...string) error { cn.SetWriteTimeout(c.opt.WriteTimeout) err = writeCmd(cn, cmd) - c.putConn(cn, err) + c.releaseConn(cn, err) return err } @@ -279,7 +279,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { cn.SetReadTimeout(timeout) err = c.cmd.readReply(cn) - c.putConn(cn, err) + c.releaseConn(cn, err) if err != nil { return nil, err } diff --git a/pubsub_test.go b/pubsub_test.go index 3cb9627b..d44b1dd8 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -294,6 +294,22 @@ var _ = Describe("PubSub", func() { Expect(stats.Hits).To(Equal(uint32(1))) }) + It("returns an error when subscribe fails", func() { + pubsub := client.Subscribe() + defer pubsub.Close() + + pubsub.SetNetConn(&badConn{ + readErr: io.EOF, + writeErr: io.EOF, + }) + + err := pubsub.Subscribe("mychannel") + Expect(err).To(MatchError("EOF")) + + err = pubsub.Subscribe("mychannel") + Expect(err).NotTo(HaveOccurred()) + }) + expectReceiveMessageOnError := func(pubsub *redis.PubSub) { pubsub.SetNetConn(&badConn{ readErr: io.EOF,