feat: enable custom retry behavior

Added a field ShouldRetry to all client option structs, to allow
custom behavior on retries. Existing shouldRetry function is renamed
to DefaultShouldRetry, and is the default as the name suggests.
This commit is contained in:
Conall OCallaghan 2022-02-25 17:33:18 -05:00 committed by conall o'callaghan
parent fa515240d4
commit 2a3e951cd4
11 changed files with 130 additions and 6 deletions

View File

@ -63,6 +63,7 @@ type ClusterOptions struct {
MaxRetries int MaxRetries int
MinRetryBackoff time.Duration MinRetryBackoff time.Duration
MaxRetryBackoff time.Duration MaxRetryBackoff time.Duration
ShouldRetry ShouldRetryFunc
DialTimeout time.Duration DialTimeout time.Duration
ReadTimeout time.Duration ReadTimeout time.Duration
@ -129,6 +130,10 @@ func (opt *ClusterOptions) init() {
if opt.NewClient == nil { if opt.NewClient == nil {
opt.NewClient = NewClient opt.NewClient = NewClient
} }
if opt.ShouldRetry == nil {
opt.ShouldRetry = DefaultShouldRetry
}
} }
func (opt *ClusterOptions) clientOptions() *Options { func (opt *ClusterOptions) clientOptions() *Options {
@ -834,7 +839,7 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error {
continue continue
} }
if shouldRetry(lastErr, cmd.readTimeout() == nil) { if c.opt.ShouldRetry(lastErr, cmd.readTimeout() == nil) {
// First retry the same node. // First retry the same node.
if attempt == 0 { if attempt == 0 {
continue continue
@ -1497,7 +1502,7 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
continue continue
} }
if shouldRetry(err, true) { if c.opt.ShouldRetry(err, true) {
continue continue
} }

View File

@ -922,6 +922,29 @@ var _ = Describe("ClusterClient", func() {
Expect(info.Val()).Should(ContainSubstring("tcp_port:8221")) Expect(info.Val()).Should(ContainSubstring("tcp_port:8221"))
}) })
It("should support custom ShouldRetry", func() {
script := redis.NewScript(`
local k = KEYS[1]
local n = redis.call("incr", k)
if n == 1 then
return redis.error_reply("attempt 1 should fail")
end
redis.call("del", k)
return true
`)
script.Load(ctx, client)
opt := redisClusterOptions()
opt.ShouldRetry = func(err error, retryTimeout bool) bool {
if err.Error() == "attempt 1 should fail" {
return true
}
return redis.DefaultShouldRetry(err, retryTimeout)
}
client := cluster.newClusterClient(ctx, opt)
val, _ := script.Run(ctx, client, []string{"random_key"}).Result()
Expect(val).To(Equal(int64(1)))
})
assertClusterClient() assertClusterClient()
}) })

View File

@ -12,6 +12,7 @@ import (
// ErrClosed performs any operation on the closed client will return this error. // ErrClosed performs any operation on the closed client will return this error.
var ErrClosed = pool.ErrClosed var ErrClosed = pool.ErrClosed
var ErrPoolTimeout = pool.ErrPoolTimeout
type Error interface { type Error interface {
error error
@ -25,7 +26,7 @@ type Error interface {
var _ Error = proto.RedisError("") var _ Error = proto.RedisError("")
func shouldRetry(err error, retryTimeout bool) bool { func DefaultShouldRetry(err error, retryTimeout bool) bool {
switch err { switch err {
case io.EOF, io.ErrUnexpectedEOF: case io.EOF, io.ErrUnexpectedEOF:
return true return true

View File

@ -16,6 +16,8 @@ import (
"github.com/go-redis/redis/v8/internal/pool" "github.com/go-redis/redis/v8/internal/pool"
) )
type ShouldRetryFunc = func(error, bool) bool
// Limiter is the interface of a rate limiter or a circuit breaker. // Limiter is the interface of a rate limiter or a circuit breaker.
type Limiter interface { type Limiter interface {
// Allow returns nil if operation is allowed or an error otherwise. // Allow returns nil if operation is allowed or an error otherwise.
@ -64,6 +66,7 @@ type Options struct {
// Maximum backoff between each retry. // Maximum backoff between each retry.
// Default is 512 milliseconds; -1 disables backoff. // Default is 512 milliseconds; -1 disables backoff.
MaxRetryBackoff time.Duration MaxRetryBackoff time.Duration
ShouldRetry ShouldRetryFunc
// Dial timeout for establishing new connections. // Dial timeout for establishing new connections.
// Default is 5 seconds. // Default is 5 seconds.
@ -182,6 +185,9 @@ func (opt *Options) init() {
case 0: case 0:
opt.MaxRetryBackoff = 512 * time.Millisecond opt.MaxRetryBackoff = 512 * time.Millisecond
} }
if opt.ShouldRetry == nil {
opt.ShouldRetry = DefaultShouldRetry
}
} }
func (opt *Options) clone() *Options { func (opt *Options) clone() *Options {

View File

@ -348,7 +348,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
return false, nil return false, nil
} }
retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) retry := c.opt.ShouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
return retry, err return retry, err
} }
@ -426,7 +426,7 @@ func (c *baseClient) _generalProcessPipeline(
canRetry, err = p(ctx, cn, cmds) canRetry, err = p(ctx, cn, cmds)
return err return err
}) })
if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) { if lastErr == nil || !canRetry || !c.opt.ShouldRetry(lastErr, true) {
return lastErr return lastErr
} }
} }

