diff --git a/bench_test.go b/bench_test.go index f670a7c5..1646724c 100644 --- a/bench_test.go +++ b/bench_test.go @@ -254,7 +254,7 @@ func BenchmarkClusterPing(b *testing.B) { } defer stopCluster(cluster) - client := cluster.clusterClient(redisClusterOptions()) + client := cluster.newClusterClient(redisClusterOptions()) defer client.Close() b.ResetTimer() @@ -280,7 +280,7 @@ func BenchmarkClusterSetString(b *testing.B) { } defer stopCluster(cluster) - client := cluster.clusterClient(redisClusterOptions()) + client := cluster.newClusterClient(redisClusterOptions()) defer client.Close() value := string(bytes.Repeat([]byte{'1'}, 10000)) @@ -308,7 +308,7 @@ func BenchmarkClusterReloadState(b *testing.B) { } defer stopCluster(cluster) - client := cluster.clusterClient(redisClusterOptions()) + client := cluster.newClusterClient(redisClusterOptions()) defer client.Close() b.ResetTimer() diff --git a/cluster.go b/cluster.go index c5fcb9bb..2cf6e01b 100644 --- a/cluster.go +++ b/cluster.go @@ -773,13 +773,13 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error { if ask { pipe := node.Client.Pipeline() - _ = pipe.Process(NewCmd("ASKING")) + _ = pipe.Process(NewCmd("asking")) _ = pipe.Process(cmd) _, lastErr = pipe.ExecContext(ctx) _ = pipe.Close() ask = false } else { - lastErr = node.Client._process(ctx, cmd) + lastErr = node.Client.ProcessContext(ctx, cmd) } // If there is no error - we are done. @@ -840,6 +840,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { var wg sync.WaitGroup errCh := make(chan error, 1) + for _, master := range state.Masters { wg.Add(1) go func(node *clusterNode) { @@ -853,6 +854,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { } }(master) } + wg.Wait() select { @@ -873,6 +875,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { var wg sync.WaitGroup errCh := make(chan error, 1) + for _, slave := range state.Slaves { wg.Add(1) go func(node *clusterNode) { @@ -886,6 +889,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { } }(slave) } + wg.Wait() select { @@ -906,6 +910,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { var wg sync.WaitGroup errCh := make(chan error, 1) + worker := func(node *clusterNode) { defer wg.Done() err := fn(node.Client) @@ -927,6 +932,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { } wg.Wait() + select { case err := <-errCh: return err @@ -1068,18 +1074,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmd(wr, cmds...) - }) - if err != nil { - return err - } - - return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - return c.pipelineReadCmds(node, rd, cmds, failedCmds) - }) - }) + err := c._processPipelineNode(ctx, node, cmds, failedCmds) if err == nil { return } @@ -1142,6 +1137,25 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool { return true } +func (c *ClusterClient) _processPipelineNode( + ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, +) error { + return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmds(wr, cmds) + }) + if err != nil { + return err + } + + return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { + return c.pipelineReadCmds(node, rd, cmds, failedCmds) + }) + }) + }) +} + func (c *ClusterClient) pipelineReadCmds( node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, ) error { @@ -1186,7 +1200,7 @@ func (c *ClusterClient) checkMovedErr( } if ask { - failedCmds.Add(node, NewCmd("ASKING"), cmd) + failedCmds.Add(node, NewCmd("asking"), cmd) return true } @@ -1243,26 +1257,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return txPipelineWriteMulti(wr, cmds) - }) - if err != nil { - return err - } - - return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - err := c.txPipelineReadQueued(rd, cmds, failedCmds) - if err != nil { - moved, ask, addr := isMovedError(err) - if moved || ask { - return c.cmdsMoved(cmds, moved, ask, addr, failedCmds) - } - return err - } - return pipelineReadCmds(rd, cmds) - }) - }) + err := c._processTxPipelineNode(ctx, node, cmds, failedCmds) if err == nil { return } @@ -1296,11 +1291,42 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { return cmdsMap } +func (c *ClusterClient) _processTxPipelineNode( + ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, +) error { + return node.Client.hooks.processTxPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmds(wr, cmds) + }) + if err != nil { + return err + } + + return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { + statusCmd := cmds[0].(*StatusCmd) + // Trim multi and exec. + cmds = cmds[1 : len(cmds)-1] + + err := c.txPipelineReadQueued(rd, statusCmd, cmds, failedCmds) + if err != nil { + moved, ask, addr := isMovedError(err) + if moved || ask { + return c.cmdsMoved(cmds, moved, ask, addr, failedCmds) + } + return err + } + + return pipelineReadCmds(rd, cmds) + }) + }) + }) +} + func (c *ClusterClient) txPipelineReadQueued( - rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, + rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder, failedCmds *cmdsMap, ) error { // Parse queued replies. - var statusCmd StatusCmd if err := statusCmd.readReply(rd); err != nil { return err } @@ -1352,7 +1378,7 @@ func (c *ClusterClient) cmdsMoved( if ask { for _, cmd := range cmds { - failedCmds.Add(node, NewCmd("ASKING"), cmd) + failedCmds.Add(node, NewCmd("asking"), cmd) } return nil } diff --git a/cluster_test.go b/cluster_test.go index 05eaeb4d..62fd6e0d 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -47,14 +47,14 @@ func (s *clusterScenario) addrs() []string { return addrs } -func (s *clusterScenario) clusterClientUnsafe(opt *redis.ClusterOptions) *redis.ClusterClient { +func (s *clusterScenario) newClusterClientUnsafe(opt *redis.ClusterOptions) *redis.ClusterClient { opt.Addrs = s.addrs() return redis.NewClusterClient(opt) } -func (s *clusterScenario) clusterClient(opt *redis.ClusterOptions) *redis.ClusterClient { - client := s.clusterClientUnsafe(opt) +func (s *clusterScenario) newClusterClient(opt *redis.ClusterOptions) *redis.ClusterClient { + client := s.newClusterClientUnsafe(opt) err := eventually(func() error { if opt.ClusterSlots != nil { @@ -527,12 +527,184 @@ var _ = Describe("ClusterClient", func() { err := pubsub.Ping() Expect(err).NotTo(HaveOccurred()) }) + + It("supports Process hook", func() { + err := client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.ForEachNode(func(node *redis.Client) error { + return node.Ping().Err() + }) + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + clusterHook := &hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcess") + return nil + }, + } + client.AddHook(clusterHook) + + nodeHook := &hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcess") + return nil + }, + } + + _ = client.ForEachNode(func(node *redis.Client) error { + node.AddHook(nodeHook) + return nil + }) + + err = client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "cluster.BeforeProcess", + "shard.BeforeProcess", + "shard.AfterProcess", + "cluster.AfterProcess", + })) + + clusterHook.beforeProcess = nil + clusterHook.afterProcess = nil + nodeHook.beforeProcess = nil + nodeHook.afterProcess = nil + }) + + It("supports Pipeline hook", func() { + err := client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.ForEachNode(func(node *redis.Client) error { + return node.Ping().Err() + }) + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + client.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcessPipeline") + return nil + }, + }) + + _ = client.ForEachNode(func(node *redis.Client) error { + node.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + return nil + }) + + _, err = client.Pipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "cluster.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "cluster.AfterProcessPipeline", + })) + }) + + It("supports TxPipeline hook", func() { + err := client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.ForEachNode(func(node *redis.Client) error { + return node.Ping().Err() + }) + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + client.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcessPipeline") + return nil + }, + }) + + _ = client.ForEachNode(func(node *redis.Client) error { + node.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + return nil + }) + + _, err = client.TxPipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "cluster.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "cluster.AfterProcessPipeline", + })) + }) } Describe("ClusterClient", func() { BeforeEach(func() { opt = redisClusterOptions() - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -555,7 +727,7 @@ var _ = Describe("ClusterClient", func() { It("returns an error when there are no attempts left", func() { opt := redisClusterOptions() opt.MaxRedirects = -1 - client := cluster.clusterClient(opt) + client := cluster.newClusterClient(opt) Eventually(func() error { return client.SwapNodes("A") @@ -707,7 +879,7 @@ var _ = Describe("ClusterClient", func() { opt = redisClusterOptions() opt.MinRetryBackoff = 250 * time.Millisecond opt.MaxRetryBackoff = time.Second - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -757,7 +929,7 @@ var _ = Describe("ClusterClient", func() { BeforeEach(func() { opt = redisClusterOptions() opt.RouteByLatency = true - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -813,7 +985,7 @@ var _ = Describe("ClusterClient", func() { }} return slots, nil } - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -867,7 +1039,7 @@ var _ = Describe("ClusterClient", func() { }} return slots, nil } - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -959,7 +1131,7 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() { opt.ReadTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond opt.MaxRedirects = 1 - client = cluster.clusterClientUnsafe(opt) + client = cluster.newClusterClientUnsafe(opt) }) AfterEach(func() { @@ -1028,7 +1200,7 @@ var _ = Describe("ClusterClient timeout", func() { opt.ReadTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond opt.MaxRedirects = 1 - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachNode(func(client *redis.Client) error { return client.ClientPause(pause).Err() diff --git a/command.go b/command.go index 0a4a345d..266e3892 100644 --- a/command.go +++ b/command.go @@ -15,6 +15,7 @@ import ( type Cmder interface { Name() string Args() []interface{} + String() string stringArg(int) string readTimeout() *time.Duration @@ -41,16 +42,19 @@ func cmdsFirstErr(cmds []Cmder) error { return nil } -func writeCmd(wr *proto.Writer, cmds ...Cmder) error { +func writeCmds(wr *proto.Writer, cmds []Cmder) error { for _, cmd := range cmds { - err := wr.WriteArgs(cmd.Args()) - if err != nil { + if err := writeCmd(wr, cmd); err != nil { return err } } return nil } +func writeCmd(wr *proto.Writer, cmd Cmder) error { + return wr.WriteArgs(cmd.Args()) +} + func cmdString(cmd Cmder, val interface{}) string { ss := make([]string, 0, len(cmd.Args())) for _, arg := range cmd.Args() { @@ -149,6 +153,10 @@ func NewCmd(args ...interface{}) *Cmd { } } +func (cmd *Cmd) String() string { + return cmdString(cmd, cmd.val) +} + func (cmd *Cmd) Val() interface{} { return cmd.val } @@ -157,7 +165,7 @@ func (cmd *Cmd) Result() (interface{}, error) { return cmd.val, cmd.err } -func (cmd *Cmd) String() (string, error) { +func (cmd *Cmd) Text() (string, error) { if cmd.err != nil { return "", cmd.err } diff --git a/example_test.go b/example_test.go index b38f71f6..801dc965 100644 --- a/example_test.go +++ b/example_test.go @@ -447,7 +447,7 @@ func Example_customCommand() { } func Example_customCommand2() { - v, err := rdb.Do("get", "key_does_not_exist").String() + v, err := rdb.Do("get", "key_does_not_exist").Text() fmt.Printf("%q %s", v, err) // Output: "" redis: nil } diff --git a/main_test.go b/main_test.go index 958a00d3..f7f028a1 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "errors" "fmt" "net" @@ -370,3 +371,41 @@ func (cn *badConn) Write([]byte) (int, error) { } return 0, badConnError("bad connection") } + +//------------------------------------------------------------------------------ + +type hook struct { + beforeProcess func(ctx context.Context, cmd redis.Cmder) (context.Context, error) + afterProcess func(ctx context.Context, cmd redis.Cmder) error + + beforeProcessPipeline func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) + afterProcessPipeline func(ctx context.Context, cmds []redis.Cmder) error +} + +func (h *hook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + if h.beforeProcess != nil { + return h.beforeProcess(ctx, cmd) + } + return ctx, nil +} + +func (h *hook) AfterProcess(ctx context.Context, cmd redis.Cmder) error { + if h.afterProcess != nil { + return h.afterProcess(ctx, cmd) + } + return nil +} + +func (h *hook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + if h.beforeProcessPipeline != nil { + return h.beforeProcessPipeline(ctx, cmds) + } + return ctx, nil +} + +func (h *hook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error { + if h.afterProcessPipeline != nil { + return h.afterProcessPipeline(ctx, cmds) + } + return nil +} diff --git a/race_test.go b/race_test.go index afe06cf8..9a575555 100644 --- a/race_test.go +++ b/race_test.go @@ -299,7 +299,7 @@ var _ = Describe("cluster races", func() { BeforeEach(func() { opt := redisClusterOptions() - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) C, N = 10, 1000 if testing.Short() { diff --git a/redis.go b/redis.go index 2b50fff0..93032579 100644 --- a/redis.go +++ b/redis.go @@ -128,6 +128,13 @@ func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error { return firstErr } +func (hs hooks) processTxPipeline( + ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, +) error { + cmds = wrapMultiExec(cmds) + return hs.processPipeline(ctx, cmds, fn) +} + //------------------------------------------------------------------------------ type baseClient struct { @@ -411,7 +418,7 @@ func (c *baseClient) pipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmd(wr, cmds...) + return writeCmds(wr, cmds) }) if err != nil { return true, err @@ -437,41 +444,46 @@ func (c *baseClient) txPipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return txPipelineWriteMulti(wr, cmds) + return writeCmds(wr, cmds) }) if err != nil { return true, err } err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - err := txPipelineReadQueued(rd, cmds) + statusCmd := cmds[0].(*StatusCmd) + // Trim multi and exec. + cmds = cmds[1 : len(cmds)-1] + + err := txPipelineReadQueued(rd, statusCmd, cmds) if err != nil { return err } + return pipelineReadCmds(rd, cmds) }) return false, err } -func txPipelineWriteMulti(wr *proto.Writer, cmds []Cmder) error { - multiExec := make([]Cmder, 0, len(cmds)+2) - multiExec = append(multiExec, NewStatusCmd("MULTI")) - multiExec = append(multiExec, cmds...) - multiExec = append(multiExec, NewSliceCmd("EXEC")) - return writeCmd(wr, multiExec...) +func wrapMultiExec(cmds []Cmder) []Cmder { + if len(cmds) == 0 { + panic("not reached") + } + cmds = append(cmds, make([]Cmder, 2)...) + copy(cmds[1:], cmds[:len(cmds)-2]) + cmds[0] = NewStatusCmd("multi") + cmds[len(cmds)-1] = NewSliceCmd("exec") + return cmds } -func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error { +func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { // Parse queued replies. - var statusCmd StatusCmd - err := statusCmd.readReply(rd) - if err != nil { + if err := statusCmd.readReply(rd); err != nil { return err } for range cmds { - err = statusCmd.readReply(rd) - if err != nil && !isRedisError(err) { + if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) { return err } } @@ -577,7 +589,7 @@ func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error { } func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline) + return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) } // Options returns read-only Options that were used to create the client. diff --git a/ring.go b/ring.go index b0c6dfc3..e1b49917 100644 --- a/ring.go +++ b/ring.go @@ -581,7 +581,7 @@ func (c *Ring) _process(ctx context.Context, cmd Cmder) error { return err } - lastErr = shard.Client._process(ctx, cmd) + lastErr = shard.Client.ProcessContext(ctx, cmd) if lastErr == nil || !isRetryableError(lastErr, cmd.readTimeout() == nil) { return lastErr } @@ -646,10 +646,7 @@ func (c *Ring) generalProcessPipeline( go func(hash string, cmds []Cmder) { defer wg.Done() - err := c.processShardPipeline(ctx, hash, cmds, tx) - if err != nil { - setCmdsErr(cmds, err) - } + _ = c.processShardPipeline(ctx, hash, cmds, tx) }(hash, cmds) } @@ -663,15 +660,14 @@ func (c *Ring) processShardPipeline( //TODO: retry? shard, err := c.shards.GetByHash(hash) if err != nil { + setCmdsErr(cmds, err) return err } if tx { - err = shard.Client._generalProcessPipeline( - ctx, cmds, shard.Client.txPipelineProcessCmds) + err = shard.Client.processTxPipeline(ctx, cmds) } else { - err = shard.Client._generalProcessPipeline( - ctx, cmds, shard.Client.pipelineProcessCmds) + err = shard.Client.processPipeline(ctx, cmds) } return err } diff --git a/ring_test.go b/ring_test.go index ac9d9a9e..eef8dc25 100644 --- a/ring_test.go +++ b/ring_test.go @@ -195,6 +195,155 @@ var _ = Describe("Redis Ring", func() { Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set")) }) }) + + It("supports Process hook", func() { + err := ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + ring.AddHook(&hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcess") + return nil + }, + }) + + ring.ForEachShard(func(shard *redis.Client) error { + shard.AddHook(&hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcess") + return nil + }, + }) + return nil + }) + + err = ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "ring.BeforeProcess", + "shard.BeforeProcess", + "shard.AfterProcess", + "ring.AfterProcess", + })) + }) + + It("supports Pipeline hook", func() { + err := ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + ring.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcessPipeline") + return nil + }, + }) + + ring.ForEachShard(func(shard *redis.Client) error { + shard.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + return nil + }) + + _, err = ring.Pipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "ring.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "ring.AfterProcessPipeline", + })) + }) + + It("supports TxPipeline hook", func() { + err := ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + ring.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcessPipeline") + return nil + }, + }) + + ring.ForEachShard(func(shard *redis.Client) error { + shard.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + return nil + }) + + _, err = ring.TxPipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "ring.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "ring.AfterProcessPipeline", + })) + }) }) var _ = Describe("empty Redis Ring", func() { diff --git a/tx.go b/tx.go index 4af5ebe6..9ae15901 100644 --- a/tx.go +++ b/tx.go @@ -151,7 +151,7 @@ func (c *Tx) TxPipeline() Pipeliner { pipe := Pipeline{ ctx: c.ctx, exec: func(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline) + return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) }, } pipe.init()