Merge pull request #268 from go-redis/fix/close-connection-on-network-timeout

Close connection on network timeout.
This commit is contained in:
Vladimir Mihailenco 2016-03-09 15:28:25 +02:00
commit eb78eedafe
14 changed files with 161 additions and 110 deletions

View File

@ -79,7 +79,7 @@ func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) {
if err != nil { if err != nil {
retErr = err retErr = err
} }
client.putConn(cn, err) client.putConn(cn, err, false)
} }
cmdsMap = failedCmds cmdsMap = failedCmds

View File

@ -32,7 +32,6 @@ type Cmder interface {
setErr(error) setErr(error)
reset() reset()
writeTimeout() *time.Duration
readTimeout() *time.Duration readTimeout() *time.Duration
clusterKey() string clusterKey() string
@ -82,7 +81,7 @@ type baseCmd struct {
_clusterKeyPos int _clusterKeyPos int
_writeTimeout, _readTimeout *time.Duration _readTimeout *time.Duration
} }
func (cmd *baseCmd) Err() error { func (cmd *baseCmd) Err() error {
@ -104,10 +103,6 @@ func (cmd *baseCmd) setReadTimeout(d time.Duration) {
cmd._readTimeout = &d cmd._readTimeout = &d
} }
func (cmd *baseCmd) writeTimeout() *time.Duration {
return cmd._writeTimeout
}
func (cmd *baseCmd) clusterKey() string { func (cmd *baseCmd) clusterKey() string {
if cmd._clusterKeyPos > 0 && cmd._clusterKeyPos < len(cmd._args) { if cmd._clusterKeyPos > 0 && cmd._clusterKeyPos < len(cmd._args) {
return fmt.Sprint(cmd._args[cmd._clusterKeyPos]) return fmt.Sprint(cmd._args[cmd._clusterKeyPos])
@ -115,10 +110,6 @@ func (cmd *baseCmd) clusterKey() string {
return "" return ""
} }
func (cmd *baseCmd) setWriteTimeout(d time.Duration) {
cmd._writeTimeout = &d
}
func (cmd *baseCmd) setErr(e error) { func (cmd *baseCmd) setErr(e error) {
cmd.err = e cmd.err = e
} }

View File

@ -1303,6 +1303,9 @@ var _ = Describe("Commands", func() {
bLPop := client.BLPop(time.Second, "list1") bLPop := client.BLPop(time.Second, "list1")
Expect(bLPop.Val()).To(BeNil()) Expect(bLPop.Val()).To(BeNil())
Expect(bLPop.Err()).To(Equal(redis.Nil)) Expect(bLPop.Err()).To(Equal(redis.Nil))
stats := client.Pool().Stats()
Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(1)))
}) })
It("should BRPop", func() { It("should BRPop", func() {

View File

@ -33,15 +33,17 @@ func isNetworkError(err error) bool {
return ok return ok
} }
func isBadConn(err error) bool { func isBadConn(err error, allowTimeout bool) bool {
if err == nil { if err == nil {
return false return false
} }
if _, ok := err.(redisError); ok { if _, ok := err.(redisError); ok {
return false return false
} }
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { if allowTimeout {
return false if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return false
}
} }
return true return true
} }

View File

@ -98,9 +98,10 @@ func TestGinkgoSuite(t *testing.T) {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func eventually(fn func() error, timeout time.Duration) (err error) { func eventually(fn func() error, timeout time.Duration) error {
done := make(chan struct{}) done := make(chan struct{})
var exit int32 var exit int32
var err error
go func() { go func() {
for atomic.LoadInt32(&exit) == 0 { for atomic.LoadInt32(&exit) == 0 {
err = fn() err = fn()

View File

@ -133,7 +133,7 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) {
} }
err = c.execCmds(cn, cmds) err = c.execCmds(cn, cmds)
c.base.putConn(cn, err) c.base.putConn(cn, err, false)
return retCmds, err return retCmds, err
} }

View File

@ -98,7 +98,7 @@ func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) {
resetCmds(failedCmds) resetCmds(failedCmds)
} }
failedCmds, err = execCmds(cn, failedCmds) failedCmds, err = execCmds(cn, failedCmds)
pipe.client.putConn(cn, err) pipe.client.putConn(cn, err, false)
if err != nil && retErr == nil { if err != nil && retErr == nil {
retErr = err retErr = err
} }

View File

@ -18,7 +18,8 @@ var (
// PoolStats contains pool state information and accumulated stats. // PoolStats contains pool state information and accumulated stats.
type PoolStats struct { type PoolStats struct {
Requests uint32 // number of times a connection was requested by the pool Requests uint32 // number of times a connection was requested by the pool
Waits uint32 // number of times our pool had to wait for a connection Hits uint32 // number of times free connection was found in the pool
Waits uint32 // number of times the pool had to wait for a connection
Timeouts uint32 // number of times a wait timeout occurred Timeouts uint32 // number of times a wait timeout occurred
TotalConns uint32 // the number of total connections in the pool TotalConns uint32 // the number of total connections in the pool
@ -241,6 +242,7 @@ func (p *connPool) Get() (cn *conn, isNew bool, err error) {
// Fetch first non-idle connection, if available. // Fetch first non-idle connection, if available.
if cn = p.First(); cn != nil { if cn = p.First(); cn != nil {
atomic.AddUint32(&p.stats.Hits, 1)
return return
} }

View File

@ -123,6 +123,12 @@ var _ = Describe("pool", func() {
pool := client.Pool() pool := client.Pool()
Expect(pool.Len()).To(Equal(1)) Expect(pool.Len()).To(Equal(1))
Expect(pool.FreeLen()).To(Equal(1)) Expect(pool.FreeLen()).To(Equal(1))
stats := pool.Stats()
Expect(stats.Requests).To(Equal(uint32(3)))
Expect(stats.Hits).To(Equal(uint32(2)))
Expect(stats.Waits).To(Equal(uint32(0)))
Expect(stats.Timeouts).To(Equal(uint32(0)))
}) })
It("should reuse connections", func() { It("should reuse connections", func() {
@ -135,6 +141,12 @@ var _ = Describe("pool", func() {
pool := client.Pool() pool := client.Pool()
Expect(pool.Len()).To(Equal(1)) Expect(pool.Len()).To(Equal(1))
Expect(pool.FreeLen()).To(Equal(1)) Expect(pool.FreeLen()).To(Equal(1))
stats := pool.Stats()
Expect(stats.Requests).To(Equal(uint32(100)))
Expect(stats.Hits).To(Equal(uint32(99)))
Expect(stats.Waits).To(Equal(uint32(0)))
Expect(stats.Timeouts).To(Equal(uint32(0)))
}) })
It("should unblock client when connection is removed", func() { It("should unblock client when connection is removed", func() {

View File

@ -245,10 +245,11 @@ func (c *PubSub) Receive() (interface{}, error) {
return c.ReceiveTimeout(0) return c.ReceiveTimeout(0)
} }
// ReceiveMessage returns a message or error. It automatically // ReceiveMessage returns a Message or error ignoring Subscription or Pong
// reconnects to Redis in case of network errors. // messages. It automatically reconnects to Redis Server and resubscribes
// to channels in case of network errors.
func (c *PubSub) ReceiveMessage() (*Message, error) { func (c *PubSub) ReceiveMessage() (*Message, error) {
var errNum int var errNum uint
for { for {
msgi, err := c.ReceiveTimeout(5 * time.Second) msgi, err := c.ReceiveTimeout(5 * time.Second)
if err != nil { if err != nil {
@ -260,10 +261,9 @@ func (c *PubSub) ReceiveMessage() (*Message, error) {
if errNum < 3 { if errNum < 3 {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
err := c.Ping("") err := c.Ping("")
if err == nil { if err != nil {
continue Logger.Printf("PubSub.Ping failed: %s", err)
} }
Logger.Printf("PubSub.Ping failed: %s", err)
} }
} else { } else {
// 3 consequent errors - connection is bad // 3 consequent errors - connection is bad
@ -297,7 +297,7 @@ func (c *PubSub) ReceiveMessage() (*Message, error) {
} }
func (c *PubSub) putConn(cn *conn, err error) { func (c *PubSub) putConn(cn *conn, err error) {
if !c.base.putConn(cn, err) { if !c.base.putConn(cn, err, true) {
c.nsub = 0 c.nsub = 0
} }
} }

View File

@ -33,12 +33,6 @@ var _ = Describe("PubSub", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
defer pubsub.Close() defer pubsub.Close()
n, err := client.Publish("mychannel1", "hello").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
Expect(pubsub.PUnsubscribe("mychannel*")).NotTo(HaveOccurred())
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -48,6 +42,18 @@ var _ = Describe("PubSub", func() {
Expect(subscr.Count).To(Equal(1)) Expect(subscr.Count).To(Equal(1))
} }
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
Expect(err.(net.Error).Timeout()).To(Equal(true))
Expect(msgi).To(BeNil())
}
n, err := client.Publish("mychannel1", "hello").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
Expect(pubsub.PUnsubscribe("mychannel*")).NotTo(HaveOccurred())
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -66,11 +72,8 @@ var _ = Describe("PubSub", func() {
Expect(subscr.Count).To(Equal(0)) Expect(subscr.Count).To(Equal(0))
} }
{ stats := client.Pool().Stats()
msgi, err := pubsub.ReceiveTimeout(time.Second) Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2)))
Expect(err.(net.Error).Timeout()).To(Equal(true))
Expect(msgi).NotTo(HaveOccurred())
}
}) })
It("should pub/sub channels", func() { It("should pub/sub channels", func() {
@ -128,16 +131,6 @@ var _ = Describe("PubSub", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
defer pubsub.Close() defer pubsub.Close()
n, err := client.Publish("mychannel", "hello").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
n, err = client.Publish("mychannel2", "hello2").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
Expect(pubsub.Unsubscribe("mychannel", "mychannel2")).NotTo(HaveOccurred())
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -156,6 +149,22 @@ var _ = Describe("PubSub", func() {
Expect(subscr.Count).To(Equal(2)) Expect(subscr.Count).To(Equal(2))
} }
{
msgi, err := pubsub.ReceiveTimeout(time.Second)
Expect(err.(net.Error).Timeout()).To(Equal(true))
Expect(msgi).NotTo(HaveOccurred())
}
n, err := client.Publish("mychannel", "hello").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
n, err = client.Publish("mychannel2", "hello2").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
Expect(pubsub.Unsubscribe("mychannel", "mychannel2")).NotTo(HaveOccurred())
{ {
msgi, err := pubsub.ReceiveTimeout(time.Second) msgi, err := pubsub.ReceiveTimeout(time.Second)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -190,11 +199,8 @@ var _ = Describe("PubSub", func() {
Expect(subscr.Count).To(Equal(0)) Expect(subscr.Count).To(Equal(0))
} }
{ stats := client.Pool().Stats()
msgi, err := pubsub.ReceiveTimeout(time.Second) Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2)))
Expect(err.(net.Error).Timeout()).To(Equal(true))
Expect(msgi).NotTo(HaveOccurred())
}
}) })
It("should ping/pong", func() { It("should ping/pong", func() {
@ -277,6 +283,9 @@ var _ = Describe("PubSub", func() {
Expect(msg.Payload).To(Equal("hello")) Expect(msg.Payload).To(Equal("hello"))
Eventually(done).Should(Receive()) Eventually(done).Should(Receive())
stats := client.Pool().Stats()
Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2)))
}) })
expectReceiveMessageOnError := func(pubsub *redis.PubSub) { expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
@ -305,6 +314,9 @@ var _ = Describe("PubSub", func() {
Expect(msg.Payload).To(Equal("hello")) Expect(msg.Payload).To(Equal("hello"))
Eventually(done).Should(Receive()) Eventually(done).Should(Receive())
stats := client.Pool().Stats()
Expect(stats.Requests - stats.Hits - stats.Waits).To(Equal(uint32(2)))
} }
It("Subscribe should reconnect on ReceiveMessage error", func() { It("Subscribe should reconnect on ReceiveMessage error", func() {

View File

@ -13,6 +13,8 @@ var Logger = log.New(os.Stderr, "redis: ", log.LstdFlags)
type baseClient struct { type baseClient struct {
connPool pool connPool pool
opt *Options opt *Options
onClose func() error // hook called when client is closed
} }
func (c *baseClient) String() string { func (c *baseClient) String() string {
@ -23,8 +25,8 @@ func (c *baseClient) conn() (*conn, bool, error) {
return c.connPool.Get() return c.connPool.Get()
} }
func (c *baseClient) putConn(cn *conn, err error) bool { func (c *baseClient) putConn(cn *conn, err error, allowTimeout bool) bool {
if isBadConn(err) { if isBadConn(err, allowTimeout) {
err = c.connPool.Remove(cn, err) err = c.connPool.Remove(cn, err)
if err != nil { if err != nil {
Logger.Printf("pool.Remove failed: %s", err) Logger.Printf("pool.Remove failed: %s", err)
@ -51,20 +53,16 @@ func (c *baseClient) process(cmd Cmder) {
return return
} }
if timeout := cmd.writeTimeout(); timeout != nil { readTimeout := cmd.readTimeout()
cn.WriteTimeout = *timeout if readTimeout != nil {
} else { cn.ReadTimeout = *readTimeout
cn.WriteTimeout = c.opt.WriteTimeout
}
if timeout := cmd.readTimeout(); timeout != nil {
cn.ReadTimeout = *timeout
} else { } else {
cn.ReadTimeout = c.opt.ReadTimeout cn.ReadTimeout = c.opt.ReadTimeout
} }
cn.WriteTimeout = c.opt.WriteTimeout
if err := cn.writeCmds(cmd); err != nil { if err := cn.writeCmds(cmd); err != nil {
c.putConn(cn, err) c.putConn(cn, err, false)
cmd.setErr(err) cmd.setErr(err)
if shouldRetry(err) { if shouldRetry(err) {
continue continue
@ -73,7 +71,7 @@ func (c *baseClient) process(cmd Cmder) {
} }
err = cmd.readReply(cn) err = cmd.readReply(cn)
c.putConn(cn, err) c.putConn(cn, err, readTimeout != nil)
if shouldRetry(err) { if shouldRetry(err) {
continue continue
} }
@ -87,7 +85,16 @@ func (c *baseClient) process(cmd Cmder) {
// It is rare to Close a Client, as the Client is meant to be // It is rare to Close a Client, as the Client is meant to be
// long-lived and shared between many goroutines. // long-lived and shared between many goroutines.
func (c *baseClient) Close() error { func (c *baseClient) Close() error {
return c.connPool.Close() var retErr error
if c.onClose != nil {
if err := c.onClose(); err != nil && retErr == nil {
retErr = err
}
}
if err := c.connPool.Close(); err != nil && retErr == nil {
retErr = err
}
return retErr
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -190,8 +197,10 @@ type Client struct {
func newClient(opt *Options, pool pool) *Client { func newClient(opt *Options, pool pool) *Client {
base := baseClient{opt: opt, connPool: pool} base := baseClient{opt: opt, connPool: pool}
return &Client{ return &Client{
baseClient: base, baseClient: base,
commandable: commandable{process: base.process}, commandable: commandable{
process: base.process,
},
} }
} }

View File

@ -326,7 +326,7 @@ func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) {
resetCmds(cmds) resetCmds(cmds)
} }
failedCmds, err := execCmds(cn, cmds) failedCmds, err := execCmds(cn, cmds)
client.putConn(cn, err) client.putConn(cn, err, false)
if err != nil && retErr == nil { if err != nil && retErr == nil {
retErr = err retErr = err
} }

View File

@ -65,18 +65,31 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
opt: opt, opt: opt,
} }
return newClient(opt, failover.Pool()) base := baseClient{
opt: opt,
connPool: failover.Pool(),
onClose: func() error {
return failover.Close()
},
}
return &Client{
baseClient: base,
commandable: commandable{
process: base.process,
},
}
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type sentinelClient struct { type sentinelClient struct {
baseClient
commandable commandable
*baseClient
} }
func newSentinel(opt *Options) *sentinelClient { func newSentinel(opt *Options) *sentinelClient {
base := &baseClient{ base := baseClient{
opt: opt, opt: opt,
connPool: newConnPool(opt), connPool: newConnPool(opt),
} }
@ -116,8 +129,12 @@ type sentinelFailover struct {
pool pool pool pool
poolOnce sync.Once poolOnce sync.Once
lock sync.RWMutex mu sync.RWMutex
_sentinel *sentinelClient sentinel *sentinelClient
}
func (d *sentinelFailover) Close() error {
return d.resetSentinel()
} }
func (d *sentinelFailover) dial() (net.Conn, error) { func (d *sentinelFailover) dial() (net.Conn, error) {
@ -137,15 +154,15 @@ func (d *sentinelFailover) Pool() pool {
} }
func (d *sentinelFailover) MasterAddr() (string, error) { func (d *sentinelFailover) MasterAddr() (string, error) {
defer d.lock.Unlock() defer d.mu.Unlock()
d.lock.Lock() d.mu.Lock()
// Try last working sentinel. // Try last working sentinel.
if d._sentinel != nil { if d.sentinel != nil {
addr, err := d._sentinel.GetMasterAddrByName(d.masterName).Result() addr, err := d.sentinel.GetMasterAddrByName(d.masterName).Result()
if err != nil { if err != nil {
Logger.Printf("sentinel: GetMasterAddrByName %q failed: %s", d.masterName, err) Logger.Printf("sentinel: GetMasterAddrByName %q failed: %s", d.masterName, err)
d.resetSentinel() d._resetSentinel()
} else { } else {
addr := net.JoinHostPort(addr[0], addr[1]) addr := net.JoinHostPort(addr[0], addr[1])
Logger.Printf("sentinel: %q addr is %s", d.masterName, addr) Logger.Printf("sentinel: %q addr is %s", d.masterName, addr)
@ -186,10 +203,26 @@ func (d *sentinelFailover) MasterAddr() (string, error) {
func (d *sentinelFailover) setSentinel(sentinel *sentinelClient) { func (d *sentinelFailover) setSentinel(sentinel *sentinelClient) {
d.discoverSentinels(sentinel) d.discoverSentinels(sentinel)
d._sentinel = sentinel d.sentinel = sentinel
go d.listen() go d.listen()
} }
func (d *sentinelFailover) resetSentinel() error {
d.mu.Lock()
err := d._resetSentinel()
d.mu.Unlock()
return err
}
func (d *sentinelFailover) _resetSentinel() error {
var err error
if d.sentinel != nil {
err = d.sentinel.Close()
d.sentinel = nil
}
return err
}
func (d *sentinelFailover) discoverSentinels(sentinel *sentinelClient) { func (d *sentinelFailover) discoverSentinels(sentinel *sentinelClient) {
sentinels, err := sentinel.Sentinels(d.masterName).Result() sentinels, err := sentinel.Sentinels(d.masterName).Result()
if err != nil { if err != nil {
@ -247,55 +280,41 @@ func (d *sentinelFailover) listen() {
var pubsub *PubSub var pubsub *PubSub
for { for {
if pubsub == nil { if pubsub == nil {
pubsub = d._sentinel.PubSub() pubsub = d.sentinel.PubSub()
if err := pubsub.Subscribe("+switch-master"); err != nil { if err := pubsub.Subscribe("+switch-master"); err != nil {
Logger.Printf("sentinel: Subscribe failed: %s", err) Logger.Printf("sentinel: Subscribe failed: %s", err)
d.lock.Lock()
d.resetSentinel() d.resetSentinel()
d.lock.Unlock()
return return
} }
} }
msg, err := pubsub.Receive() msg, err := pubsub.ReceiveMessage()
if err != nil { if err != nil {
Logger.Printf("sentinel: Receive failed: %s", err) Logger.Printf("sentinel: ReceiveMessage failed: %s", err)
pubsub.Close() pubsub.Close()
d.resetSentinel()
return return
} }
switch msg := msg.(type) { switch msg.Channel {
case *Message: case "+switch-master":
switch msg.Channel { parts := strings.Split(msg.Payload, " ")
case "+switch-master": if parts[0] != d.masterName {
parts := strings.Split(msg.Payload, " ") Logger.Printf("sentinel: ignore new %s addr", parts[0])
if parts[0] != d.masterName { continue
Logger.Printf("sentinel: ignore new %s addr", parts[0])
continue
}
addr := net.JoinHostPort(parts[3], parts[4])
Logger.Printf(
"sentinel: new %q addr is %s",
d.masterName, addr,
)
d.closeOldConns(addr)
default:
Logger.Printf("sentinel: unsupported message: %s", msg)
} }
case *Subscription:
// Ignore. addr := net.JoinHostPort(parts[3], parts[4])
default: Logger.Printf(
Logger.Printf("sentinel: unsupported message: %s", msg) "sentinel: new %q addr is %s",
d.masterName, addr,
)
d.closeOldConns(addr)
} }
} }
} }
func (d *sentinelFailover) resetSentinel() {
d._sentinel.Close()
d._sentinel = nil
}
func contains(slice []string, str string) bool { func contains(slice []string, str string) bool {
for _, s := range slice { for _, s := range slice {
if s == str { if s == str {