View File

@ -304,6 +304,30 @@ var _ = Describe("Client", func() {
err := client.Conn(ctx).Get(ctx, "this-key-does-not-exist").Err() err := client.Conn(ctx).Get(ctx, "this-key-does-not-exist").Err()
Expect(err).To(Equal(redis.Nil)) Expect(err).To(Equal(redis.Nil))
}) })
It("should support custom ShouldRetry", func() {
opt := redisOptions()
opt.ShouldRetry = func(err error, retryTimeout bool) bool {
if err.Error() == "attempt 1 should fail" {
return true
}
return redis.DefaultShouldRetry(err, retryTimeout)
}
opt.MaxRetries = 1
client := redis.NewClient(opt)
script := redis.NewScript(`
local k = KEYS[1]
local n = redis.call("incr", k)
if n == 1 then
return redis.error_reply("attempt 1 should fail")
end
redis.call("del", k)
return true
`)
script.Load(ctx, client)
val, _ := script.Run(ctx, client, []string{"random_key"}).Result()
Expect(val).To(Equal(int64(1)))
})
}) })
var _ = Describe("Client timeout", func() { var _ = Describe("Client timeout", func() {

View File

@ -74,6 +74,7 @@ type RingOptions struct {
MaxRetries int MaxRetries int
MinRetryBackoff time.Duration MinRetryBackoff time.Duration
MaxRetryBackoff time.Duration MaxRetryBackoff time.Duration
ShouldRetry ShouldRetryFunc
DialTimeout time.Duration DialTimeout time.Duration
ReadTimeout time.Duration ReadTimeout time.Duration
@ -125,6 +126,9 @@ func (opt *RingOptions) init() {
case 0: case 0:
opt.MaxRetryBackoff = 512 * time.Millisecond opt.MaxRetryBackoff = 512 * time.Millisecond
} }
if opt.ShouldRetry == nil {
opt.ShouldRetry = DefaultShouldRetry
}
} }
func (opt *RingOptions) clientOptions() *Options { func (opt *RingOptions) clientOptions() *Options {
@ -606,7 +610,7 @@ func (c *Ring) process(ctx context.Context, cmd Cmder) error {
} }
lastErr = shard.Client.Process(ctx, cmd) lastErr = shard.Client.Process(ctx, cmd)
if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) { if lastErr == nil || !c.opt.ShouldRetry(lastErr, cmd.readTimeout() == nil) {
return lastErr return lastErr
} }
} }

View File

