Cancel sleep when context is cancelled

This commit is contained in:
Vladimir Mihailenco 2019-07-30 12:13:00 +03:00
parent 6d8db67ef5
commit c837612911
6 changed files with 58 additions and 16 deletions

View File

@ -754,7 +754,9 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error {
var err error var err error
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
} }
if node == nil { if node == nil {
@ -1049,7 +1051,9 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
} }
failedCmds := newCmdsMap() failedCmds := newCmdsMap()
@ -1254,7 +1258,9 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
} }
failedCmds := newCmdsMap() failedCmds := newCmdsMap()
@ -1376,6 +1382,10 @@ func (c *ClusterClient) txPipelineReadQueued(
} }
func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
return c.WatchContext(c.ctx, fn, keys...)
}
func (c *ClusterClient) WatchContext(ctx context.Context, fn func(*Tx) error, keys ...string) error {
if len(keys) == 0 { if len(keys) == 0 {
return fmt.Errorf("redis: Watch requires at least one key") return fmt.Errorf("redis: Watch requires at least one key")
} }
@ -1395,10 +1405,12 @@ func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
} }
err = node.Client.Watch(fn, keys...) err = node.Client.WatchContext(ctx, fn, keys...)
if err == nil { if err == nil {
break break
} }

View File

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"io" "io"
"net" "net"
"strings" "strings"
@ -9,10 +10,10 @@ import (
) )
func IsRetryableError(err error, retryTimeout bool) bool { func IsRetryableError(err error, retryTimeout bool) bool {
if err == nil { switch err {
case nil, context.Canceled, context.DeadlineExceeded:
return false return false
} case io.EOF:
if err == io.EOF {
return true return true
} }
if netErr, ok := err.(net.Error); ok { if netErr, ok := err.(net.Error); ok {

View File

@ -1,6 +1,23 @@
package internal package internal
import "github.com/go-redis/redis/internal/util" import (
"context"
"time"
"github.com/go-redis/redis/internal/util"
)
func Sleep(ctx context.Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func ToLower(s string) string { func ToLower(s string) string {
if isLower(s) { if isLower(s) {

View File

@ -244,7 +244,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
func (c *baseClient) process(ctx context.Context, cmd Cmder) error { func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
} }
cn, err := c.getConn(ctx) cn, err := c.getConn(ctx)
@ -331,7 +333,9 @@ func (c *baseClient) generalProcessPipeline(
) error { ) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
} }
cn, err := c.getConn(ctx) cn, err := c.getConn(ctx)

View File

@ -553,7 +553,9 @@ func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) {
func (c *Ring) process(ctx context.Context, cmd Cmder) error { func (c *Ring) process(ctx context.Context, cmd Cmder) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
} }
shard, err := c.cmdShard(cmd) shard, err := c.cmdShard(cmd)
@ -626,7 +628,9 @@ func (c *Ring) generalProcessPipeline(
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(c.retryBackoff(attempt)) if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
} }
var mu sync.Mutex var mu sync.Mutex

10
tx.go
View File

@ -22,13 +22,13 @@ type Tx struct {
ctx context.Context ctx context.Context
} }
func (c *Client) newTx() *Tx { func (c *Client) newTx(ctx context.Context) *Tx {
tx := Tx{ tx := Tx{
baseClient: baseClient{ baseClient: baseClient{
opt: c.opt, opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true), connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true),
}, },
ctx: c.ctx, ctx: ctx,
} }
tx.init() tx.init()
return &tx return &tx
@ -65,7 +65,11 @@ func (c *Tx) ProcessContext(ctx context.Context, cmd Cmder) error {
// //
// The transaction is automatically closed when fn exits. // The transaction is automatically closed when fn exits.
func (c *Client) Watch(fn func(*Tx) error, keys ...string) error { func (c *Client) Watch(fn func(*Tx) error, keys ...string) error {
tx := c.newTx() return c.WatchContext(c.ctx, fn, keys...)
}
func (c *Client) WatchContext(ctx context.Context, fn func(*Tx) error, keys ...string) error {
tx := c.newTx(ctx)
if len(keys) > 0 { if len(keys) > 0 {
if err := tx.Watch(keys...).Err(); err != nil { if err := tx.Watch(keys...).Err(); err != nil {
_ = tx.Close() _ = tx.Close()