diff --git a/cluster_pipeline.go b/cluster_pipeline.go index 73d272b..5299b5f 100644 --- a/cluster_pipeline.go +++ b/cluster_pipeline.go @@ -79,7 +79,7 @@ func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) { if err != nil { retErr = err } - client.putConn(cn, err) + client.putConn(cn, err, false) } cmdsMap = failedCmds diff --git a/command.go b/command.go index 1986c66..31516f6 100644 --- a/command.go +++ b/command.go @@ -32,7 +32,6 @@ type Cmder interface { setErr(error) reset() - writeTimeout() *time.Duration readTimeout() *time.Duration clusterKey() string @@ -82,7 +81,7 @@ type baseCmd struct { _clusterKeyPos int - _writeTimeout, _readTimeout *time.Duration + _readTimeout *time.Duration } func (cmd *baseCmd) Err() error { @@ -104,10 +103,6 @@ func (cmd *baseCmd) setReadTimeout(d time.Duration) { cmd._readTimeout = &d } -func (cmd *baseCmd) writeTimeout() *time.Duration { - return cmd._writeTimeout -} - func (cmd *baseCmd) clusterKey() string { if cmd._clusterKeyPos > 0 && cmd._clusterKeyPos < len(cmd._args) { return fmt.Sprint(cmd._args[cmd._clusterKeyPos]) @@ -115,10 +110,6 @@ func (cmd *baseCmd) clusterKey() string { return "" } -func (cmd *baseCmd) setWriteTimeout(d time.Duration) { - cmd._writeTimeout = &d -} - func (cmd *baseCmd) setErr(e error) { cmd.err = e } diff --git a/error.go b/error.go index dce10a3..3f2a560 100644 --- a/error.go +++ b/error.go @@ -33,15 +33,17 @@ func isNetworkError(err error) bool { return ok } -func isBadConn(err error) bool { +func isBadConn(err error, allowTimeout bool) bool { if err == nil { return false } if _, ok := err.(redisError); ok { return false } - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - return false + if allowTimeout { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return false + } } return true } diff --git a/multi.go b/multi.go index 7ffc7e0..1a13d04 100644 --- a/multi.go +++ b/multi.go @@ -133,7 +133,7 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) { } err = c.execCmds(cn, cmds) - c.base.putConn(cn, err) + c.base.putConn(cn, err, false) return retCmds, err } diff --git a/pipeline.go b/pipeline.go index 8c800be..8caae6b 100644 --- a/pipeline.go +++ b/pipeline.go @@ -98,7 +98,7 @@ func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) { resetCmds(failedCmds) } failedCmds, err = execCmds(cn, failedCmds) - pipe.client.putConn(cn, err) + pipe.client.putConn(cn, err, false) if err != nil && retErr == nil { retErr = err } diff --git a/pubsub.go b/pubsub.go index bde81b5..1b422ec 100644 --- a/pubsub.go +++ b/pubsub.go @@ -297,7 +297,7 @@ func (c *PubSub) ReceiveMessage() (*Message, error) { } func (c *PubSub) putConn(cn *conn, err error) { - if !c.base.putConn(cn, err) { + if !c.base.putConn(cn, err, true) { c.nsub = 0 } } diff --git a/redis.go b/redis.go index 488bfb9..5af7d68 100644 --- a/redis.go +++ b/redis.go @@ -23,8 +23,8 @@ func (c *baseClient) conn() (*conn, bool, error) { return c.connPool.Get() } -func (c *baseClient) putConn(cn *conn, err error) bool { - if isBadConn(err) { +func (c *baseClient) putConn(cn *conn, err error, allowTimeout bool) bool { + if isBadConn(err, allowTimeout) { err = c.connPool.Remove(cn, err) if err != nil { Logger.Printf("pool.Remove failed: %s", err) @@ -51,20 +51,16 @@ func (c *baseClient) process(cmd Cmder) { return } - if timeout := cmd.writeTimeout(); timeout != nil { - cn.WriteTimeout = *timeout - } else { - cn.WriteTimeout = c.opt.WriteTimeout - } - - if timeout := cmd.readTimeout(); timeout != nil { - cn.ReadTimeout = *timeout + readTimeout := cmd.readTimeout() + if readTimeout != nil { + cn.ReadTimeout = *readTimeout } else { cn.ReadTimeout = c.opt.ReadTimeout } + cn.WriteTimeout = c.opt.WriteTimeout if err := cn.writeCmds(cmd); err != nil { - c.putConn(cn, err) + c.putConn(cn, err, false) cmd.setErr(err) if shouldRetry(err) { continue @@ -73,7 +69,7 @@ func (c *baseClient) process(cmd Cmder) { } err = cmd.readReply(cn) - c.putConn(cn, err) + c.putConn(cn, err, readTimeout != nil) if shouldRetry(err) { continue } diff --git a/ring.go b/ring.go index 1d9b902..f1ae8ad 100644 --- a/ring.go +++ b/ring.go @@ -326,7 +326,7 @@ func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) { resetCmds(cmds) } failedCmds, err := execCmds(cn, cmds) - client.putConn(cn, err) + client.putConn(cn, err, false) if err != nil && retErr == nil { retErr = err }