diff --git a/command.go b/command.go index 977a313..fcea708 100644 --- a/command.go +++ b/command.go @@ -47,6 +47,12 @@ func setCmdsErr(cmds []Cmder, e error) { } } +func resetCmds(cmds []Cmder) { + for _, cmd := range cmds { + cmd.reset() + } +} + func cmdString(cmd Cmder, val interface{}) string { s := strings.Join(cmd.args(), " ") if err := cmd.Err(); err != nil { diff --git a/conn.go b/conn.go index 6ce5580..751bb54 100644 --- a/conn.go +++ b/conn.go @@ -7,14 +7,18 @@ import ( "gopkg.in/bufio.v1" ) +var ( + zeroTime = time.Time{} +) + type conn struct { netcn net.Conn rd *bufio.Reader buf []byte usedAt time.Time - readTimeout time.Duration - writeTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration } func newConnDialer(opt *options) func() (*conn, error) { @@ -70,8 +74,8 @@ func (cn *conn) writeCmds(cmds ...Cmder) error { } func (cn *conn) Read(b []byte) (int, error) { - if cn.readTimeout != 0 { - cn.netcn.SetReadDeadline(time.Now().Add(cn.readTimeout)) + if cn.ReadTimeout != 0 { + cn.netcn.SetReadDeadline(time.Now().Add(cn.ReadTimeout)) } else { cn.netcn.SetReadDeadline(zeroTime) } @@ -79,8 +83,8 @@ func (cn *conn) Read(b []byte) (int, error) { } func (cn *conn) Write(b []byte) (int, error) { - if cn.writeTimeout != 0 { - cn.netcn.SetWriteDeadline(time.Now().Add(cn.writeTimeout)) + if cn.WriteTimeout != 0 { + cn.netcn.SetWriteDeadline(time.Now().Add(cn.WriteTimeout)) } else { cn.netcn.SetWriteDeadline(zeroTime) } diff --git a/error.go b/error.go index 0e031f3..7709c51 100644 --- a/error.go +++ b/error.go @@ -26,7 +26,7 @@ func (err redisError) Error() string { } func isNetworkError(err error) bool { - if _, ok := err.(*net.OpError); ok || err == io.EOF { + if _, ok := err.(net.Error); ok || err == io.EOF { return true } return false @@ -53,3 +53,11 @@ func isMovedError(err error) (moved bool, ask bool, addr string) { return } + +// shouldRetry reports whether failed command should be retried. +func shouldRetry(err error) bool { + if err == nil { + return false + } + return isNetworkError(err) +} diff --git a/export_test.go b/export_test.go index 53519d5..f468729 100644 --- a/export_test.go +++ b/export_test.go @@ -1,9 +1,15 @@ package redis +import "net" + func (c *baseClient) Pool() pool { return c.connPool } +func (cn *conn) SetNetConn(netcn net.Conn) { + cn.netcn = netcn +} + func HashSlot(key string) int { return hashSlot(key) } diff --git a/pipeline.go b/pipeline.go index 8cfd5a1..62cf7fd 100644 --- a/pipeline.go +++ b/pipeline.go @@ -50,26 +50,38 @@ func (c *Pipeline) Discard() error { // Exec always returns list of commands and error of the first failed // command if any. -func (c *Pipeline) Exec() ([]Cmder, error) { +func (c *Pipeline) Exec() (cmds []Cmder, retErr error) { if c.closed { return nil, errClosed } if len(c.cmds) == 0 { - return []Cmder{}, nil + return c.cmds, nil } - cmds := c.cmds + cmds = c.cmds c.cmds = make([]Cmder, 0, 0) - cn, err := c.client.conn() - if err != nil { - setCmdsErr(cmds, err) - return cmds, err + for i := 0; i <= c.client.opt.MaxRetries; i++ { + if i > 0 { + resetCmds(cmds) + } + + cn, err := c.client.conn() + if err != nil { + setCmdsErr(cmds, err) + return cmds, err + } + + retErr = c.execCmds(cn, cmds) + c.client.putConn(cn, err) + if shouldRetry(err) { + continue + } + + break } - err = c.execCmds(cn, cmds) - c.client.putConn(cn, err) - return cmds, err + return cmds, retErr } func (c *Pipeline) execCmds(cn *conn, cmds []Cmder) error { @@ -79,17 +91,11 @@ func (c *Pipeline) execCmds(cn *conn, cmds []Cmder) error { } var firstCmdErr error - for i, cmd := range cmds { + for _, cmd := range cmds { err := cmd.parseReply(cn.rd) - if err == nil { - continue - } - if firstCmdErr == nil { + if err != nil && firstCmdErr == nil { firstCmdErr = err } - if isNetworkError(err) { - setCmdsErr(cmds[i:], err) - } } return firstCmdErr diff --git a/pool.go b/pool.go index 5194bc8..8af50e5 100644 --- a/pool.go +++ b/pool.go @@ -16,10 +16,6 @@ var ( errPoolTimeout = errors.New("redis: connection pool timeout") ) -var ( - zeroTime = time.Time{} -) - type pool interface { First() *conn Get() (*conn, error) diff --git a/pubsub.go b/pubsub.go index 1f4e911..5a32220 100644 --- a/pubsub.go +++ b/pubsub.go @@ -63,7 +63,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { if err != nil { return nil, err } - cn.readTimeout = timeout + cn.ReadTimeout = timeout cmd := NewSliceCmd() if err := cmd.parseReply(cn.rd); err != nil { @@ -92,6 +92,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { Payload: reply[3].(string), }, nil } + return nil, fmt.Errorf("redis: unsupported message name: %q", msgName) } diff --git a/redis.go b/redis.go index a2fe519..6d774e4 100644 --- a/redis.go +++ b/redis.go @@ -32,32 +32,46 @@ func (c *baseClient) putConn(cn *conn, ei error) { } func (c *baseClient) process(cmd Cmder) { - cn, err := c.conn() - if err != nil { - cmd.setErr(err) - return - } + for i := 0; i <= c.opt.MaxRetries; i++ { + if i > 0 { + cmd.reset() + } - if timeout := cmd.writeTimeout(); timeout != nil { - cn.writeTimeout = *timeout - } else { - cn.writeTimeout = c.opt.WriteTimeout - } + cn, err := c.conn() + if err != nil { + cmd.setErr(err) + return + } - if timeout := cmd.readTimeout(); timeout != nil { - cn.readTimeout = *timeout - } else { - cn.readTimeout = c.opt.ReadTimeout - } + if timeout := cmd.writeTimeout(); timeout != nil { + cn.WriteTimeout = *timeout + } else { + cn.WriteTimeout = c.opt.WriteTimeout + } - if err := cn.writeCmds(cmd); err != nil { + if timeout := cmd.readTimeout(); timeout != nil { + cn.ReadTimeout = *timeout + } else { + cn.ReadTimeout = c.opt.ReadTimeout + } + + if err := cn.writeCmds(cmd); err != nil { + c.putConn(cn, err) + cmd.setErr(err) + if shouldRetry(err) { + continue + } + return + } + + err = cmd.parseReply(cn.rd) c.putConn(cn, err) - cmd.setErr(err) + if shouldRetry(err) { + continue + } + return } - - err = cmd.parseReply(cn.rd) - c.putConn(cn, err) } // Close closes the client, releasing any open resources. @@ -105,6 +119,10 @@ type Options struct { // than specified in this option. // Default: 0 = no eviction IdleTimeout time.Duration + + // MaxRetries specifies maximum number of times client will retry + // failed command. Default is to not retry failed command. + MaxRetries int } func (opt *Options) getDialer() func() (net.Conn, error) { @@ -157,6 +175,8 @@ func (opt *Options) options() *options { DialTimeout: opt.getDialTimeout(), ReadTimeout: opt.ReadTimeout, WriteTimeout: opt.WriteTimeout, + + MaxRetries: opt.MaxRetries, } } @@ -172,6 +192,8 @@ type options struct { DialTimeout time.Duration ReadTimeout time.Duration WriteTimeout time.Duration + + MaxRetries int } func (opt *options) connPoolOptions() *connPoolOptions { diff --git a/redis_test.go b/redis_test.go index 58676bc..0198aba 100644 --- a/redis_test.go +++ b/redis_test.go @@ -124,6 +124,23 @@ var _ = Describe("Client", func() { Expect(db1.FlushDb().Err()).NotTo(HaveOccurred()) }) + It("should retry command on network error", func() { + Expect(client.Close()).NotTo(HaveOccurred()) + + client = redis.NewClient(&redis.Options{ + Addr: redisAddr, + MaxRetries: 1, + }) + + // Put bad connection in the pool. + cn, err := client.Pool().Get() + Expect(err).NotTo(HaveOccurred()) + cn.SetNetConn(newBadNetConn()) + Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) + + err = client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + }) }) //------------------------------------------------------------------------------ @@ -266,6 +283,24 @@ func BenchmarkPipeline(b *testing.B) { //------------------------------------------------------------------------------ +type badNetConn struct { + net.TCPConn +} + +var _ net.Conn = &badNetConn{} + +func newBadNetConn() net.Conn { + return &badNetConn{} +} + +func (badNetConn) Read([]byte) (int, error) { + return 0, net.UnknownNetworkError("badNetConn") +} + +func (badNetConn) Write([]byte) (int, error) { + return 0, net.UnknownNetworkError("badNetConn") +} + // Replaces ginkgo's Eventually. func waitForSubstring(fn func() string, substr string, timeout time.Duration) error { var s string