From eda1f9c6ad9a2b84f5e175732bbeb096a51b93c0 Mon Sep 17 00:00:00 2001 From: Pavlov Aleksey <irishgreenhedgehog@gmail.com> Date: Mon, 14 Sep 2020 21:27:26 +0300 Subject: [PATCH 1/3] add context cancelation support for blocking operations --- redis.go | 48 ++++++++++++++++++++++++++++++++++++++++++++---- redis_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/redis.go b/redis.go index 617bf973..472b3247 100644 --- a/redis.go +++ b/redis.go @@ -49,7 +49,13 @@ func (hs hooks) process( ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, ) error { if len(hs.hooks) == 0 { - return fn(ctx, cmd) + return hs.withContext(ctx, func() error { + err := fn(ctx, cmd) + if err != nil { + cmd.SetErr(err) + } + return err + }) } var hookIndex int @@ -63,7 +69,13 @@ func (hs hooks) process( } if retErr == nil { - retErr = fn(ctx, cmd) + retErr = hs.withContext(ctx, func() error { + err := fn(ctx, cmd) + if err != nil { + cmd.SetErr(err) + } + return err + }) } for hookIndex--; hookIndex >= 0; hookIndex-- { @@ -80,7 +92,13 @@ func (hs hooks) processPipeline( ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, ) error { if len(hs.hooks) == 0 { - return fn(ctx, cmds) + return hs.withContext(ctx, func() error { + err := fn(ctx, cmds) + if err != nil { + setCmdsErr(cmds, err) + } + return err + }) } var hookIndex int @@ -94,7 +112,13 @@ func (hs hooks) processPipeline( } if retErr == nil { - retErr = fn(ctx, cmds) + retErr = hs.withContext(ctx, func() error { + err := fn(ctx, cmds) + if err != nil { + setCmdsErr(cmds, err) + } + return err + }) } for hookIndex--; hookIndex >= 0; hookIndex-- { @@ -114,6 +138,22 @@ func (hs hooks) processTxPipeline( return hs.processPipeline(ctx, cmds, fn) } +func (hs hooks) withContext(ctx context.Context, fn func() error) error { + if ctx.Done() == nil { + return fn() + } + + errc := make(chan error, 1) + go func() { errc <- fn() }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errc: + return err + } +} + //------------------------------------------------------------------------------ type baseClient struct { diff --git a/redis_test.go b/redis_test.go index 044a7c3e..c00afc0d 100644 --- a/redis_test.go +++ b/redis_test.go @@ -389,3 +389,28 @@ var _ = Describe("Client OnConnect", func() { Expect(name).To(Equal("on_connect")) }) }) + +var _ = Describe("Client context cancelation", func() { + var opt *redis.Options + var client *redis.Client + + BeforeEach(func() { + opt = redisOptions() + opt.ReadTimeout = -1 + opt.WriteTimeout = -1 + client = redis.NewClient(opt) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("Blocking operation cancelation", func() { + ctx, cancel := context.WithCancel(ctx) + cancel() + + err := client.BLPop(ctx, 1*time.Second, "test").Err() + Expect(err).To(HaveOccurred()) + Expect(err).To(BeIdenticalTo(context.Canceled)) + }) +}) From 297e671f5eade43a511e3e42439f43fc7ce1d060 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco <vladimir.webdev@gmail.com> Date: Thu, 17 Sep 2020 11:23:34 +0300 Subject: [PATCH 2/3] Properly propagate context error --- redis.go | 61 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/redis.go b/redis.go index 472b3247..e15da91e 100644 --- a/redis.go +++ b/redis.go @@ -49,13 +49,13 @@ func (hs hooks) process( ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, ) error { if len(hs.hooks) == 0 { - return hs.withContext(ctx, func() error { - err := fn(ctx, cmd) - if err != nil { - cmd.SetErr(err) - } - return err + err, canceled := hs.withContext(ctx, func() error { + return fn(ctx, cmd) }) + if canceled { + cmd.SetErr(err) + } + return err } var hookIndex int @@ -69,13 +69,13 @@ func (hs hooks) process( } if retErr == nil { - retErr = hs.withContext(ctx, func() error { - err := fn(ctx, cmd) - if err != nil { - cmd.SetErr(err) - } - return err + var canceled bool + retErr, canceled = hs.withContext(ctx, func() error { + return fn(ctx, cmd) }) + if canceled { + cmd.SetErr(retErr) + } } for hookIndex--; hookIndex >= 0; hookIndex-- { @@ -92,13 +92,13 @@ func (hs hooks) processPipeline( ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, ) error { if len(hs.hooks) == 0 { - return hs.withContext(ctx, func() error { - err := fn(ctx, cmds) - if err != nil { - setCmdsErr(cmds, err) - } - return err + err, canceled := hs.withContext(ctx, func() error { + return fn(ctx, cmds) }) + if canceled { + setCmdsErr(cmds, err) + } + return err } var hookIndex int @@ -112,13 +112,13 @@ func (hs hooks) processPipeline( } if retErr == nil { - retErr = hs.withContext(ctx, func() error { - err := fn(ctx, cmds) - if err != nil { - setCmdsErr(cmds, err) - } - return err + var canceled bool + retErr, canceled = hs.withContext(ctx, func() error { + return fn(ctx, cmds) }) + if canceled { + setCmdsErr(cmds, retErr) + } } for hookIndex--; hookIndex >= 0; hookIndex-- { @@ -138,19 +138,20 @@ func (hs hooks) processTxPipeline( return hs.processPipeline(ctx, cmds, fn) } -func (hs hooks) withContext(ctx context.Context, fn func() error) error { - if ctx.Done() == nil { - return fn() +func (hs hooks) withContext(ctx context.Context, fn func() error) (_ error, canceled bool) { + done := ctx.Done() + if done == nil { + return fn(), false } errc := make(chan error, 1) go func() { errc <- fn() }() select { - case <-ctx.Done(): - return ctx.Err() + case <-done: + return ctx.Err(), true case err := <-errc: - return err + return err, false } } From c5d4b71f6661a2236b52bfbe5ed09bf9d3198319 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco <vladimir.webdev@gmail.com> Date: Thu, 17 Sep 2020 12:27:16 +0300 Subject: [PATCH 3/3] Fix race --- cluster.go | 12 +--- command.go | 162 +++++++++++++++++++++++------------------------- command_test.go | 2 +- redis.go | 42 ++++--------- ring.go | 15 +---- sentinel.go | 3 +- 6 files changed, 96 insertions(+), 140 deletions(-) diff --git a/cluster.go b/cluster.go index be8217b1..d17c7479 100644 --- a/cluster.go +++ b/cluster.go @@ -751,15 +751,6 @@ func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error { } func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { - err := c._process(ctx, cmd) - if err != nil { - cmd.SetErr(err) - return err - } - return nil -} - -func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error { cmdInfo := c.cmdInfo(cmd.Name()) slot := c.cmdSlot(cmd) @@ -1197,9 +1188,12 @@ func (c *ClusterClient) pipelineReadCmds( ) error { for _, cmd := range cmds { err := cmd.readReply(rd) + cmd.SetErr(err) + if err == nil { continue } + if c.checkMovedErr(ctx, cmd, err, failedCmds) { continue } diff --git a/command.go b/command.go index 55a5bd5c..4879cfa0 100644 --- a/command.go +++ b/command.go @@ -299,9 +299,9 @@ func (cmd *Cmd) Bool() (bool, error) { } } -func (cmd *Cmd) readReply(rd *proto.Reader) error { - cmd.val, cmd.err = rd.ReadReply(sliceParser) - return cmd.err +func (cmd *Cmd) readReply(rd *proto.Reader) (err error) { + cmd.val, err = rd.ReadReply(sliceParser) + return err } // sliceParser implements proto.MultiBulkParse. @@ -357,10 +357,9 @@ func (cmd *SliceCmd) String() string { } func (cmd *SliceCmd) readReply(rd *proto.Reader) error { - var v interface{} - v, cmd.err = rd.ReadArrayReply(sliceParser) - if cmd.err != nil { - return cmd.err + v, err := rd.ReadArrayReply(sliceParser) + if err != nil { + return err } cmd.val = v.([]interface{}) return nil @@ -397,9 +396,9 @@ func (cmd *StatusCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StatusCmd) readReply(rd *proto.Reader) error { - cmd.val, cmd.err = rd.ReadString() - return cmd.err +func (cmd *StatusCmd) readReply(rd *proto.Reader) (err error) { + cmd.val, err = rd.ReadString() + return err } //------------------------------------------------------------------------------ @@ -437,9 +436,9 @@ func (cmd *IntCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *IntCmd) readReply(rd *proto.Reader) error { - cmd.val, cmd.err = rd.ReadIntReply() - return cmd.err +func (cmd *IntCmd) readReply(rd *proto.Reader) (err error) { + cmd.val, err = rd.ReadIntReply() + return err } //------------------------------------------------------------------------------ @@ -474,7 +473,7 @@ func (cmd *IntSliceCmd) String() string { } func (cmd *IntSliceCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make([]int64, n) for i := 0; i < len(cmd.val); i++ { num, err := rd.ReadIntReply() @@ -485,7 +484,7 @@ func (cmd *IntSliceCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -522,10 +521,9 @@ func (cmd *DurationCmd) String() string { } func (cmd *DurationCmd) readReply(rd *proto.Reader) error { - var n int64 - n, cmd.err = rd.ReadIntReply() - if cmd.err != nil { - return cmd.err + n, err := rd.ReadIntReply() + if err != nil { + return err } switch n { // -2 if the key does not exist @@ -570,7 +568,7 @@ func (cmd *TimeCmd) String() string { } func (cmd *TimeCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { if n != 2 { return nil, fmt.Errorf("got %d elements, expected 2", n) } @@ -588,7 +586,7 @@ func (cmd *TimeCmd) readReply(rd *proto.Reader) error { cmd.val = time.Unix(sec, microsec*1000) return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -623,17 +621,15 @@ func (cmd *BoolCmd) String() string { } func (cmd *BoolCmd) readReply(rd *proto.Reader) error { - var v interface{} - v, cmd.err = rd.ReadReply(nil) + v, err := rd.ReadReply(nil) // `SET key value NX` returns nil when key already exists. But // `SETNX key value` returns bool (0/1). So convert nil to bool. - if cmd.err == Nil { + if err == Nil { cmd.val = false - cmd.err = nil return nil } - if cmd.err != nil { - return cmd.err + if err != nil { + return err } switch v := v.(type) { case int64: @@ -643,8 +639,7 @@ func (cmd *BoolCmd) readReply(rd *proto.Reader) error { cmd.val = v == "OK" return nil default: - cmd.err = fmt.Errorf("got %T, wanted int64 or string", v) - return cmd.err + return fmt.Errorf("got %T, wanted int64 or string", v) } } @@ -736,9 +731,9 @@ func (cmd *StringCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringCmd) readReply(rd *proto.Reader) error { - cmd.val, cmd.err = rd.ReadString() - return cmd.err +func (cmd *StringCmd) readReply(rd *proto.Reader) (err error) { + cmd.val, err = rd.ReadString() + return err } //------------------------------------------------------------------------------ @@ -772,9 +767,9 @@ func (cmd *FloatCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *FloatCmd) readReply(rd *proto.Reader) error { - cmd.val, cmd.err = rd.ReadFloatReply() - return cmd.err +func (cmd *FloatCmd) readReply(rd *proto.Reader) (err error) { + cmd.val, err = rd.ReadFloatReply() + return err } //------------------------------------------------------------------------------ @@ -813,7 +808,7 @@ func (cmd *StringSliceCmd) ScanSlice(container interface{}) error { } func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make([]string, n) for i := 0; i < len(cmd.val); i++ { switch s, err := rd.ReadString(); { @@ -827,7 +822,7 @@ func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -862,7 +857,7 @@ func (cmd *BoolSliceCmd) String() string { } func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make([]bool, n) for i := 0; i < len(cmd.val); i++ { n, err := rd.ReadIntReply() @@ -873,7 +868,7 @@ func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -908,7 +903,7 @@ func (cmd *StringStringMapCmd) String() string { } func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make(map[string]string, n/2) for i := int64(0); i < n; i += 2 { key, err := rd.ReadString() @@ -925,7 +920,7 @@ func (cmd *StringStringMapCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -960,7 +955,7 @@ func (cmd *StringIntMapCmd) String() string { } func (cmd *StringIntMapCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make(map[string]int64, n/2) for i := int64(0); i < n; i += 2 { key, err := rd.ReadString() @@ -977,7 +972,7 @@ func (cmd *StringIntMapCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -1012,7 +1007,7 @@ func (cmd *StringStructMapCmd) String() string { } func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make(map[string]struct{}, n) for i := int64(0); i < n; i++ { key, err := rd.ReadString() @@ -1023,7 +1018,7 @@ func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -1063,10 +1058,9 @@ func (cmd *XMessageSliceCmd) String() string { } func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) error { - var v interface{} - v, cmd.err = rd.ReadArrayReply(xMessageSliceParser) - if cmd.err != nil { - return cmd.err + v, err := rd.ReadArrayReply(xMessageSliceParser) + if err != nil { + return err } cmd.val = v.([]XMessage) return nil @@ -1163,7 +1157,7 @@ func (cmd *XStreamSliceCmd) String() string { } func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make([]XStream, n) for i := 0; i < len(cmd.val); i++ { i := i @@ -1194,7 +1188,7 @@ func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -1235,7 +1229,7 @@ func (cmd *XPendingCmd) String() string { } func (cmd *XPendingCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { if n != 4 { return nil, fmt.Errorf("got %d, wanted 4", n) } @@ -1296,7 +1290,7 @@ func (cmd *XPendingCmd) readReply(rd *proto.Reader) error { return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -1337,7 +1331,7 @@ func (cmd *XPendingExtCmd) String() string { } func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make([]XPendingExt, 0, n) for i := int64(0); i < n; i++ { _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { @@ -1379,7 +1373,7 @@ func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -1420,18 +1414,17 @@ func (cmd *XInfoGroupsCmd) String() string { } func (cmd *XInfoGroupsCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply( - func(rd *proto.Reader, n int64) (interface{}, error) { - for i := int64(0); i < n; i++ { - v, err := rd.ReadReply(xGroupInfoParser) - if err != nil { - return nil, err - } - cmd.val = append(cmd.val, v.(XInfoGroups)) + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + for i := int64(0); i < n; i++ { + v, err := rd.ReadReply(xGroupInfoParser) + if err != nil { + return nil, err } - return nil, nil - }) - return nil + cmd.val = append(cmd.val, v.(XInfoGroups)) + } + return nil, nil + }) + return err } func xGroupInfoParser(rd *proto.Reader, n int64) (interface{}, error) { @@ -1507,7 +1500,7 @@ func (cmd *ZSliceCmd) String() string { } func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make([]Z, n/2) for i := 0; i < len(cmd.val); i++ { member, err := rd.ReadString() @@ -1527,7 +1520,7 @@ func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -1562,7 +1555,7 @@ func (cmd *ZWithKeyCmd) String() string { } func (cmd *ZWithKeyCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { if n != 3 { return nil, fmt.Errorf("got %d elements, expected 3", n) } @@ -1587,7 +1580,7 @@ func (cmd *ZWithKeyCmd) readReply(rd *proto.Reader) error { return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -1625,9 +1618,9 @@ func (cmd *ScanCmd) String() string { return cmdString(cmd, cmd.page) } -func (cmd *ScanCmd) readReply(rd *proto.Reader) error { - cmd.page, cmd.cursor, cmd.err = rd.ReadScanReply() - return cmd.err +func (cmd *ScanCmd) readReply(rd *proto.Reader) (err error) { + cmd.page, cmd.cursor, err = rd.ReadScanReply() + return err } // Iterator creates a new ScanIterator. @@ -1680,7 +1673,7 @@ func (cmd *ClusterSlotsCmd) String() string { } func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make([]ClusterSlot, n) for i := 0; i < len(cmd.val); i++ { n, err := rd.ReadArrayLen() @@ -1742,7 +1735,7 @@ func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -1834,10 +1827,9 @@ func (cmd *GeoLocationCmd) String() string { } func (cmd *GeoLocationCmd) readReply(rd *proto.Reader) error { - var v interface{} - v, cmd.err = rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) - if cmd.err != nil { - return cmd.err + v, err := rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) + if err != nil { + return err } cmd.locations = v.([]GeoLocation) return nil @@ -1947,7 +1939,7 @@ func (cmd *GeoPosCmd) String() string { } func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make([]*GeoPos, n) for i := 0; i < len(cmd.val); i++ { i := i @@ -1978,7 +1970,7 @@ func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } //------------------------------------------------------------------------------ @@ -2024,7 +2016,7 @@ func (cmd *CommandsInfoCmd) String() string { } func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make(map[string]*CommandInfo, n) for i := int64(0); i < n; i++ { v, err := rd.ReadReply(commandInfoParser) @@ -2036,7 +2028,7 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } func commandInfoParser(rd *proto.Reader, n int64) (interface{}, error) { @@ -2211,7 +2203,7 @@ func (cmd *SlowLogCmd) String() string { } func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error { - _, cmd.err = rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { + _, err := rd.ReadArrayReply(func(rd *proto.Reader, n int64) (interface{}, error) { cmd.val = make([]SlowLog, n) for i := 0; i < len(cmd.val); i++ { n, err := rd.ReadArrayLen() @@ -2281,5 +2273,5 @@ func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error { } return nil, nil }) - return cmd.err + return err } diff --git a/command_test.go b/command_test.go index d80b7444..d110d0c3 100644 --- a/command_test.go +++ b/command_test.go @@ -86,7 +86,7 @@ var _ = Describe("Cmd", func() { Expect(tm2).To(BeTemporally("==", tm)) }) - It("allow to set custom error", func() { + It("allows to set custom error", func() { e := errors.New("custom error") cmd := redis.Cmd{} cmd.SetErr(e) diff --git a/redis.go b/redis.go index e15da91e..0921359e 100644 --- a/redis.go +++ b/redis.go @@ -49,12 +49,10 @@ func (hs hooks) process( ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, ) error { if len(hs.hooks) == 0 { - err, canceled := hs.withContext(ctx, func() error { + err := hs.withContext(ctx, func() error { return fn(ctx, cmd) }) - if canceled { - cmd.SetErr(err) - } + cmd.SetErr(err) return err } @@ -69,13 +67,10 @@ func (hs hooks) process( } if retErr == nil { - var canceled bool - retErr, canceled = hs.withContext(ctx, func() error { + retErr = hs.withContext(ctx, func() error { return fn(ctx, cmd) }) - if canceled { - cmd.SetErr(retErr) - } + cmd.SetErr(retErr) } for hookIndex--; hookIndex >= 0; hookIndex-- { @@ -92,12 +87,9 @@ func (hs hooks) processPipeline( ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, ) error { if len(hs.hooks) == 0 { - err, canceled := hs.withContext(ctx, func() error { + err := hs.withContext(ctx, func() error { return fn(ctx, cmds) }) - if canceled { - setCmdsErr(cmds, err) - } return err } @@ -112,13 +104,9 @@ func (hs hooks) processPipeline( } if retErr == nil { - var canceled bool - retErr, canceled = hs.withContext(ctx, func() error { + retErr = hs.withContext(ctx, func() error { return fn(ctx, cmds) }) - if canceled { - setCmdsErr(cmds, retErr) - } } for hookIndex--; hookIndex >= 0; hookIndex-- { @@ -138,10 +126,10 @@ func (hs hooks) processTxPipeline( return hs.processPipeline(ctx, cmds, fn) } -func (hs hooks) withContext(ctx context.Context, fn func() error) (_ error, canceled bool) { +func (hs hooks) withContext(ctx context.Context, fn func() error) error { done := ctx.Done() if done == nil { - return fn(), false + return fn() } errc := make(chan error, 1) @@ -149,9 +137,9 @@ func (hs hooks) withContext(ctx context.Context, fn func() error) (_ error, canc select { case <-done: - return ctx.Err(), true + return ctx.Err() case err := <-errc: - return err, false + return err } } @@ -324,15 +312,6 @@ func (c *baseClient) withConn( } func (c *baseClient) process(ctx context.Context, cmd Cmder) error { - err := c._process(ctx, cmd) - if err != nil { - cmd.SetErr(err) - return err - } - return nil -} - -func (c *baseClient) _process(ctx context.Context, cmd Cmder) error { var lastErr error for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { attempt := attempt @@ -476,6 +455,7 @@ func (c *baseClient) pipelineProcessCmds( func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { for _, cmd := range cmds { err := cmd.readReply(rd) + cmd.SetErr(err) if err != nil && !isRedisError(err) { return err } diff --git a/ring.go b/ring.go index e0b433e1..86fce524 100644 --- a/ring.go +++ b/ring.go @@ -588,15 +588,6 @@ func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) { } func (c *Ring) process(ctx context.Context, cmd Cmder) error { - err := c._process(ctx, cmd) - if err != nil { - cmd.SetErr(err) - return err - } - return nil -} - -func (c *Ring) _process(ctx context.Context, cmd Cmder) error { var lastErr error for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { @@ -694,11 +685,9 @@ func (c *Ring) processShardPipeline( } if tx { - err = shard.Client.processTxPipeline(ctx, cmds) - } else { - err = shard.Client.processPipeline(ctx, cmds) + return shard.Client.processTxPipeline(ctx, cmds) } - return err + return shard.Client.processPipeline(ctx, cmds) } func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error { diff --git a/sentinel.go b/sentinel.go index f58970f8..f911622a 100644 --- a/sentinel.go +++ b/sentinel.go @@ -224,6 +224,7 @@ func masterSlaveDialer( // SentinelClient is a client for a Redis Sentinel. type SentinelClient struct { *baseClient + hooks ctx context.Context } @@ -253,7 +254,7 @@ func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient { } func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error { - return c.baseClient.process(ctx, cmd) + return c.hooks.process(ctx, cmd, c.baseClient.process) } func (c *SentinelClient) pubSub() *PubSub {