@ -171,6 +171,30 @@ var _ = Describe("Redis Ring", func() {
Expect(ringShard1.Info(ctx).Val()).ToNot(ContainSubstring("keys=")) Expect(ringShard1.Info(ctx).Val()).ToNot(ContainSubstring("keys="))
Expect(ringShard2.Info(ctx).Val()).To(ContainSubstring("keys=100")) Expect(ringShard2.Info(ctx).Val()).To(ContainSubstring("keys=100"))
}) })
It("should support custom ShouldRetry", func() {
opt := redisRingOptions()
opt.ShouldRetry = func(err error, retryTimeout bool) bool {
if err.Error() == "attempt 1 should fail" {
return true
}
return redis.DefaultShouldRetry(err, retryTimeout)
}
opt.MaxRetries = 1
ring := redis.NewRing(opt)
script := redis.NewScript(`
local k = KEYS[1]
local n = redis.call("incr", k)
if n == 1 then
return redis.error_reply("attempt 1 should fail")
end
redis.call("del", k)
return true
`)
script.Load(ctx, ring)
val, _ := script.Run(ctx, ring, []string{"random_key"}).Result()
Expect(val).To(Equal(int64(1)))
})
}) })
Describe("new client callback", func() { Describe("new client callback", func() {

View File

@ -58,6 +58,7 @@ type FailoverOptions struct {
MaxRetries int MaxRetries int
MinRetryBackoff time.Duration MinRetryBackoff time.Duration
MaxRetryBackoff time.Duration MaxRetryBackoff time.Duration
ShouldRetry ShouldRetryFunc
DialTimeout time.Duration DialTimeout time.Duration
ReadTimeout time.Duration ReadTimeout time.Duration
@ -90,6 +91,7 @@ func (opt *FailoverOptions) clientOptions() *Options {
MaxRetries: opt.MaxRetries, MaxRetries: opt.MaxRetries,
MinRetryBackoff: opt.MinRetryBackoff, MinRetryBackoff: opt.MinRetryBackoff,
MaxRetryBackoff: opt.MaxRetryBackoff, MaxRetryBackoff: opt.MaxRetryBackoff,
ShouldRetry: opt.ShouldRetry,
DialTimeout: opt.DialTimeout, DialTimeout: opt.DialTimeout,
ReadTimeout: opt.ReadTimeout, ReadTimeout: opt.ReadTimeout,
@ -121,6 +123,7 @@ func (opt *FailoverOptions) sentinelOptions(addr string) *Options {
MaxRetries: opt.MaxRetries, MaxRetries: opt.MaxRetries,
MinRetryBackoff: opt.MinRetryBackoff, MinRetryBackoff: opt.MinRetryBackoff,
MaxRetryBackoff: opt.MaxRetryBackoff, MaxRetryBackoff: opt.MaxRetryBackoff,
ShouldRetry: opt.ShouldRetry,
DialTimeout: opt.DialTimeout, DialTimeout: opt.DialTimeout,
ReadTimeout: opt.ReadTimeout, ReadTimeout: opt.ReadTimeout,
@ -153,6 +156,7 @@ func (opt *FailoverOptions) clusterOptions() *ClusterOptions {
MinRetryBackoff: opt.MinRetryBackoff, MinRetryBackoff: opt.MinRetryBackoff,
MaxRetryBackoff: opt.MaxRetryBackoff, MaxRetryBackoff: opt.MaxRetryBackoff,
ShouldRetry: opt.ShouldRetry,
DialTimeout: opt.DialTimeout, DialTimeout: opt.DialTimeout,
ReadTimeout: opt.ReadTimeout, ReadTimeout: opt.ReadTimeout,

View File

@ -211,6 +211,35 @@ var _ = Describe("NewFailoverClusterClient", func() {
_, err = startRedis(masterPort) _, err = startRedis(masterPort)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
It("should support custom ShouldRetry", func() {
opt := &redis.FailoverOptions{
MasterName: sentinelName,
SentinelAddrs: sentinelAddrs,
RouteRandomly: true,
MaxRetries: 1,
}
opt.ShouldRetry = func(err error, retryTimeout bool) bool {
if err.Error() == "attempt 1 should fail" {
return true
}
return redis.DefaultShouldRetry(err, retryTimeout)
}
client := redis.NewFailoverClusterClient(opt)
script := redis.NewScript(`
local k = KEYS[1]
local n = redis.call("incr", k)
if n == 1 then
return redis.error_reply("attempt 1 should fail")
end
redis.call("del", k)
return true
`)
script.Load(ctx, client)
val, _ := script.Run(ctx, client, []string{"random_key"}).Result()
Expect(val).To(Equal(int64(1)))
})
}) })
var _ = Describe("SentinelAclAuth", func() { var _ = Describe("SentinelAclAuth", func() {

View File

@ -30,6 +30,7 @@ type UniversalOptions struct {
MaxRetries int MaxRetries int
MinRetryBackoff time.Duration MinRetryBackoff time.Duration
MaxRetryBackoff time.Duration MaxRetryBackoff time.Duration
ShouldRetry ShouldRetryFunc
DialTimeout time.Duration DialTimeout time.Duration
ReadTimeout time.Duration ReadTimeout time.Duration
@ -82,6 +83,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions {
MaxRetries: o.MaxRetries, MaxRetries: o.MaxRetries,
MinRetryBackoff: o.MinRetryBackoff, MinRetryBackoff: o.MinRetryBackoff,
MaxRetryBackoff: o.MaxRetryBackoff, MaxRetryBackoff: o.MaxRetryBackoff,
ShouldRetry: o.ShouldRetry,
DialTimeout: o.DialTimeout, DialTimeout: o.DialTimeout,
ReadTimeout: o.ReadTimeout, ReadTimeout: o.ReadTimeout,
@ -119,6 +121,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions {
MaxRetries: o.MaxRetries, MaxRetries: o.MaxRetries,
MinRetryBackoff: o.MinRetryBackoff, MinRetryBackoff: o.MinRetryBackoff,
MaxRetryBackoff: o.MaxRetryBackoff, MaxRetryBackoff: o.MaxRetryBackoff,
ShouldRetry: o.ShouldRetry,
DialTimeout: o.DialTimeout, DialTimeout: o.DialTimeout,
ReadTimeout: o.ReadTimeout, ReadTimeout: o.ReadTimeout,
@ -155,6 +158,7 @@ func (o *UniversalOptions) Simple() *Options {
MaxRetries: o.MaxRetries, MaxRetries: o.MaxRetries,
MinRetryBackoff: o.MinRetryBackoff, MinRetryBackoff: o.MinRetryBackoff,
MaxRetryBackoff: o.MaxRetryBackoff, MaxRetryBackoff: o.MaxRetryBackoff,
ShouldRetry: o.ShouldRetry,
DialTimeout: o.DialTimeout, DialTimeout: o.DialTimeout,
ReadTimeout: o.ReadTimeout, ReadTimeout: o.ReadTimeout,