From 2e3402d33d8c2d63d027edf75d63592eb10cbe01 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Fri, 14 Feb 2020 12:44:03 +0200 Subject: [PATCH 1/3] Don't allocate tmp slice in txPipelineWriteMulti --- cluster.go | 2 +- command.go | 9 ++++++--- redis.go | 22 ++++++++++++++++------ 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/cluster.go b/cluster.go index c5fcb9bb..12721c3f 100644 --- a/cluster.go +++ b/cluster.go @@ -1070,7 +1070,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmd(wr, cmds...) + return writeCmds(wr, cmds) }) if err != nil { return err diff --git a/command.go b/command.go index 0a4a345d..8c5fbe24 100644 --- a/command.go +++ b/command.go @@ -41,16 +41,19 @@ func cmdsFirstErr(cmds []Cmder) error { return nil } -func writeCmd(wr *proto.Writer, cmds ...Cmder) error { +func writeCmds(wr *proto.Writer, cmds []Cmder) error { for _, cmd := range cmds { - err := wr.WriteArgs(cmd.Args()) - if err != nil { + if err := writeCmd(wr, cmd); err != nil { return err } } return nil } +func writeCmd(wr *proto.Writer, cmd Cmder) error { + return wr.WriteArgs(cmd.Args()) +} + func cmdString(cmd Cmder, val interface{}) string { ss := make([]string, 0, len(cmd.Args())) for _, arg := range cmd.Args() { diff --git a/redis.go b/redis.go index 2b50fff0..caba631b 100644 --- a/redis.go +++ b/redis.go @@ -411,7 +411,7 @@ func (c *baseClient) pipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmd(wr, cmds...) + return writeCmds(wr, cmds) }) if err != nil { return true, err @@ -453,12 +453,22 @@ func (c *baseClient) txPipelineProcessCmds( return false, err } +var ( + multi = NewStatusCmd("multi") + exec = NewSliceCmd("exec") +) + func txPipelineWriteMulti(wr *proto.Writer, cmds []Cmder) error { - multiExec := make([]Cmder, 0, len(cmds)+2) - multiExec = append(multiExec, NewStatusCmd("MULTI")) - multiExec = append(multiExec, cmds...) - multiExec = append(multiExec, NewSliceCmd("EXEC")) - return writeCmd(wr, multiExec...) + if err := writeCmd(wr, multi); err != nil { + return err + } + if err := writeCmds(wr, cmds); err != nil { + return err + } + if err := writeCmd(wr, exec); err != nil { + return err + } + return nil } func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error { From 49a0c8c3198bec92efef81a9a9e514920348e22e Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Fri, 14 Feb 2020 14:30:07 +0200 Subject: [PATCH 2/3] Add test for ring and cluster hooks --- cluster.go | 88 +++++++++++++++--------- cluster_test.go | 178 ++++++++++++++++++++++++++++++++++++++++++++++++ command.go | 7 +- example_test.go | 2 +- main_test.go | 39 +++++++++++ ring.go | 14 ++-- ring_test.go | 149 ++++++++++++++++++++++++++++++++++++++++ 7 files changed, 433 insertions(+), 44 deletions(-) diff --git a/cluster.go b/cluster.go index 12721c3f..5c6a97c6 100644 --- a/cluster.go +++ b/cluster.go @@ -779,7 +779,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error { _ = pipe.Close() ask = false } else { - lastErr = node.Client._process(ctx, cmd) + lastErr = node.Client.ProcessContext(ctx, cmd) } // If there is no error - we are done. @@ -840,6 +840,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { var wg sync.WaitGroup errCh := make(chan error, 1) + for _, master := range state.Masters { wg.Add(1) go func(node *clusterNode) { @@ -853,6 +854,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { } }(master) } + wg.Wait() select { @@ -873,6 +875,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { var wg sync.WaitGroup errCh := make(chan error, 1) + for _, slave := range state.Slaves { wg.Add(1) go func(node *clusterNode) { @@ -886,6 +889,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { } }(slave) } + wg.Wait() select { @@ -906,6 +910,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { var wg sync.WaitGroup errCh := make(chan error, 1) + worker := func(node *clusterNode) { defer wg.Done() err := fn(node.Client) @@ -927,6 +932,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { } wg.Wait() + select { case err := <-errCh: return err @@ -1068,18 +1074,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmds(wr, cmds) - }) - if err != nil { - return err - } - - return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - return c.pipelineReadCmds(node, rd, cmds, failedCmds) - }) - }) + err := c._processPipelineNode(ctx, node, cmds, failedCmds) if err == nil { return } @@ -1142,6 +1137,25 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool { return true } +func (c *ClusterClient) _processPipelineNode( + ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, +) error { + return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmds(wr, cmds) + }) + if err != nil { + return err + } + + return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { + return c.pipelineReadCmds(node, rd, cmds, failedCmds) + }) + }) + }) +} + func (c *ClusterClient) pipelineReadCmds( node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, ) error { @@ -1243,26 +1257,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return txPipelineWriteMulti(wr, cmds) - }) - if err != nil { - return err - } - - return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - err := c.txPipelineReadQueued(rd, cmds, failedCmds) - if err != nil { - moved, ask, addr := isMovedError(err) - if moved || ask { - return c.cmdsMoved(cmds, moved, ask, addr, failedCmds) - } - return err - } - return pipelineReadCmds(rd, cmds) - }) - }) + err := c._processTxPipelineNode(ctx, node, cmds, failedCmds) if err == nil { return } @@ -1296,6 +1291,33 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { return cmdsMap } +func (c *ClusterClient) _processTxPipelineNode( + ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, +) error { + return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + return txPipelineWriteMulti(wr, cmds) + }) + if err != nil { + return err + } + + return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { + err := c.txPipelineReadQueued(rd, cmds, failedCmds) + if err != nil { + moved, ask, addr := isMovedError(err) + if moved || ask { + return c.cmdsMoved(cmds, moved, ask, addr, failedCmds) + } + return err + } + return pipelineReadCmds(rd, cmds) + }) + }) + }) +} + func (c *ClusterClient) txPipelineReadQueued( rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, ) error { diff --git a/cluster_test.go b/cluster_test.go index 05eaeb4d..51796302 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -527,6 +527,184 @@ var _ = Describe("ClusterClient", func() { err := pubsub.Ping() Expect(err).NotTo(HaveOccurred()) }) + + It("supports Process hook", func() { + var masters []*redis.Client + + err := client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.ForEachMaster(func(master *redis.Client) error { + masters = append(masters, master) + return master.Ping().Err() + }) + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + clusterHook := &hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcess") + return nil + }, + } + client.AddHook(clusterHook) + + masterHook := &hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcess") + return nil + }, + } + + for _, master := range masters { + master.AddHook(masterHook) + } + + err = client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "cluster.BeforeProcess", + "shard.BeforeProcess", + "shard.AfterProcess", + "cluster.AfterProcess", + })) + + clusterHook.beforeProcess = nil + clusterHook.afterProcess = nil + masterHook.beforeProcess = nil + masterHook.afterProcess = nil + }) + + It("supports Pipeline hook", func() { + var masters []*redis.Client + + err := client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.ForEachMaster(func(master *redis.Client) error { + masters = append(masters, master) + return master.Ping().Err() + }) + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + client.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcessPipeline") + return nil + }, + }) + + for _, master := range masters { + master.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + } + + _, err = client.Pipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "cluster.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "cluster.AfterProcessPipeline", + })) + }) + + It("supports TxPipeline hook", func() { + var masters []*redis.Client + + err := client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.ForEachMaster(func(master *redis.Client) error { + masters = append(masters, master) + return master.Ping().Err() + }) + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + client.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcessPipeline") + return nil + }, + }) + + for _, master := range masters { + master.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + } + + _, err = client.TxPipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "cluster.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "cluster.AfterProcessPipeline", + })) + }) } Describe("ClusterClient", func() { diff --git a/command.go b/command.go index 8c5fbe24..266e3892 100644 --- a/command.go +++ b/command.go @@ -15,6 +15,7 @@ import ( type Cmder interface { Name() string Args() []interface{} + String() string stringArg(int) string readTimeout() *time.Duration @@ -152,6 +153,10 @@ func NewCmd(args ...interface{}) *Cmd { } } +func (cmd *Cmd) String() string { + return cmdString(cmd, cmd.val) +} + func (cmd *Cmd) Val() interface{} { return cmd.val } @@ -160,7 +165,7 @@ func (cmd *Cmd) Result() (interface{}, error) { return cmd.val, cmd.err } -func (cmd *Cmd) String() (string, error) { +func (cmd *Cmd) Text() (string, error) { if cmd.err != nil { return "", cmd.err } diff --git a/example_test.go b/example_test.go index b38f71f6..801dc965 100644 --- a/example_test.go +++ b/example_test.go @@ -447,7 +447,7 @@ func Example_customCommand() { } func Example_customCommand2() { - v, err := rdb.Do("get", "key_does_not_exist").String() + v, err := rdb.Do("get", "key_does_not_exist").Text() fmt.Printf("%q %s", v, err) // Output: "" redis: nil } diff --git a/main_test.go b/main_test.go index 958a00d3..f7f028a1 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "errors" "fmt" "net" @@ -370,3 +371,41 @@ func (cn *badConn) Write([]byte) (int, error) { } return 0, badConnError("bad connection") } + +//------------------------------------------------------------------------------ + +type hook struct { + beforeProcess func(ctx context.Context, cmd redis.Cmder) (context.Context, error) + afterProcess func(ctx context.Context, cmd redis.Cmder) error + + beforeProcessPipeline func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) + afterProcessPipeline func(ctx context.Context, cmds []redis.Cmder) error +} + +func (h *hook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + if h.beforeProcess != nil { + return h.beforeProcess(ctx, cmd) + } + return ctx, nil +} + +func (h *hook) AfterProcess(ctx context.Context, cmd redis.Cmder) error { + if h.afterProcess != nil { + return h.afterProcess(ctx, cmd) + } + return nil +} + +func (h *hook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + if h.beforeProcessPipeline != nil { + return h.beforeProcessPipeline(ctx, cmds) + } + return ctx, nil +} + +func (h *hook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error { + if h.afterProcessPipeline != nil { + return h.afterProcessPipeline(ctx, cmds) + } + return nil +} diff --git a/ring.go b/ring.go index b0c6dfc3..40bafc28 100644 --- a/ring.go +++ b/ring.go @@ -581,7 +581,7 @@ func (c *Ring) _process(ctx context.Context, cmd Cmder) error { return err } - lastErr = shard.Client._process(ctx, cmd) + lastErr = shard.Client.ProcessContext(ctx, cmd) if lastErr == nil || !isRetryableError(lastErr, cmd.readTimeout() == nil) { return lastErr } @@ -646,10 +646,7 @@ func (c *Ring) generalProcessPipeline( go func(hash string, cmds []Cmder) { defer wg.Done() - err := c.processShardPipeline(ctx, hash, cmds, tx) - if err != nil { - setCmdsErr(cmds, err) - } + _ = c.processShardPipeline(ctx, hash, cmds, tx) }(hash, cmds) } @@ -663,15 +660,14 @@ func (c *Ring) processShardPipeline( //TODO: retry? shard, err := c.shards.GetByHash(hash) if err != nil { + setCmdsErr(cmds, err) return err } if tx { - err = shard.Client._generalProcessPipeline( - ctx, cmds, shard.Client.txPipelineProcessCmds) + err = shard.Client.processPipeline(ctx, cmds) } else { - err = shard.Client._generalProcessPipeline( - ctx, cmds, shard.Client.pipelineProcessCmds) + err = shard.Client.processTxPipeline(ctx, cmds) } return err } diff --git a/ring_test.go b/ring_test.go index ac9d9a9e..a0fdd462 100644 --- a/ring_test.go +++ b/ring_test.go @@ -195,6 +195,155 @@ var _ = Describe("Redis Ring", func() { Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set")) }) }) + + It("supports Process hook", func() { + err := ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + ring.AddHook(&hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcess") + return nil + }, + }) + + ring.ForEachShard(func(shard *redis.Client) error { + shard.AddHook(&hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcess") + return nil + }, + }) + return nil + }) + + err = ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "ring.BeforeProcess", + "shard.BeforeProcess", + "shard.AfterProcess", + "ring.AfterProcess", + })) + }) + + It("supports Pipeline hook", func() { + err := ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + ring.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcessPipeline") + return nil + }, + }) + + ring.ForEachShard(func(shard *redis.Client) error { + shard.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + return nil + }) + + _, err = ring.Pipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "ring.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "ring.AfterProcessPipeline", + })) + }) + + It("supports TxPipeline hook", func() { + err := ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + ring.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcessPipeline") + return nil + }, + }) + + ring.ForEachShard(func(shard *redis.Client) error { + shard.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + return nil + }) + + _, err = ring.TxPipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "ring.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "ring.AfterProcessPipeline", + })) + }) }) var _ = Describe("empty Redis Ring", func() { From 218b17f0fc567ce62f7d6d0607caff1d8100354f Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Fri, 14 Feb 2020 15:37:35 +0200 Subject: [PATCH 3/3] Include multi & exec in pipeline hook --- bench_test.go | 6 ++-- cluster.go | 20 ++++++++----- cluster_test.go | 78 +++++++++++++++++++++++-------------------------- race_test.go | 2 +- redis.go | 50 ++++++++++++++++--------------- ring.go | 4 +-- ring_test.go | 8 ++--- tx.go | 2 +- 8 files changed, 85 insertions(+), 85 deletions(-) diff --git a/bench_test.go b/bench_test.go index f670a7c5..1646724c 100644 --- a/bench_test.go +++ b/bench_test.go @@ -254,7 +254,7 @@ func BenchmarkClusterPing(b *testing.B) { } defer stopCluster(cluster) - client := cluster.clusterClient(redisClusterOptions()) + client := cluster.newClusterClient(redisClusterOptions()) defer client.Close() b.ResetTimer() @@ -280,7 +280,7 @@ func BenchmarkClusterSetString(b *testing.B) { } defer stopCluster(cluster) - client := cluster.clusterClient(redisClusterOptions()) + client := cluster.newClusterClient(redisClusterOptions()) defer client.Close() value := string(bytes.Repeat([]byte{'1'}, 10000)) @@ -308,7 +308,7 @@ func BenchmarkClusterReloadState(b *testing.B) { } defer stopCluster(cluster) - client := cluster.clusterClient(redisClusterOptions()) + client := cluster.newClusterClient(redisClusterOptions()) defer client.Close() b.ResetTimer() diff --git a/cluster.go b/cluster.go index 5c6a97c6..2cf6e01b 100644 --- a/cluster.go +++ b/cluster.go @@ -773,7 +773,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error { if ask { pipe := node.Client.Pipeline() - _ = pipe.Process(NewCmd("ASKING")) + _ = pipe.Process(NewCmd("asking")) _ = pipe.Process(cmd) _, lastErr = pipe.ExecContext(ctx) _ = pipe.Close() @@ -1200,7 +1200,7 @@ func (c *ClusterClient) checkMovedErr( } if ask { - failedCmds.Add(node, NewCmd("ASKING"), cmd) + failedCmds.Add(node, NewCmd("asking"), cmd) return true } @@ -1294,17 +1294,21 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { func (c *ClusterClient) _processTxPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) error { - return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + return node.Client.hooks.processTxPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return txPipelineWriteMulti(wr, cmds) + return writeCmds(wr, cmds) }) if err != nil { return err } return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - err := c.txPipelineReadQueued(rd, cmds, failedCmds) + statusCmd := cmds[0].(*StatusCmd) + // Trim multi and exec. + cmds = cmds[1 : len(cmds)-1] + + err := c.txPipelineReadQueued(rd, statusCmd, cmds, failedCmds) if err != nil { moved, ask, addr := isMovedError(err) if moved || ask { @@ -1312,6 +1316,7 @@ func (c *ClusterClient) _processTxPipelineNode( } return err } + return pipelineReadCmds(rd, cmds) }) }) @@ -1319,10 +1324,9 @@ func (c *ClusterClient) _processTxPipelineNode( } func (c *ClusterClient) txPipelineReadQueued( - rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, + rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder, failedCmds *cmdsMap, ) error { // Parse queued replies. - var statusCmd StatusCmd if err := statusCmd.readReply(rd); err != nil { return err } @@ -1374,7 +1378,7 @@ func (c *ClusterClient) cmdsMoved( if ask { for _, cmd := range cmds { - failedCmds.Add(node, NewCmd("ASKING"), cmd) + failedCmds.Add(node, NewCmd("asking"), cmd) } return nil } diff --git a/cluster_test.go b/cluster_test.go index 51796302..62fd6e0d 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -47,14 +47,14 @@ func (s *clusterScenario) addrs() []string { return addrs } -func (s *clusterScenario) clusterClientUnsafe(opt *redis.ClusterOptions) *redis.ClusterClient { +func (s *clusterScenario) newClusterClientUnsafe(opt *redis.ClusterOptions) *redis.ClusterClient { opt.Addrs = s.addrs() return redis.NewClusterClient(opt) } -func (s *clusterScenario) clusterClient(opt *redis.ClusterOptions) *redis.ClusterClient { - client := s.clusterClientUnsafe(opt) +func (s *clusterScenario) newClusterClient(opt *redis.ClusterOptions) *redis.ClusterClient { + client := s.newClusterClientUnsafe(opt) err := eventually(func() error { if opt.ClusterSlots != nil { @@ -529,14 +529,11 @@ var _ = Describe("ClusterClient", func() { }) It("supports Process hook", func() { - var masters []*redis.Client - err := client.Ping().Err() Expect(err).NotTo(HaveOccurred()) - err = client.ForEachMaster(func(master *redis.Client) error { - masters = append(masters, master) - return master.Ping().Err() + err = client.ForEachNode(func(node *redis.Client) error { + return node.Ping().Err() }) Expect(err).NotTo(HaveOccurred()) @@ -556,7 +553,7 @@ var _ = Describe("ClusterClient", func() { } client.AddHook(clusterHook) - masterHook := &hook{ + nodeHook := &hook{ beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { Expect(cmd.String()).To(Equal("ping: ")) stack = append(stack, "shard.BeforeProcess") @@ -569,9 +566,10 @@ var _ = Describe("ClusterClient", func() { }, } - for _, master := range masters { - master.AddHook(masterHook) - } + _ = client.ForEachNode(func(node *redis.Client) error { + node.AddHook(nodeHook) + return nil + }) err = client.Ping().Err() Expect(err).NotTo(HaveOccurred()) @@ -584,19 +582,16 @@ var _ = Describe("ClusterClient", func() { clusterHook.beforeProcess = nil clusterHook.afterProcess = nil - masterHook.beforeProcess = nil - masterHook.afterProcess = nil + nodeHook.beforeProcess = nil + nodeHook.afterProcess = nil }) It("supports Pipeline hook", func() { - var masters []*redis.Client - err := client.Ping().Err() Expect(err).NotTo(HaveOccurred()) - err = client.ForEachMaster(func(master *redis.Client) error { - masters = append(masters, master) - return master.Ping().Err() + err = client.ForEachNode(func(node *redis.Client) error { + return node.Ping().Err() }) Expect(err).NotTo(HaveOccurred()) @@ -617,8 +612,8 @@ var _ = Describe("ClusterClient", func() { }, }) - for _, master := range masters { - master.AddHook(&hook{ + _ = client.ForEachNode(func(node *redis.Client) error { + node.AddHook(&hook{ beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: ")) @@ -632,7 +627,8 @@ var _ = Describe("ClusterClient", func() { return nil }, }) - } + return nil + }) _, err = client.Pipelined(func(pipe redis.Pipeliner) error { pipe.Ping() @@ -648,14 +644,11 @@ var _ = Describe("ClusterClient", func() { }) It("supports TxPipeline hook", func() { - var masters []*redis.Client - err := client.Ping().Err() Expect(err).NotTo(HaveOccurred()) - err = client.ForEachMaster(func(master *redis.Client) error { - masters = append(masters, master) - return master.Ping().Err() + err = client.ForEachNode(func(node *redis.Client) error { + return node.Ping().Err() }) Expect(err).NotTo(HaveOccurred()) @@ -676,22 +669,23 @@ var _ = Describe("ClusterClient", func() { }, }) - for _, master := range masters { - master.AddHook(&hook{ + _ = client.ForEachNode(func(node *redis.Client) error { + node.AddHook(&hook{ beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) stack = append(stack, "shard.BeforeProcessPipeline") return ctx, nil }, afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) stack = append(stack, "shard.AfterProcessPipeline") return nil }, }) - } + return nil + }) _, err = client.TxPipelined(func(pipe redis.Pipeliner) error { pipe.Ping() @@ -710,7 +704,7 @@ var _ = Describe("ClusterClient", func() { Describe("ClusterClient", func() { BeforeEach(func() { opt = redisClusterOptions() - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -733,7 +727,7 @@ var _ = Describe("ClusterClient", func() { It("returns an error when there are no attempts left", func() { opt := redisClusterOptions() opt.MaxRedirects = -1 - client := cluster.clusterClient(opt) + client := cluster.newClusterClient(opt) Eventually(func() error { return client.SwapNodes("A") @@ -885,7 +879,7 @@ var _ = Describe("ClusterClient", func() { opt = redisClusterOptions() opt.MinRetryBackoff = 250 * time.Millisecond opt.MaxRetryBackoff = time.Second - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -935,7 +929,7 @@ var _ = Describe("ClusterClient", func() { BeforeEach(func() { opt = redisClusterOptions() opt.RouteByLatency = true - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -991,7 +985,7 @@ var _ = Describe("ClusterClient", func() { }} return slots, nil } - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -1045,7 +1039,7 @@ var _ = Describe("ClusterClient", func() { }} return slots, nil } - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -1137,7 +1131,7 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() { opt.ReadTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond opt.MaxRedirects = 1 - client = cluster.clusterClientUnsafe(opt) + client = cluster.newClusterClientUnsafe(opt) }) AfterEach(func() { @@ -1206,7 +1200,7 @@ var _ = Describe("ClusterClient timeout", func() { opt.ReadTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond opt.MaxRedirects = 1 - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachNode(func(client *redis.Client) error { return client.ClientPause(pause).Err() diff --git a/race_test.go b/race_test.go index afe06cf8..9a575555 100644 --- a/race_test.go +++ b/race_test.go @@ -299,7 +299,7 @@ var _ = Describe("cluster races", func() { BeforeEach(func() { opt := redisClusterOptions() - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) C, N = 10, 1000 if testing.Short() { diff --git a/redis.go b/redis.go index caba631b..93032579 100644 --- a/redis.go +++ b/redis.go @@ -128,6 +128,13 @@ func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error { return firstErr } +func (hs hooks) processTxPipeline( + ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, +) error { + cmds = wrapMultiExec(cmds) + return hs.processPipeline(ctx, cmds, fn) +} + //------------------------------------------------------------------------------ type baseClient struct { @@ -437,51 +444,46 @@ func (c *baseClient) txPipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return txPipelineWriteMulti(wr, cmds) + return writeCmds(wr, cmds) }) if err != nil { return true, err } err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - err := txPipelineReadQueued(rd, cmds) + statusCmd := cmds[0].(*StatusCmd) + // Trim multi and exec. + cmds = cmds[1 : len(cmds)-1] + + err := txPipelineReadQueued(rd, statusCmd, cmds) if err != nil { return err } + return pipelineReadCmds(rd, cmds) }) return false, err } -var ( - multi = NewStatusCmd("multi") - exec = NewSliceCmd("exec") -) - -func txPipelineWriteMulti(wr *proto.Writer, cmds []Cmder) error { - if err := writeCmd(wr, multi); err != nil { - return err +func wrapMultiExec(cmds []Cmder) []Cmder { + if len(cmds) == 0 { + panic("not reached") } - if err := writeCmds(wr, cmds); err != nil { - return err - } - if err := writeCmd(wr, exec); err != nil { - return err - } - return nil + cmds = append(cmds, make([]Cmder, 2)...) + copy(cmds[1:], cmds[:len(cmds)-2]) + cmds[0] = NewStatusCmd("multi") + cmds[len(cmds)-1] = NewSliceCmd("exec") + return cmds } -func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error { +func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { // Parse queued replies. - var statusCmd StatusCmd - err := statusCmd.readReply(rd) - if err != nil { + if err := statusCmd.readReply(rd); err != nil { return err } for range cmds { - err = statusCmd.readReply(rd) - if err != nil && !isRedisError(err) { + if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) { return err } } @@ -587,7 +589,7 @@ func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error { } func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline) + return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) } // Options returns read-only Options that were used to create the client. diff --git a/ring.go b/ring.go index 40bafc28..e1b49917 100644 --- a/ring.go +++ b/ring.go @@ -665,9 +665,9 @@ func (c *Ring) processShardPipeline( } if tx { - err = shard.Client.processPipeline(ctx, cmds) - } else { err = shard.Client.processTxPipeline(ctx, cmds) + } else { + err = shard.Client.processPipeline(ctx, cmds) } return err } diff --git a/ring_test.go b/ring_test.go index a0fdd462..eef8dc25 100644 --- a/ring_test.go +++ b/ring_test.go @@ -317,14 +317,14 @@ var _ = Describe("Redis Ring", func() { ring.ForEachShard(func(shard *redis.Client) error { shard.AddHook(&hook{ beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) stack = append(stack, "shard.BeforeProcessPipeline") return ctx, nil }, afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) stack = append(stack, "shard.AfterProcessPipeline") return nil }, diff --git a/tx.go b/tx.go index 4af5ebe6..9ae15901 100644 --- a/tx.go +++ b/tx.go @@ -151,7 +151,7 @@ func (c *Tx) TxPipeline() Pipeliner { pipe := Pipeline{ ctx: c.ctx, exec: func(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline) + return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) }, } pipe.init()