diff --git a/error.go b/error.go index 9fe1376..2ed5dc6 100644 --- a/error.go +++ b/error.go @@ -65,8 +65,11 @@ func isRedisError(err error) bool { } func isBadConn(err error, allowTimeout bool) bool { - if err == nil { + switch err { + case nil: return false + case context.Canceled, context.DeadlineExceeded: + return true } if isRedisError(err) { diff --git a/main_test.go b/main_test.go index 1adbcbe..636518e 100644 --- a/main_test.go +++ b/main_test.go @@ -175,8 +175,8 @@ func redisRingOptions() *redis.RingOptions { func performAsync(n int, cbs ...func(int)) *sync.WaitGroup { var wg sync.WaitGroup for _, cb := range cbs { + wg.Add(n) for i := 0; i < n; i++ { - wg.Add(1) go func(cb func(int), i int) { defer GinkgoRecover() defer wg.Done() diff --git a/race_test.go b/race_test.go index 606999b..72b89e0 100644 --- a/race_test.go +++ b/race_test.go @@ -2,6 +2,7 @@ package redis_test import ( "bytes" + "context" "fmt" "net" "strconv" @@ -295,6 +296,25 @@ var _ = Describe("races", func() { Expect(err).NotTo(HaveOccurred()) }) }) + + It("should abort on context timeout", func() { + opt := redisClusterOptions() + client := cluster.newClusterClient(ctx, opt) + + ctx, cancel := context.WithCancel(context.Background()) + + wg := performAsync(C, func(_ int) { + _, err := client.XRead(ctx, &redis.XReadArgs{ + Streams: []string{"test", "$"}, + Block: 1 * time.Second, + }).Result() + Expect(err).To(Equal(context.Canceled)) + }) + + time.Sleep(10 * time.Millisecond) + cancel() + wg.Wait() + }) }) var _ = Describe("cluster races", func() { diff --git a/redis.go b/redis.go index c5ba9e0..88cced7 100644 --- a/redis.go +++ b/redis.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync/atomic" "time" "github.com/go-redis/redis/v8/internal" @@ -130,20 +131,7 @@ func (hs hooks) processTxPipeline( } func (hs hooks) withContext(ctx context.Context, fn func() error) error { - done := ctx.Done() - if done == nil { - return fn() - } - - errc := make(chan error, 1) - go func() { errc <- fn() }() - - select { - case <-done: - return ctx.Err() - case err := <-errc: - return err - } + return fn() } //------------------------------------------------------------------------------ @@ -316,8 +304,24 @@ func (c *baseClient) withConn( c.releaseConn(ctx, cn, err) }() - err = fn(ctx, cn) - return err + done := ctx.Done() + if done == nil { + err = fn(ctx, cn) + return err + } + + errc := make(chan error, 1) + go func() { errc <- fn(ctx, cn) }() + + select { + case <-done: + _ = cn.Close() + + err = ctx.Err() + return err + case err = <-errc: + return err + } }) } @@ -334,7 +338,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error { } } - retryTimeout := true + retryTimeout := uint32(1) err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmd(wr, cmd) @@ -345,7 +349,9 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error { err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) if err != nil { - retryTimeout = cmd.readTimeout() == nil + if cmd.readTimeout() == nil { + atomic.StoreUint32(&retryTimeout, 1) + } return err } @@ -354,7 +360,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error { if err == nil { return nil } - retry = shouldRetry(err, retryTimeout) + retry = shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) return err }) if err == nil || !retry {