diff --git a/cluster.go b/cluster.go index a54f2f37..7347f46b 100644 --- a/cluster.go +++ b/cluster.go @@ -63,6 +63,7 @@ type ClusterOptions struct { MaxRetries int MinRetryBackoff time.Duration MaxRetryBackoff time.Duration + ShouldRetry ShouldRetryFunc DialTimeout time.Duration ReadTimeout time.Duration @@ -129,6 +130,10 @@ func (opt *ClusterOptions) init() { if opt.NewClient == nil { opt.NewClient = NewClient } + + if opt.ShouldRetry == nil { + opt.ShouldRetry = DefaultShouldRetry + } } func (opt *ClusterOptions) clientOptions() *Options { @@ -834,7 +839,7 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { continue } - if shouldRetry(lastErr, cmd.readTimeout() == nil) { + if c.opt.ShouldRetry(lastErr, cmd.readTimeout() == nil) { // First retry the same node. if attempt == 0 { continue @@ -1497,7 +1502,7 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s continue } - if shouldRetry(err, true) { + if c.opt.ShouldRetry(err, true) { continue } diff --git a/cluster_test.go b/cluster_test.go index 6ee7364e..d4a05bbb 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -922,6 +922,29 @@ var _ = Describe("ClusterClient", func() { Expect(info.Val()).Should(ContainSubstring("tcp_port:8221")) }) + It("should support custom ShouldRetry", func() { + script := redis.NewScript(` + local k = KEYS[1] + local n = redis.call("incr", k) + if n == 1 then + return redis.error_reply("attempt 1 should fail") + end + redis.call("del", k) + return true + `) + script.Load(ctx, client) + opt := redisClusterOptions() + opt.ShouldRetry = func(err error, retryTimeout bool) bool { + if err.Error() == "attempt 1 should fail" { + return true + } + return redis.DefaultShouldRetry(err, retryTimeout) + } + client := cluster.newClusterClient(ctx, opt) + val, _ := script.Run(ctx, client, []string{"random_key"}).Result() + Expect(val).To(Equal(int64(1))) + }) + assertClusterClient() }) diff --git a/error.go b/error.go index 5025215b..611aea12 100644 --- a/error.go +++ b/error.go @@ -12,6 +12,7 @@ import ( // ErrClosed performs any operation on the closed client will return this error. var ErrClosed = pool.ErrClosed +var ErrPoolTimeout = pool.ErrPoolTimeout type Error interface { error @@ -25,7 +26,7 @@ type Error interface { var _ Error = proto.RedisError("") -func shouldRetry(err error, retryTimeout bool) bool { +func DefaultShouldRetry(err error, retryTimeout bool) bool { switch err { case io.EOF, io.ErrUnexpectedEOF: return true diff --git a/options.go b/options.go index a4abe32c..ac33cc6d 100644 --- a/options.go +++ b/options.go @@ -16,6 +16,8 @@ import ( "github.com/go-redis/redis/v8/internal/pool" ) +type ShouldRetryFunc = func(error, bool) bool + // Limiter is the interface of a rate limiter or a circuit breaker. type Limiter interface { // Allow returns nil if operation is allowed or an error otherwise. @@ -64,6 +66,7 @@ type Options struct { // Maximum backoff between each retry. // Default is 512 milliseconds; -1 disables backoff. MaxRetryBackoff time.Duration + ShouldRetry ShouldRetryFunc // Dial timeout for establishing new connections. // Default is 5 seconds. @@ -182,6 +185,9 @@ func (opt *Options) init() { case 0: opt.MaxRetryBackoff = 512 * time.Millisecond } + if opt.ShouldRetry == nil { + opt.ShouldRetry = DefaultShouldRetry + } } func (opt *Options) clone() *Options { diff --git a/redis.go b/redis.go index bcf8a2a9..c64c9156 100644 --- a/redis.go +++ b/redis.go @@ -348,7 +348,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool return false, nil } - retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) + retry := c.opt.ShouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) return retry, err } @@ -426,7 +426,7 @@ func (c *baseClient) _generalProcessPipeline( canRetry, err = p(ctx, cn, cmds) return err }) - if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) { + if lastErr == nil || !canRetry || !c.opt.ShouldRetry(lastErr, true) { return lastErr } } diff --git a/redis_test.go b/redis_test.go index 05c6d791..0d77beea 100644 --- a/redis_test.go +++ b/redis_test.go @@ -304,6 +304,30 @@ var _ = Describe("Client", func() { err := client.Conn(ctx).Get(ctx, "this-key-does-not-exist").Err() Expect(err).To(Equal(redis.Nil)) }) + + It("should support custom ShouldRetry", func() { + opt := redisOptions() + opt.ShouldRetry = func(err error, retryTimeout bool) bool { + if err.Error() == "attempt 1 should fail" { + return true + } + return redis.DefaultShouldRetry(err, retryTimeout) + } + opt.MaxRetries = 1 + client := redis.NewClient(opt) + script := redis.NewScript(` + local k = KEYS[1] + local n = redis.call("incr", k) + if n == 1 then + return redis.error_reply("attempt 1 should fail") + end + redis.call("del", k) + return true + `) + script.Load(ctx, client) + val, _ := script.Run(ctx, client, []string{"random_key"}).Result() + Expect(val).To(Equal(int64(1))) + }) }) var _ = Describe("Client timeout", func() { diff --git a/ring.go b/ring.go index 11a20370..2cf2781d 100644 --- a/ring.go +++ b/ring.go @@ -74,6 +74,7 @@ type RingOptions struct { MaxRetries int MinRetryBackoff time.Duration MaxRetryBackoff time.Duration + ShouldRetry ShouldRetryFunc DialTimeout time.Duration ReadTimeout time.Duration @@ -125,6 +126,9 @@ func (opt *RingOptions) init() { case 0: opt.MaxRetryBackoff = 512 * time.Millisecond } + if opt.ShouldRetry == nil { + opt.ShouldRetry = DefaultShouldRetry + } } func (opt *RingOptions) clientOptions() *Options { @@ -606,7 +610,7 @@ func (c *Ring) process(ctx context.Context, cmd Cmder) error { } lastErr = shard.Client.Process(ctx, cmd) - if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) { + if lastErr == nil || !c.opt.ShouldRetry(lastErr, cmd.readTimeout() == nil) { return lastErr } } diff --git a/ring_test.go b/ring_test.go index 03a49fd7..1135acf0 100644 --- a/ring_test.go +++ b/ring_test.go @@ -171,6 +171,30 @@ var _ = Describe("Redis Ring", func() { Expect(ringShard1.Info(ctx).Val()).ToNot(ContainSubstring("keys=")) Expect(ringShard2.Info(ctx).Val()).To(ContainSubstring("keys=100")) }) + + It("should support custom ShouldRetry", func() { + opt := redisRingOptions() + opt.ShouldRetry = func(err error, retryTimeout bool) bool { + if err.Error() == "attempt 1 should fail" { + return true + } + return redis.DefaultShouldRetry(err, retryTimeout) + } + opt.MaxRetries = 1 + ring := redis.NewRing(opt) + script := redis.NewScript(` + local k = KEYS[1] + local n = redis.call("incr", k) + if n == 1 then + return redis.error_reply("attempt 1 should fail") + end + redis.call("del", k) + return true + `) + script.Load(ctx, ring) + val, _ := script.Run(ctx, ring, []string{"random_key"}).Result() + Expect(val).To(Equal(int64(1))) + }) }) Describe("new client callback", func() { diff --git a/sentinel.go b/sentinel.go index ec6221dc..7f9a1282 100644 --- a/sentinel.go +++ b/sentinel.go @@ -58,6 +58,7 @@ type FailoverOptions struct { MaxRetries int MinRetryBackoff time.Duration MaxRetryBackoff time.Duration + ShouldRetry ShouldRetryFunc DialTimeout time.Duration ReadTimeout time.Duration @@ -90,6 +91,7 @@ func (opt *FailoverOptions) clientOptions() *Options { MaxRetries: opt.MaxRetries, MinRetryBackoff: opt.MinRetryBackoff, MaxRetryBackoff: opt.MaxRetryBackoff, + ShouldRetry: opt.ShouldRetry, DialTimeout: opt.DialTimeout, ReadTimeout: opt.ReadTimeout, @@ -121,6 +123,7 @@ func (opt *FailoverOptions) sentinelOptions(addr string) *Options { MaxRetries: opt.MaxRetries, MinRetryBackoff: opt.MinRetryBackoff, MaxRetryBackoff: opt.MaxRetryBackoff, + ShouldRetry: opt.ShouldRetry, DialTimeout: opt.DialTimeout, ReadTimeout: opt.ReadTimeout, @@ -153,6 +156,7 @@ func (opt *FailoverOptions) clusterOptions() *ClusterOptions { MinRetryBackoff: opt.MinRetryBackoff, MaxRetryBackoff: opt.MaxRetryBackoff, + ShouldRetry: opt.ShouldRetry, DialTimeout: opt.DialTimeout, ReadTimeout: opt.ReadTimeout, diff --git a/sentinel_test.go b/sentinel_test.go index 753e0fc2..2ccf9581 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -211,6 +211,35 @@ var _ = Describe("NewFailoverClusterClient", func() { _, err = startRedis(masterPort) Expect(err).NotTo(HaveOccurred()) }) + + It("should support custom ShouldRetry", func() { + opt := &redis.FailoverOptions{ + MasterName: sentinelName, + SentinelAddrs: sentinelAddrs, + + RouteRandomly: true, + MaxRetries: 1, + } + opt.ShouldRetry = func(err error, retryTimeout bool) bool { + if err.Error() == "attempt 1 should fail" { + return true + } + return redis.DefaultShouldRetry(err, retryTimeout) + } + client := redis.NewFailoverClusterClient(opt) + script := redis.NewScript(` + local k = KEYS[1] + local n = redis.call("incr", k) + if n == 1 then + return redis.error_reply("attempt 1 should fail") + end + redis.call("del", k) + return true + `) + script.Load(ctx, client) + val, _ := script.Run(ctx, client, []string{"random_key"}).Result() + Expect(val).To(Equal(int64(1))) + }) }) var _ = Describe("SentinelAclAuth", func() { diff --git a/universal.go b/universal.go index 1e962ab3..d26c7378 100644 --- a/universal.go +++ b/universal.go @@ -30,6 +30,7 @@ type UniversalOptions struct { MaxRetries int MinRetryBackoff time.Duration MaxRetryBackoff time.Duration + ShouldRetry ShouldRetryFunc DialTimeout time.Duration ReadTimeout time.Duration @@ -82,6 +83,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions { MaxRetries: o.MaxRetries, MinRetryBackoff: o.MinRetryBackoff, MaxRetryBackoff: o.MaxRetryBackoff, + ShouldRetry: o.ShouldRetry, DialTimeout: o.DialTimeout, ReadTimeout: o.ReadTimeout, @@ -119,6 +121,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions { MaxRetries: o.MaxRetries, MinRetryBackoff: o.MinRetryBackoff, MaxRetryBackoff: o.MaxRetryBackoff, + ShouldRetry: o.ShouldRetry, DialTimeout: o.DialTimeout, ReadTimeout: o.ReadTimeout, @@ -155,6 +158,7 @@ func (o *UniversalOptions) Simple() *Options { MaxRetries: o.MaxRetries, MinRetryBackoff: o.MinRetryBackoff, MaxRetryBackoff: o.MaxRetryBackoff, + ShouldRetry: o.ShouldRetry, DialTimeout: o.DialTimeout, ReadTimeout: o.ReadTimeout,