diff --git a/cluster.go b/cluster.go index 12721c3f..5c6a97c6 100644 --- a/cluster.go +++ b/cluster.go @@ -779,7 +779,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error { _ = pipe.Close() ask = false } else { - lastErr = node.Client._process(ctx, cmd) + lastErr = node.Client.ProcessContext(ctx, cmd) } // If there is no error - we are done. @@ -840,6 +840,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { var wg sync.WaitGroup errCh := make(chan error, 1) + for _, master := range state.Masters { wg.Add(1) go func(node *clusterNode) { @@ -853,6 +854,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { } }(master) } + wg.Wait() select { @@ -873,6 +875,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { var wg sync.WaitGroup errCh := make(chan error, 1) + for _, slave := range state.Slaves { wg.Add(1) go func(node *clusterNode) { @@ -886,6 +889,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { } }(slave) } + wg.Wait() select { @@ -906,6 +910,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { var wg sync.WaitGroup errCh := make(chan error, 1) + worker := func(node *clusterNode) { defer wg.Done() err := fn(node.Client) @@ -927,6 +932,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { } wg.Wait() + select { case err := <-errCh: return err @@ -1068,18 +1074,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmds(wr, cmds) - }) - if err != nil { - return err - } - - return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - return c.pipelineReadCmds(node, rd, cmds, failedCmds) - }) - }) + err := c._processPipelineNode(ctx, node, cmds, failedCmds) if err == nil { return } @@ -1142,6 +1137,25 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool { return true } +func (c *ClusterClient) _processPipelineNode( + ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, +) error { + return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmds(wr, cmds) + }) + if err != nil { + return err + } + + return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { + return c.pipelineReadCmds(node, rd, cmds, failedCmds) + }) + }) + }) +} + func (c *ClusterClient) pipelineReadCmds( node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, ) error { @@ -1243,26 +1257,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er go func(node *clusterNode, cmds []Cmder) { defer wg.Done() - err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { - err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return txPipelineWriteMulti(wr, cmds) - }) - if err != nil { - return err - } - - return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - err := c.txPipelineReadQueued(rd, cmds, failedCmds) - if err != nil { - moved, ask, addr := isMovedError(err) - if moved || ask { - return c.cmdsMoved(cmds, moved, ask, addr, failedCmds) - } - return err - } - return pipelineReadCmds(rd, cmds) - }) - }) + err := c._processTxPipelineNode(ctx, node, cmds, failedCmds) if err == nil { return } @@ -1296,6 +1291,33 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { return cmdsMap } +func (c *ClusterClient) _processTxPipelineNode( + ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, +) error { + return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { + return txPipelineWriteMulti(wr, cmds) + }) + if err != nil { + return err + } + + return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { + err := c.txPipelineReadQueued(rd, cmds, failedCmds) + if err != nil { + moved, ask, addr := isMovedError(err) + if moved || ask { + return c.cmdsMoved(cmds, moved, ask, addr, failedCmds) + } + return err + } + return pipelineReadCmds(rd, cmds) + }) + }) + }) +} + func (c *ClusterClient) txPipelineReadQueued( rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, ) error { diff --git a/cluster_test.go b/cluster_test.go index 05eaeb4d..51796302 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -527,6 +527,184 @@ var _ = Describe("ClusterClient", func() { err := pubsub.Ping() Expect(err).NotTo(HaveOccurred()) }) + + It("supports Process hook", func() { + var masters []*redis.Client + + err := client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.ForEachMaster(func(master *redis.Client) error { + masters = append(masters, master) + return master.Ping().Err() + }) + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + clusterHook := &hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcess") + return nil + }, + } + client.AddHook(clusterHook) + + masterHook := &hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcess") + return nil + }, + } + + for _, master := range masters { + master.AddHook(masterHook) + } + + err = client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "cluster.BeforeProcess", + "shard.BeforeProcess", + "shard.AfterProcess", + "cluster.AfterProcess", + })) + + clusterHook.beforeProcess = nil + clusterHook.afterProcess = nil + masterHook.beforeProcess = nil + masterHook.afterProcess = nil + }) + + It("supports Pipeline hook", func() { + var masters []*redis.Client + + err := client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.ForEachMaster(func(master *redis.Client) error { + masters = append(masters, master) + return master.Ping().Err() + }) + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + client.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcessPipeline") + return nil + }, + }) + + for _, master := range masters { + master.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + } + + _, err = client.Pipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "cluster.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "cluster.AfterProcessPipeline", + })) + }) + + It("supports TxPipeline hook", func() { + var masters []*redis.Client + + err := client.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.ForEachMaster(func(master *redis.Client) error { + masters = append(masters, master) + return master.Ping().Err() + }) + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + client.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "cluster.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "cluster.AfterProcessPipeline") + return nil + }, + }) + + for _, master := range masters { + master.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + } + + _, err = client.TxPipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "cluster.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "cluster.AfterProcessPipeline", + })) + }) } Describe("ClusterClient", func() { diff --git a/command.go b/command.go index 8c5fbe24..266e3892 100644 --- a/command.go +++ b/command.go @@ -15,6 +15,7 @@ import ( type Cmder interface { Name() string Args() []interface{} + String() string stringArg(int) string readTimeout() *time.Duration @@ -152,6 +153,10 @@ func NewCmd(args ...interface{}) *Cmd { } } +func (cmd *Cmd) String() string { + return cmdString(cmd, cmd.val) +} + func (cmd *Cmd) Val() interface{} { return cmd.val } @@ -160,7 +165,7 @@ func (cmd *Cmd) Result() (interface{}, error) { return cmd.val, cmd.err } -func (cmd *Cmd) String() (string, error) { +func (cmd *Cmd) Text() (string, error) { if cmd.err != nil { return "", cmd.err } diff --git a/example_test.go b/example_test.go index b38f71f6..801dc965 100644 --- a/example_test.go +++ b/example_test.go @@ -447,7 +447,7 @@ func Example_customCommand() { } func Example_customCommand2() { - v, err := rdb.Do("get", "key_does_not_exist").String() + v, err := rdb.Do("get", "key_does_not_exist").Text() fmt.Printf("%q %s", v, err) // Output: "" redis: nil } diff --git a/main_test.go b/main_test.go index 958a00d3..f7f028a1 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "errors" "fmt" "net" @@ -370,3 +371,41 @@ func (cn *badConn) Write([]byte) (int, error) { } return 0, badConnError("bad connection") } + +//------------------------------------------------------------------------------ + +type hook struct { + beforeProcess func(ctx context.Context, cmd redis.Cmder) (context.Context, error) + afterProcess func(ctx context.Context, cmd redis.Cmder) error + + beforeProcessPipeline func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) + afterProcessPipeline func(ctx context.Context, cmds []redis.Cmder) error +} + +func (h *hook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + if h.beforeProcess != nil { + return h.beforeProcess(ctx, cmd) + } + return ctx, nil +} + +func (h *hook) AfterProcess(ctx context.Context, cmd redis.Cmder) error { + if h.afterProcess != nil { + return h.afterProcess(ctx, cmd) + } + return nil +} + +func (h *hook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + if h.beforeProcessPipeline != nil { + return h.beforeProcessPipeline(ctx, cmds) + } + return ctx, nil +} + +func (h *hook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error { + if h.afterProcessPipeline != nil { + return h.afterProcessPipeline(ctx, cmds) + } + return nil +} diff --git a/ring.go b/ring.go index b0c6dfc3..40bafc28 100644 --- a/ring.go +++ b/ring.go @@ -581,7 +581,7 @@ func (c *Ring) _process(ctx context.Context, cmd Cmder) error { return err } - lastErr = shard.Client._process(ctx, cmd) + lastErr = shard.Client.ProcessContext(ctx, cmd) if lastErr == nil || !isRetryableError(lastErr, cmd.readTimeout() == nil) { return lastErr } @@ -646,10 +646,7 @@ func (c *Ring) generalProcessPipeline( go func(hash string, cmds []Cmder) { defer wg.Done() - err := c.processShardPipeline(ctx, hash, cmds, tx) - if err != nil { - setCmdsErr(cmds, err) - } + _ = c.processShardPipeline(ctx, hash, cmds, tx) }(hash, cmds) } @@ -663,15 +660,14 @@ func (c *Ring) processShardPipeline( //TODO: retry? shard, err := c.shards.GetByHash(hash) if err != nil { + setCmdsErr(cmds, err) return err } if tx { - err = shard.Client._generalProcessPipeline( - ctx, cmds, shard.Client.txPipelineProcessCmds) + err = shard.Client.processPipeline(ctx, cmds) } else { - err = shard.Client._generalProcessPipeline( - ctx, cmds, shard.Client.pipelineProcessCmds) + err = shard.Client.processTxPipeline(ctx, cmds) } return err } diff --git a/ring_test.go b/ring_test.go index ac9d9a9e..a0fdd462 100644 --- a/ring_test.go +++ b/ring_test.go @@ -195,6 +195,155 @@ var _ = Describe("Redis Ring", func() { Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set")) }) }) + + It("supports Process hook", func() { + err := ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + ring.AddHook(&hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcess") + return nil + }, + }) + + ring.ForEachShard(func(shard *redis.Client) error { + shard.AddHook(&hook{ + beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { + Expect(cmd.String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcess") + return ctx, nil + }, + afterProcess: func(ctx context.Context, cmd redis.Cmder) error { + Expect(cmd.String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcess") + return nil + }, + }) + return nil + }) + + err = ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "ring.BeforeProcess", + "shard.BeforeProcess", + "shard.AfterProcess", + "ring.AfterProcess", + })) + }) + + It("supports Pipeline hook", func() { + err := ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + ring.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcessPipeline") + return nil + }, + }) + + ring.ForEachShard(func(shard *redis.Client) error { + shard.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + return nil + }) + + _, err = ring.Pipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "ring.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "ring.AfterProcessPipeline", + })) + }) + + It("supports TxPipeline hook", func() { + err := ring.Ping().Err() + Expect(err).NotTo(HaveOccurred()) + + var stack []string + + ring.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "ring.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "ring.AfterProcessPipeline") + return nil + }, + }) + + ring.ForEachShard(func(shard *redis.Client) error { + shard.AddHook(&hook{ + beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: ")) + stack = append(stack, "shard.BeforeProcessPipeline") + return ctx, nil + }, + afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].String()).To(Equal("ping: PONG")) + stack = append(stack, "shard.AfterProcessPipeline") + return nil + }, + }) + return nil + }) + + _, err = ring.TxPipelined(func(pipe redis.Pipeliner) error { + pipe.Ping() + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(stack).To(Equal([]string{ + "ring.BeforeProcessPipeline", + "shard.BeforeProcessPipeline", + "shard.AfterProcessPipeline", + "ring.AfterProcessPipeline", + })) + }) }) var _ = Describe("empty Redis Ring", func() {