From eda1f9c6ad9a2b84f5e175732bbeb096a51b93c0 Mon Sep 17 00:00:00 2001
From: Pavlov Aleksey <irishgreenhedgehog@gmail.com>
Date: Mon, 14 Sep 2020 21:27:26 +0300
Subject: [PATCH 1/3] add context cancelation support for blocking operations

---
 redis.go      | 48 ++++++++++++++++++++++++++++++++++++++++++++----
 redis_test.go | 25 +++++++++++++++++++++++++
 2 files changed, 69 insertions(+), 4 deletions(-)

diff --git a/redis.go b/redis.go
index 617bf973..472b3247 100644
--- a/redis.go
+++ b/redis.go
@@ -49,7 +49,13 @@ func (hs hooks) process(
 	ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
 ) error {
 	if len(hs.hooks) == 0 {
-		return fn(ctx, cmd)
+		return hs.withContext(ctx, func() error {
+			err := fn(ctx, cmd)
+			if err != nil {
+				cmd.SetErr(err)
+			}
+			return err
+		})
 	}
 
 	var hookIndex int
@@ -63,7 +69,13 @@ func (hs hooks) process(
 	}
 
 	if retErr == nil {
-		retErr = fn(ctx, cmd)
+		retErr = hs.withContext(ctx, func() error {
+			err := fn(ctx, cmd)
+			if err != nil {
+				cmd.SetErr(err)
+			}
+			return err
+		})
 	}
 
 	for hookIndex--; hookIndex >= 0; hookIndex-- {
@@ -80,7 +92,13 @@ func (hs hooks) processPipeline(
 	ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
 ) error {
 	if len(hs.hooks) == 0 {
-		return fn(ctx, cmds)
+		return hs.withContext(ctx, func() error {
+			err := fn(ctx, cmds)
+			if err != nil {
+				setCmdsErr(cmds, err)
+			}
+			return err
+		})
 	}
 
 	var hookIndex int
@@ -94,7 +112,13 @@ func (hs hooks) processPipeline(
 	}
 
 	if retErr == nil {
-		retErr = fn(ctx, cmds)
+		retErr = hs.withContext(ctx, func() error {
+			err := fn(ctx, cmds)
+			if err != nil {
+				setCmdsErr(cmds, err)
+			}
+			return err
+		})
 	}
 
 	for hookIndex--; hookIndex >= 0; hookIndex-- {
@@ -114,6 +138,22 @@ func (hs hooks) processTxPipeline(
 	return hs.processPipeline(ctx, cmds, fn)
 }
 
+func (hs hooks) withContext(ctx context.Context, fn func() error) error {
+	if ctx.Done() == nil {
+		return fn()
+	}
+
+	errc := make(chan error, 1)
+	go func() { errc <- fn() }()
+
+	select {
+	case <-ctx.Done():
+		return ctx.Err()
+	case err := <-errc:
+		return err
+	}
+}
+
 //------------------------------------------------------------------------------
 
 type baseClient struct {
diff --git a/redis_test.go b/redis_test.go
index 044a7c3e..c00afc0d 100644
--- a/redis_test.go
+++ b/redis_test.go
@@ -389,3 +389,28 @@ var _ = Describe("Client OnConnect", func() {
 		Expect(name).To(Equal("on_connect"))
 	})
 })
+
+var _ = Describe("Client context cancelation", func() {
+	var opt *redis.Options
+	var client *redis.Client
+
+	BeforeEach(func() {
+		opt = redisOptions()
+		opt.ReadTimeout = -1
+		opt.WriteTimeout = -1
+		client = redis.NewClient(opt)
+	})
+
+	AfterEach(func() {
+		Expect(client.Close()).NotTo(HaveOccurred())
+	})
+
+	It("Blocking operation cancelation", func() {
+		ctx, cancel := context.WithCancel(ctx)
+		cancel()
+
+		err := client.BLPop(ctx, 1*time.Second, "test").Err()
+		Expect(err).To(HaveOccurred())
+		Expect(err).To(BeIdenticalTo(context.Canceled))
+	})
+})

From 297e671f5eade43a511e3e42439f43fc7ce1d060 Mon Sep 17 00:00:00 2001
From: Vladimir Mihailenco <vladimir.webdev@gmail.com>
Date: Thu, 17 Sep 2020 11:23:34 +0300
Subject: [PATCH 2/3] Properly propagate context error

---
 redis.go | 61 ++++++++++++++++++++++++++++----------------------------
 1 file changed, 31 insertions(+), 30 deletions(-)

diff --git a/redis.go b/redis.go
index 472b3247..e15da91e 100644
--- a/redis.go
+++ b/redis.go
@@ -49,13 +49,13 @@ func (hs hooks) process(
 	ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
 ) error {
 	if len(hs.hooks) == 0 {
-		return hs.withContext(ctx, func() error {
-			err := fn(ctx, cmd)
-			if err != nil {
-				cmd.SetErr(err)
-			}
-			return err
+		err, canceled := hs.withContext(ctx, func() error {
+			return fn(ctx, cmd)
 		})
+		if canceled {
+			cmd.SetErr(err)
+		}
+		return err
 	}
 
 	var hookIndex int
@@ -69,13 +69,13 @@ func (hs hooks) process(
 	}
 
 	if retErr == nil {
-		retErr = hs.withContext(ctx, func() error {
-			err := fn(ctx, cmd)
-			if err != nil {
-				cmd.SetErr(err)
-			}
-			return err
+		var canceled bool
+		retErr, canceled = hs.withContext(ctx, func() error {
+			return fn(ctx, cmd)
 		})
+		if canceled {
+			cmd.SetErr(retErr)
+		}
 	}
 
 	for hookIndex--; hookIndex >= 0; hookIndex-- {
@@ -92,13 +92,13 @@ func (hs hooks) processPipeline(
 	ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
 ) error {
 	if len(hs.hooks) == 0 {
-		return hs.withContext(ctx, func() error {
-			err := fn(ctx, cmds)
-			if err != nil {
-				setCmdsErr(cmds, err)
-			}
-			return err
+		err, canceled := hs.withContext(ctx, func() error {
+			return fn(ctx, cmds)
 		})
+		if canceled {
+			setCmdsErr(cmds, err)
+		}
+		return err
 	}
 
 	var hookIndex int
@@ -112,13 +112,13 @@ func (hs hooks) processPipeline(
 	}
 
 	if retErr == nil {
-		retErr = hs.withContext(ctx, func() error {
-			err := fn(ctx, cmds)
-			if err != nil {
-				setCmdsErr(cmds, err)
-			}
-			return err
+		var canceled bool
+		retErr, canceled = hs.withContext(ctx, func() error {
+			return fn(ctx, cmds)
 		})
+		if canceled {
+			setCmdsErr(cmds, retErr)
+		}
 	}
 
 	for hookIndex--; hookIndex >= 0; hookIndex-- {
@@ -138,19 +138,20 @@ func (hs hooks) processTxPipeline(
 	return hs.processPipeline(ctx, cmds, fn)
 }
 
-func (hs hooks) withContext(ctx context.Context, fn func() error) error {
-	if ctx.Done() == nil {
-		return fn()
+func (hs hooks) withContext(ctx context.Context, fn func() error) (_ error, canceled bool) {
+	done := ctx.Done()
+	if done == nil {
+		return fn(), false
 	}
 
 	errc := make(chan error, 1)
 	go func() { errc <- fn() }()
 
 	select {
-	case <-ctx.Done():
-		return ctx.Err()
+	case <-done:
+		return ctx.Err(), true
 	case err := <-errc:
-		return err
+		return err, false
 	}
 }
 

From c5d4b71f6661a2236b52bfbe5ed09bf9d3198319 Mon Sep 17 00:00:00 2001
From: Vladimir Mihailenco <vladimir.webdev@gmail.com>
Date: Thu, 17 Sep 2020 12:27:16 +0300
Subject: [PATCH 3/3] Fix race

---
 cluster.go      |  12 +---
 command.go      | 162 +++++++++++++++++++++++-------------------------
 command_test.go |   2 +-
 redis.go        |  42 ++++---------
 ring.go         |  15 +----
 sentinel.go     |   3 +-
 6 files changed, 96 insertions(+), 140 deletions(-)

diff --git a/cluster.go b/cluster.go
index be8217b1..d17c7479 100644
--- a/cluster.go
+++ b/cluster.go
@@ -751,15 +751,6 @@ func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error {
 }
 
 func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error {
-	err := c._process(ctx, cmd)
-	if err != nil {
-		cmd.SetErr(err)
-		return err
-	}
-	return nil
-}
-
-func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
 	cmdInfo := c.cmdInfo(cmd.Name())
 	slot := c.cmdSlot(cmd)
 
@@ -1197,9 +1188,12 @@ func (c *ClusterClient) pipelineReadCmds(
 ) error {
 	for _, cmd := range cmds {
 		err := cmd.readReply(rd)
+		cmd.SetErr(err)
+
 		if err == nil {
 			continue
 		}
+
 		if c.checkMovedErr(ctx, cmd, err, failedCmds) {
 			continue
 		}
diff --git a/command.go b/command.go
index 55a5bd5c..4879cfa0 100644
--- a/command.go
+++ b/command.go
@@ -299,9 +299,9 @@ func (cmd *Cmd) Bool() (bool, error) {
 	}
 }
 
-func (cmd *Cmd) readReply(rd *proto.Reader) error {
-	cmd.val, cmd.err = rd.ReadReply(sliceParser)
-	return cmd.err
+func (cmd *Cmd) readReply(rd *proto.Reader) (err error) {
+	cmd.val, err = rd.ReadReply(sliceParser)
+	return err
 }
 
 // sliceParser implements proto.MultiBulkParse.
@@ -357,10 +357,9 @@ func (cmd *SliceCmd) String() string {
 }
 
 func (cmd *SliceCmd) readReply(rd *proto.Reader) error {
-	var v interface{}
-	v, cmd.err = rd.ReadArrayReply(sliceParser)
-	if cmd.err != nil {
-		return cmd.err
+	v, err := rd.ReadArrayReply(sliceParser)
+	if err != nil {
+		return err
 	}
 	cmd.val = v.([]interface{})
 	return nil
@@ -397,9 +396,9 @@ func (cmd *StatusCmd) String() string {
 	return cmdString(cmd, cmd.val)
 }
 
-func (cmd *StatusCmd) readReply(rd *proto.Reader) error {
-	cmd.val, cmd.err = rd.ReadString()
-	return cmd.err
+func (cmd *StatusCmd) readReply(rd *proto.Reader) (err error) {
+	cmd.val, err = rd.ReadString()
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -437,9 +436,9 @@ func (cmd *IntCmd) String() string {
 	return cmdString(cmd, cmd.val)
 }
 
-func (cmd *IntCmd) readReply(rd *proto.Reader) error {
-	cmd.val, cmd.err = rd.ReadIntReply()
-	return cmd.err
+func (cmd *IntCmd) readReply(rd *proto.Reader) (err error) {
+	cmd.val, err = rd.ReadIntReply()
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -474,7 +473,7 @@ func (cmd *IntSliceCmd) String() string {
 }
 
 func (cmd *IntSliceCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make([]int64, n)
 		for i := 0; i < len(cmd.val); i++ {
 			num, err := rd.ReadIntReply()
@@ -485,7 +484,7 @@ func (cmd *IntSliceCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -522,10 +521,9 @@ func (cmd *DurationCmd) String() string {
 }
 
 func (cmd *DurationCmd) readReply(rd *proto.Reader) error {
-	var n int64
-	n, cmd.err = rd.ReadIntReply()
-	if cmd.err != nil {
-		return cmd.err
+	n, err := rd.ReadIntReply()
+	if err != nil {
+		return err
 	}
 	switch n {
 	// -2 if the key does not exist
@@ -570,7 +568,7 @@ func (cmd *TimeCmd) String() string {
 }
 
 func (cmd *TimeCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		if n != 2 {
 			return nil, fmt.Errorf("got %d elements, expected 2", n)
 		}
@@ -588,7 +586,7 @@ func (cmd *TimeCmd) readReply(rd *proto.Reader) error {
 		cmd.val = time.Unix(sec, microsec*1000)
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -623,17 +621,15 @@ func (cmd *BoolCmd) String() string {
 }
 
 func (cmd *BoolCmd) readReply(rd *proto.Reader) error {
-	var v interface{}
-	v, cmd.err = rd.ReadReply(nil)
+	v, err := rd.ReadReply(nil)
 	// `SET key value NX` returns nil when key already exists. But
 	// `SETNX key value` returns bool (0/1). So convert nil to bool.
-	if cmd.err == Nil {
+	if err == Nil {
 		cmd.val = false
-		cmd.err = nil
 		return nil
 	}
-	if cmd.err != nil {
-		return cmd.err
+	if err != nil {
+		return err
 	}
 	switch v := v.(type) {
 	case int64:
@@ -643,8 +639,7 @@ func (cmd *BoolCmd) readReply(rd *proto.Reader) error {
 		cmd.val = v == "OK"
 		return nil
 	default:
-		cmd.err = fmt.Errorf("got %T, wanted int64 or string", v)
-		return cmd.err
+		return fmt.Errorf("got %T, wanted int64 or string", v)
 	}
 }
 
@@ -736,9 +731,9 @@ func (cmd *StringCmd) String() string {
 	return cmdString(cmd, cmd.val)
 }
 
-func (cmd *StringCmd) readReply(rd *proto.Reader) error {
-	cmd.val, cmd.err = rd.ReadString()
-	return cmd.err
+func (cmd *StringCmd) readReply(rd *proto.Reader) (err error) {
+	cmd.val, err = rd.ReadString()
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -772,9 +767,9 @@ func (cmd *FloatCmd) String() string {
 	return cmdString(cmd, cmd.val)
 }
 
-func (cmd *FloatCmd) readReply(rd *proto.Reader) error {
-	cmd.val, cmd.err = rd.ReadFloatReply()
-	return cmd.err
+func (cmd *FloatCmd) readReply(rd *proto.Reader) (err error) {
+	cmd.val, err = rd.ReadFloatReply()
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -813,7 +808,7 @@ func (cmd *StringSliceCmd) ScanSlice(container interface{}) error {
 }
 
 func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make([]string, n)
 		for i := 0; i < len(cmd.val); i++ {
 			switch s, err := rd.ReadString(); {
@@ -827,7 +822,7 @@ func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -862,7 +857,7 @@ func (cmd *BoolSliceCmd) String() string {
 }
 
 func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make([]bool, n)
 		for i := 0; i < len(cmd.val); i++ {
 			n, err := rd.ReadIntReply()
@@ -873,7 +868,7 @@ func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -908,7 +903,7 @@ func (cmd *StringStringMapCmd) String() string {
 }
 
 func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make(map[string]string, n/2)
 		for i := int64(0); i < n; i += 2 {
 			key, err := rd.ReadString()
@@ -925,7 +920,7 @@ func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -960,7 +955,7 @@ func (cmd *StringIntMapCmd) String() string {
 }
 
 func (cmd *StringIntMapCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make(map[string]int64, n/2)
 		for i := int64(0); i < n; i += 2 {
 			key, err := rd.ReadString()
@@ -977,7 +972,7 @@ func (cmd *StringIntMapCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -1012,7 +1007,7 @@ func (cmd *StringStructMapCmd) String() string {
 }
 
 func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make(map[string]struct{}, n)
 		for i := int64(0); i < n; i++ {
 			key, err := rd.ReadString()
@@ -1023,7 +1018,7 @@ func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -1063,10 +1058,9 @@ func (cmd *XMessageSliceCmd) String() string {
 }
 
 func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) error {
-	var v interface{}
-	v, cmd.err = rd.ReadArrayReply(xMessageSliceParser)
-	if cmd.err != nil {
-		return cmd.err
+	v, err := rd.ReadArrayReply(xMessageSliceParser)
+	if err != nil {
+		return err
 	}
 	cmd.val = v.([]XMessage)
 	return nil
@@ -1163,7 +1157,7 @@ func (cmd *XStreamSliceCmd) String() string {
 }
 
 func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make([]XStream, n)
 		for i := 0; i < len(cmd.val); i++ {
 			i := i
@@ -1194,7 +1188,7 @@ func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -1235,7 +1229,7 @@ func (cmd *XPendingCmd) String() string {
 }
 
 func (cmd *XPendingCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		if n != 4 {
 			return nil, fmt.Errorf("got %d, wanted 4", n)
 		}
@@ -1296,7 +1290,7 @@ func (cmd *XPendingCmd) readReply(rd *proto.Reader) error {
 
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -1337,7 +1331,7 @@ func (cmd *XPendingExtCmd) String() string {
 }
 
 func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make([]XPendingExt, 0, n)
 		for i := int64(0); i < n; i++ {
 			_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
@@ -1379,7 +1373,7 @@ func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -1420,18 +1414,17 @@ func (cmd *XInfoGroupsCmd) String() string {
 }
 
 func (cmd *XInfoGroupsCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(
-		func(rd *proto.Reader, n int64) (interface{}, error) {
-			for i := int64(0); i < n; i++ {
-				v, err := rd.ReadReply(xGroupInfoParser)
-				if err != nil {
-					return nil, err
-				}
-				cmd.val = append(cmd.val, v.(XInfoGroups))
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+		for i := int64(0); i < n; i++ {
+			v, err := rd.ReadReply(xGroupInfoParser)
+			if err != nil {
+				return nil, err
 			}
-			return nil, nil
-		})
-	return nil
+			cmd.val = append(cmd.val, v.(XInfoGroups))
+		}
+		return nil, nil
+	})
+	return err
 }
 
 func xGroupInfoParser(rd *proto.Reader, n int64) (interface{}, error) {
@@ -1507,7 +1500,7 @@ func (cmd *ZSliceCmd) String() string {
 }
 
 func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make([]Z, n/2)
 		for i := 0; i < len(cmd.val); i++ {
 			member, err := rd.ReadString()
@@ -1527,7 +1520,7 @@ func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -1562,7 +1555,7 @@ func (cmd *ZWithKeyCmd) String() string {
 }
 
 func (cmd *ZWithKeyCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		if n != 3 {
 			return nil, fmt.Errorf("got %d elements, expected 3", n)
 		}
@@ -1587,7 +1580,7 @@ func (cmd *ZWithKeyCmd) readReply(rd *proto.Reader) error {
 
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -1625,9 +1618,9 @@ func (cmd *ScanCmd) String() string {
 	return cmdString(cmd, cmd.page)
 }
 
-func (cmd *ScanCmd) readReply(rd *proto.Reader) error {
-	cmd.page, cmd.cursor, cmd.err = rd.ReadScanReply()
-	return cmd.err
+func (cmd *ScanCmd) readReply(rd *proto.Reader) (err error) {
+	cmd.page, cmd.cursor, err = rd.ReadScanReply()
+	return err
 }
 
 // Iterator creates a new ScanIterator.
@@ -1680,7 +1673,7 @@ func (cmd *ClusterSlotsCmd) String() string {
 }
 
 func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make([]ClusterSlot, n)
 		for i := 0; i < len(cmd.val); i++ {
 			n, err := rd.ReadArrayLen()
@@ -1742,7 +1735,7 @@ func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -1834,10 +1827,9 @@ func (cmd *GeoLocationCmd) String() string {
 }
 
 func (cmd *GeoLocationCmd) readReply(rd *proto.Reader) error {
-	var v interface{}
-	v, cmd.err = rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q))
-	if cmd.err != nil {
-		return cmd.err
+	v, err := rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q))
+	if err != nil {
+		return err
 	}
 	cmd.locations = v.([]GeoLocation)
 	return nil
@@ -1947,7 +1939,7 @@ func (cmd *GeoPosCmd) String() string {
 }
 
 func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make([]*GeoPos, n)
 		for i := 0; i < len(cmd.val); i++ {
 			i := i
@@ -1978,7 +1970,7 @@ func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 //------------------------------------------------------------------------------
@@ -2024,7 +2016,7 @@ func (cmd *CommandsInfoCmd) String() string {
 }
 
 func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make(map[string]*CommandInfo, n)
 		for i := int64(0); i < n; i++ {
 			v, err := rd.ReadReply(commandInfoParser)
@@ -2036,7 +2028,7 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
 
 func commandInfoParser(rd *proto.Reader, n int64) (interface{}, error) {
@@ -2211,7 +2203,7 @@ func (cmd *SlowLogCmd) String() string {
 }
 
 func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error {
-	_, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
+	_, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) {
 		cmd.val = make([]SlowLog, n)
 		for i := 0; i < len(cmd.val); i++ {
 			n, err := rd.ReadArrayLen()
@@ -2281,5 +2273,5 @@ func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error {
 		}
 		return nil, nil
 	})
-	return cmd.err
+	return err
 }
diff --git a/command_test.go b/command_test.go
index d80b7444..d110d0c3 100644
--- a/command_test.go
+++ b/command_test.go
@@ -86,7 +86,7 @@ var _ = Describe("Cmd", func() {
 		Expect(tm2).To(BeTemporally("==", tm))
 	})
 
-	It("allow to set custom error", func() {
+	It("allows to set custom error", func() {
 		e := errors.New("custom error")
 		cmd := redis.Cmd{}
 		cmd.SetErr(e)
diff --git a/redis.go b/redis.go
index e15da91e..0921359e 100644
--- a/redis.go
+++ b/redis.go
@@ -49,12 +49,10 @@ func (hs hooks) process(
 	ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
 ) error {
 	if len(hs.hooks) == 0 {
-		err, canceled := hs.withContext(ctx, func() error {
+		err := hs.withContext(ctx, func() error {
 			return fn(ctx, cmd)
 		})
-		if canceled {
-			cmd.SetErr(err)
-		}
+		cmd.SetErr(err)
 		return err
 	}
 
@@ -69,13 +67,10 @@ func (hs hooks) process(
 	}
 
 	if retErr == nil {
-		var canceled bool
-		retErr, canceled = hs.withContext(ctx, func() error {
+		retErr = hs.withContext(ctx, func() error {
 			return fn(ctx, cmd)
 		})
-		if canceled {
-			cmd.SetErr(retErr)
-		}
+		cmd.SetErr(retErr)
 	}
 
 	for hookIndex--; hookIndex >= 0; hookIndex-- {
@@ -92,12 +87,9 @@ func (hs hooks) processPipeline(
 	ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
 ) error {
 	if len(hs.hooks) == 0 {
-		err, canceled := hs.withContext(ctx, func() error {
+		err := hs.withContext(ctx, func() error {
 			return fn(ctx, cmds)
 		})
-		if canceled {
-			setCmdsErr(cmds, err)
-		}
 		return err
 	}
 
@@ -112,13 +104,9 @@ func (hs hooks) processPipeline(
 	}
 
 	if retErr == nil {
-		var canceled bool
-		retErr, canceled = hs.withContext(ctx, func() error {
+		retErr = hs.withContext(ctx, func() error {
 			return fn(ctx, cmds)
 		})
-		if canceled {
-			setCmdsErr(cmds, retErr)
-		}
 	}
 
 	for hookIndex--; hookIndex >= 0; hookIndex-- {
@@ -138,10 +126,10 @@ func (hs hooks) processTxPipeline(
 	return hs.processPipeline(ctx, cmds, fn)
 }
 
-func (hs hooks) withContext(ctx context.Context, fn func() error) (_ error, canceled bool) {
+func (hs hooks) withContext(ctx context.Context, fn func() error) error {
 	done := ctx.Done()
 	if done == nil {
-		return fn(), false
+		return fn()
 	}
 
 	errc := make(chan error, 1)
@@ -149,9 +137,9 @@ func (hs hooks) withContext(ctx context.Context, fn func() error) (_ error, canc
 
 	select {
 	case <-done:
-		return ctx.Err(), true
+		return ctx.Err()
 	case err := <-errc:
-		return err, false
+		return err
 	}
 }
 
@@ -324,15 +312,6 @@ func (c *baseClient) withConn(
 }
 
 func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
-	err := c._process(ctx, cmd)
-	if err != nil {
-		cmd.SetErr(err)
-		return err
-	}
-	return nil
-}
-
-func (c *baseClient) _process(ctx context.Context, cmd Cmder) error {
 	var lastErr error
 	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
 		attempt := attempt
@@ -476,6 +455,7 @@ func (c *baseClient) pipelineProcessCmds(
 func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
 	for _, cmd := range cmds {
 		err := cmd.readReply(rd)
+		cmd.SetErr(err)
 		if err != nil && !isRedisError(err) {
 			return err
 		}
diff --git a/ring.go b/ring.go
index e0b433e1..86fce524 100644
--- a/ring.go
+++ b/ring.go
@@ -588,15 +588,6 @@ func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) {
 }
 
 func (c *Ring) process(ctx context.Context, cmd Cmder) error {
-	err := c._process(ctx, cmd)
-	if err != nil {
-		cmd.SetErr(err)
-		return err
-	}
-	return nil
-}
-
-func (c *Ring) _process(ctx context.Context, cmd Cmder) error {
 	var lastErr error
 	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
 		if attempt > 0 {
@@ -694,11 +685,9 @@ func (c *Ring) processShardPipeline(
 	}
 
 	if tx {
-		err = shard.Client.processTxPipeline(ctx, cmds)
-	} else {
-		err = shard.Client.processPipeline(ctx, cmds)
+		return shard.Client.processTxPipeline(ctx, cmds)
 	}
-	return err
+	return shard.Client.processPipeline(ctx, cmds)
 }
 
 func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
diff --git a/sentinel.go b/sentinel.go
index f58970f8..f911622a 100644
--- a/sentinel.go
+++ b/sentinel.go
@@ -224,6 +224,7 @@ func masterSlaveDialer(
 // SentinelClient is a client for a Redis Sentinel.
 type SentinelClient struct {
 	*baseClient
+	hooks
 	ctx context.Context
 }
 
@@ -253,7 +254,7 @@ func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient {
 }
 
 func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
-	return c.baseClient.process(ctx, cmd)
+	return c.hooks.process(ctx, cmd, c.baseClient.process)
 }
 
 func (c *SentinelClient) pubSub() *PubSub {