Merge pull request #611 from go-redis/fix/pubsub-deadlock

Fix PubSub.Subscribe deadlock
This commit is contained in:
Vladimir Mihailenco 2017-08-01 14:27:30 +03:00 committed by GitHub
commit a8ee44122a
2 changed files with 29 additions and 13 deletions

View File

@ -88,19 +88,19 @@ func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string)
return writeCmd(cn, cmd) return writeCmd(cn, cmd)
} }
func (c *PubSub) putConn(cn *pool.Conn, err error) { func (c *PubSub) releaseConn(cn *pool.Conn, err error) {
if !internal.IsBadConn(err, true) {
return
}
c.mu.Lock() c.mu.Lock()
if c.cn == cn { c._releaseConn(cn, err)
_ = c.releaseConn()
}
c.mu.Unlock() c.mu.Unlock()
} }
func (c *PubSub) releaseConn() error { func (c *PubSub) _releaseConn(cn *pool.Conn, err error) {
if internal.IsBadConn(err, true) && c.cn == cn {
_ = c.closeTheCn()
}
}
func (c *PubSub) closeTheCn() error {
err := c.closeConn(c.cn) err := c.closeConn(c.cn)
c.cn = nil c.cn = nil
return err return err
@ -116,7 +116,7 @@ func (c *PubSub) Close() error {
c.closed = true c.closed = true
if c.cn != nil { if c.cn != nil {
return c.releaseConn() return c.closeTheCn()
} }
return nil return nil
} }
@ -168,7 +168,7 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
} }
err = c._subscribe(cn, redisCmd, channels...) err = c._subscribe(cn, redisCmd, channels...)
c.putConn(cn, err) c._releaseConn(cn, err)
return err return err
} }
@ -186,7 +186,7 @@ func (c *PubSub) Ping(payload ...string) error {
cn.SetWriteTimeout(c.opt.WriteTimeout) cn.SetWriteTimeout(c.opt.WriteTimeout)
err = writeCmd(cn, cmd) err = writeCmd(cn, cmd)
c.putConn(cn, err) c.releaseConn(cn, err)
return err return err
} }
@ -279,7 +279,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
cn.SetReadTimeout(timeout) cn.SetReadTimeout(timeout)
err = c.cmd.readReply(cn) err = c.cmd.readReply(cn)
c.putConn(cn, err) c.releaseConn(cn, err)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -294,6 +294,22 @@ var _ = Describe("PubSub", func() {
Expect(stats.Hits).To(Equal(uint32(1))) Expect(stats.Hits).To(Equal(uint32(1)))
}) })
It("returns an error when subscribe fails", func() {
pubsub := client.Subscribe()
defer pubsub.Close()
pubsub.SetNetConn(&badConn{
readErr: io.EOF,
writeErr: io.EOF,
})
err := pubsub.Subscribe("mychannel")
Expect(err).To(MatchError("EOF"))
err = pubsub.Subscribe("mychannel")
Expect(err).NotTo(HaveOccurred())
})
expectReceiveMessageOnError := func(pubsub *redis.PubSub) { expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
pubsub.SetNetConn(&badConn{ pubsub.SetNetConn(&badConn{
readErr: io.EOF, readErr: io.EOF,