diff --git a/redis.go b/redis.go index 617bf973..472b3247 100644 --- a/redis.go +++ b/redis.go @@ -49,7 +49,13 @@ func (hs hooks) process( ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, ) error { if len(hs.hooks) == 0 { - return fn(ctx, cmd) + return hs.withContext(ctx, func() error { + err := fn(ctx, cmd) + if err != nil { + cmd.SetErr(err) + } + return err + }) } var hookIndex int @@ -63,7 +69,13 @@ func (hs hooks) process( } if retErr == nil { - retErr = fn(ctx, cmd) + retErr = hs.withContext(ctx, func() error { + err := fn(ctx, cmd) + if err != nil { + cmd.SetErr(err) + } + return err + }) } for hookIndex--; hookIndex >= 0; hookIndex-- { @@ -80,7 +92,13 @@ func (hs hooks) processPipeline( ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, ) error { if len(hs.hooks) == 0 { - return fn(ctx, cmds) + return hs.withContext(ctx, func() error { + err := fn(ctx, cmds) + if err != nil { + setCmdsErr(cmds, err) + } + return err + }) } var hookIndex int @@ -94,7 +112,13 @@ func (hs hooks) processPipeline( } if retErr == nil { - retErr = fn(ctx, cmds) + retErr = hs.withContext(ctx, func() error { + err := fn(ctx, cmds) + if err != nil { + setCmdsErr(cmds, err) + } + return err + }) } for hookIndex--; hookIndex >= 0; hookIndex-- { @@ -114,6 +138,22 @@ func (hs hooks) processTxPipeline( return hs.processPipeline(ctx, cmds, fn) } +func (hs hooks) withContext(ctx context.Context, fn func() error) error { + if ctx.Done() == nil { + return fn() + } + + errc := make(chan error, 1) + go func() { errc <- fn() }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errc: + return err + } +} + //------------------------------------------------------------------------------ type baseClient struct { diff --git a/redis_test.go b/redis_test.go index 044a7c3e..c00afc0d 100644 --- a/redis_test.go +++ b/redis_test.go @@ -389,3 +389,28 @@ var _ = Describe("Client OnConnect", func() { Expect(name).To(Equal("on_connect")) }) }) + +var _ = Describe("Client context cancelation", func() { + var opt *redis.Options + var client *redis.Client + + BeforeEach(func() { + opt = redisOptions() + opt.ReadTimeout = -1 + opt.WriteTimeout = -1 + client = redis.NewClient(opt) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("Blocking operation cancelation", func() { + ctx, cancel := context.WithCancel(ctx) + cancel() + + err := client.BLPop(ctx, 1*time.Second, "test").Err() + Expect(err).To(HaveOccurred()) + Expect(err).To(BeIdenticalTo(context.Canceled)) + }) +})