diff --git a/cluster.go b/cluster.go index b693fbe..6d311f0 100644 --- a/cluster.go +++ b/cluster.go @@ -754,7 +754,9 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { var err error if attempt > 0 { - time.Sleep(c.retryBackoff(attempt)) + if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + return err + } } if node == nil { @@ -1049,7 +1051,9 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { if attempt > 0 { - time.Sleep(c.retryBackoff(attempt)) + if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + return err + } } failedCmds := newCmdsMap() @@ -1254,7 +1258,9 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { if attempt > 0 { - time.Sleep(c.retryBackoff(attempt)) + if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + return err + } } failedCmds := newCmdsMap() @@ -1376,6 +1382,10 @@ func (c *ClusterClient) txPipelineReadQueued( } func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { + return c.WatchContext(c.ctx, fn, keys...) +} + +func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, keys ...string) error { if len(keys) == 0 { return fmt.Errorf("redis: Watch requires at least one key") } @@ -1395,10 +1405,12 @@ func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { if attempt > 0 { - time.Sleep(c.retryBackoff(attempt)) + if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + return err + } } - err = node.Client.Watch(fn, keys...) + err = node.Client.WatchContext(ctx, fn, keys...) if err == nil { break } diff --git a/internal/error.go b/internal/error.go index f9c9ebb..cd90d0d 100644 --- a/internal/error.go +++ b/internal/error.go @@ -1,6 +1,7 @@ package internal import ( + "context" "io" "net" "strings" @@ -9,10 +10,10 @@ import ( ) func IsRetryableError(err error, retryTimeout bool) bool { - if err == nil { + switch err { + case nil, context.Canceled, context.DeadlineExceeded: return false - } - if err == io.EOF { + case io.EOF: return true } if netErr, ok := err.(net.Error); ok { diff --git a/internal/util.go b/internal/util.go index ffd2353..6e47140 100644 --- a/internal/util.go +++ b/internal/util.go @@ -1,6 +1,23 @@ package internal -import "github.com/go-redis/redis/internal/util" +import ( + "context" + "time" + + "github.com/go-redis/redis/internal/util" +) + +func Sleep(ctx context.Context, dur time.Duration) error { + t := time.NewTimer(dur) + defer t.Stop() + + select { + case <-t.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} func ToLower(s string) string { if isLower(s) { diff --git a/redis.go b/redis.go index abdda7e..e2bf5e9 100644 --- a/redis.go +++ b/redis.go @@ -244,7 +244,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { func (c *baseClient) process(ctx context.Context, cmd Cmder) error { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { - time.Sleep(c.retryBackoff(attempt)) + if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + return err + } } cn, err := c.getConn(ctx) @@ -331,7 +333,9 @@ func (c *baseClient) generalProcessPipeline( ) error { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { - time.Sleep(c.retryBackoff(attempt)) + if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + return err + } } cn, err := c.getConn(ctx) diff --git a/ring.go b/ring.go index 2cdb01a..94ea77b 100644 --- a/ring.go +++ b/ring.go @@ -553,7 +553,9 @@ func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) { func (c *Ring) process(ctx context.Context, cmd Cmder) error { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { - time.Sleep(c.retryBackoff(attempt)) + if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + return err + } } shard, err := c.cmdShard(cmd) @@ -626,7 +628,9 @@ func (c *Ring) generalProcessPipeline( for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { - time.Sleep(c.retryBackoff(attempt)) + if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + return err + } } var mu sync.Mutex diff --git a/tx.go b/tx.go index 38a8855..327f1ab 100644 --- a/tx.go +++ b/tx.go @@ -22,13 +22,13 @@ type Tx struct { ctx context.Context } -func (c *Client) newTx() *Tx { +func (c *Client) newTx(ctx context.Context) *Tx { tx := Tx{ baseClient: baseClient{ opt: c.opt, connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true), }, - ctx: c.ctx, + ctx: ctx, } tx.init() return &tx @@ -65,7 +65,11 @@ func (c *Tx) ProcessContext(ctx context.Context, cmd Cmder) error { // // The transaction is automatically closed when fn exits. func (c *Client) Watch(fn func(*Tx) error, keys ...string) error { - tx := c.newTx() + return c.WatchContext(c.ctx, fn, keys...) +} + +func (c *Client) WatchContext(ctx context.Context, fn func(*Tx) error, keys ...string) error { + tx := c.newTx(ctx) if len(keys) > 0 { if err := tx.Watch(keys...).Err(); err != nil { _ = tx.Close()