diff --git a/bench_test.go b/bench_test.go index f670a7c..1646724 100644 --- a/bench_test.go +++ b/bench_test.go @@ -254,7 +254,7 @@ func BenchmarkClusterPing(b *testing.B) { } defer stopCluster(cluster) - client := cluster.clusterClient(redisClusterOptions()) + client := cluster.newClusterClient(redisClusterOptions()) defer client.Close() b.ResetTimer() @@ -280,7 +280,7 @@ func BenchmarkClusterSetString(b *testing.B) { } defer stopCluster(cluster) - client := cluster.clusterClient(redisClusterOptions()) + client := cluster.newClusterClient(redisClusterOptions()) defer client.Close() value := string(bytes.Repeat([]byte{'1'}, 10000)) @@ -308,7 +308,7 @@ func BenchmarkClusterReloadState(b *testing.B) { } defer stopCluster(cluster) - client := cluster.clusterClient(redisClusterOptions()) + client := cluster.newClusterClient(redisClusterOptions()) defer client.Close() b.ResetTimer() diff --git a/cluster.go b/cluster.go index 5c6a97c..2cf6e01 100644 --- a/cluster.go +++ b/cluster.go @@ -773,7 +773,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error { if ask { pipe := node.Client.Pipeline() - _ = pipe.Process(NewCmd("ASKING")) + _ = pipe.Process(NewCmd("asking")) _ = pipe.Process(cmd) _, lastErr = pipe.ExecContext(ctx) _ = pipe.Close() @@ -1200,7 +1200,7 @@ func (c *ClusterClient) checkMovedErr( } if ask { - failedCmds.Add(node, NewCmd("ASKING"), cmd) + failedCmds.Add(node, NewCmd("asking"), cmd) return true } @@ -1294,17 +1294,21 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder { func (c *ClusterClient) _processTxPipelineNode( ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap, ) error { - return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { + return node.Client.hooks.processTxPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return txPipelineWriteMulti(wr, cmds) + return writeCmds(wr, cmds) }) if err != nil { return err } return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - err := c.txPipelineReadQueued(rd, cmds, failedCmds) + statusCmd := cmds[0].(*StatusCmd) + // Trim multi and exec. + cmds = cmds[1 : len(cmds)-1] + + err := c.txPipelineReadQueued(rd, statusCmd, cmds, failedCmds) if err != nil { moved, ask, addr := isMovedError(err) if moved || ask { @@ -1312,6 +1316,7 @@ func (c *ClusterClient) _processTxPipelineNode( } return err } + return pipelineReadCmds(rd, cmds) }) }) @@ -1319,10 +1324,9 @@ func (c *ClusterClient) _processTxPipelineNode( } func (c *ClusterClient) txPipelineReadQueued( - rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, + rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder, failedCmds *cmdsMap, ) error { // Parse queued replies. - var statusCmd StatusCmd if err := statusCmd.readReply(rd); err != nil { return err } @@ -1374,7 +1378,7 @@ func (c *ClusterClient) cmdsMoved( if ask { for _, cmd := range cmds { - failedCmds.Add(node, NewCmd("ASKING"), cmd) + failedCmds.Add(node, NewCmd("asking"), cmd) } return nil } diff --git a/cluster_test.go b/cluster_test.go index 5179630..62fd6e0 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -47,14 +47,14 @@ func (s *clusterScenario) addrs() []string { return addrs } -func (s *clusterScenario) clusterClientUnsafe(opt *redis.ClusterOptions) *redis.ClusterClient { +func (s *clusterScenario) newClusterClientUnsafe(opt *redis.ClusterOptions) *redis.ClusterClient { opt.Addrs = s.addrs() return redis.NewClusterClient(opt) } -func (s *clusterScenario) clusterClient(opt *redis.ClusterOptions) *redis.ClusterClient { - client := s.clusterClientUnsafe(opt) +func (s *clusterScenario) newClusterClient(opt *redis.ClusterOptions) *redis.ClusterClient { + client := s.newClusterClientUnsafe(opt) err := eventually(func() error { if opt.ClusterSlots != nil { @@ -529,14 +529,11 @@ var _ = Describe("ClusterClient", func() { }) It("supports Process hook", func() { - var masters []*redis.Client - err := client.Ping().Err() Expect(err).NotTo(HaveOccurred()) - err = client.ForEachMaster(func(master *redis.Client) error { - masters = append(masters, master) - return master.Ping().Err() + err = client.ForEachNode(func(node *redis.Client) error { + return node.Ping().Err() }) Expect(err).NotTo(HaveOccurred()) @@ -556,7 +553,7 @@ var _ = Describe("ClusterClient", func() { } client.AddHook(clusterHook) - masterHook := &hook{ + nodeHook := &hook{ beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { Expect(cmd.String()).To(Equal("ping: ")) stack = append(stack, "shard.BeforeProcess") @@ -569,9 +566,10 @@ var _ = Describe("ClusterClient", func() { }, } - for _, master := range masters { - master.AddHook(masterHook) - } + _ = client.ForEachNode(func(node *redis.Client) error { + node.AddHook(nodeHook) + return nil + }) err = client.Ping().Err() Expect(err).NotTo(HaveOccurred()) @@ -584,19 +582,16 @@ var _ = Describe("ClusterClient", func() { clusterHook.beforeProcess = nil clusterHook.afterProcess = nil - masterHook.beforeProcess = nil - masterHook.afterProcess = nil + nodeHook.beforeProcess = nil + nodeHook.afterProcess = nil }) It("supports Pipeline hook", func() { - var masters []*redis.Client - err := client.Ping().Err() Expect(err).NotTo(HaveOccurred()) - err = client.ForEachMaster(func(master *redis.Client) error { - masters = append(masters, master) - return master.Ping().Err() + err = client.ForEachNode(func(node *redis.Client) error { + return node.Ping().Err() }) Expect(err).NotTo(HaveOccurred()) @@ -617,8 +612,8 @@ var _ = Describe("ClusterClient", func() { }, }) - for _, master := range masters { - master.AddHook(&hook{ + _ = client.ForEachNode(func(node *redis.Client) error { + node.AddHook(&hook{ beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: ")) @@ -632,7 +627,8 @@ var _ = Describe("ClusterClient", func() { return nil }, }) - } + return nil + }) _, err = client.Pipelined(func(pipe redis.Pipeliner) error { pipe.Ping() @@ -648,14 +644,11 @@ var _ = Describe("ClusterClient", func() { }) It("supports TxPipeline hook", func() { - var masters []*redis.Client - err := client.Ping().Err() Expect(err).NotTo(HaveOccurred()) - err = client.ForEachMaster(func(master *redis.Client) error { - masters = append(masters, master) - return master.Ping().Err() + err = client.ForEachNode(func(node *redis.Client) error { + return node.Ping().Err() }) Expect(err).NotTo(HaveOccurred()) @@ -676,22 +669,23 @@ var _ = Describe("ClusterClient", func() { }, }) - for _, master := range masters { - master.AddHook(&hook{ + _ = client.ForEachNode(func(node *redis.Client) error { + node.AddHook(&hook{ beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) stack = append(stack, "shard.BeforeProcessPipeline") return ctx, nil }, afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) stack = append(stack, "shard.AfterProcessPipeline") return nil }, }) - } + return nil + }) _, err = client.TxPipelined(func(pipe redis.Pipeliner) error { pipe.Ping() @@ -710,7 +704,7 @@ var _ = Describe("ClusterClient", func() { Describe("ClusterClient", func() { BeforeEach(func() { opt = redisClusterOptions() - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -733,7 +727,7 @@ var _ = Describe("ClusterClient", func() { It("returns an error when there are no attempts left", func() { opt := redisClusterOptions() opt.MaxRedirects = -1 - client := cluster.clusterClient(opt) + client := cluster.newClusterClient(opt) Eventually(func() error { return client.SwapNodes("A") @@ -885,7 +879,7 @@ var _ = Describe("ClusterClient", func() { opt = redisClusterOptions() opt.MinRetryBackoff = 250 * time.Millisecond opt.MaxRetryBackoff = time.Second - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -935,7 +929,7 @@ var _ = Describe("ClusterClient", func() { BeforeEach(func() { opt = redisClusterOptions() opt.RouteByLatency = true - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -991,7 +985,7 @@ var _ = Describe("ClusterClient", func() { }} return slots, nil } - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -1045,7 +1039,7 @@ var _ = Describe("ClusterClient", func() { }} return slots, nil } - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() @@ -1137,7 +1131,7 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() { opt.ReadTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond opt.MaxRedirects = 1 - client = cluster.clusterClientUnsafe(opt) + client = cluster.newClusterClientUnsafe(opt) }) AfterEach(func() { @@ -1206,7 +1200,7 @@ var _ = Describe("ClusterClient timeout", func() { opt.ReadTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond opt.MaxRedirects = 1 - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) err := client.ForEachNode(func(client *redis.Client) error { return client.ClientPause(pause).Err() diff --git a/race_test.go b/race_test.go index afe06cf..9a57555 100644 --- a/race_test.go +++ b/race_test.go @@ -299,7 +299,7 @@ var _ = Describe("cluster races", func() { BeforeEach(func() { opt := redisClusterOptions() - client = cluster.clusterClient(opt) + client = cluster.newClusterClient(opt) C, N = 10, 1000 if testing.Short() { diff --git a/redis.go b/redis.go index caba631..9303257 100644 --- a/redis.go +++ b/redis.go @@ -128,6 +128,13 @@ func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error { return firstErr } +func (hs hooks) processTxPipeline( + ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, +) error { + cmds = wrapMultiExec(cmds) + return hs.processPipeline(ctx, cmds, fn) +} + //------------------------------------------------------------------------------ type baseClient struct { @@ -437,51 +444,46 @@ func (c *baseClient) txPipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return txPipelineWriteMulti(wr, cmds) + return writeCmds(wr, cmds) }) if err != nil { return true, err } err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { - err := txPipelineReadQueued(rd, cmds) + statusCmd := cmds[0].(*StatusCmd) + // Trim multi and exec. + cmds = cmds[1 : len(cmds)-1] + + err := txPipelineReadQueued(rd, statusCmd, cmds) if err != nil { return err } + return pipelineReadCmds(rd, cmds) }) return false, err } -var ( - multi = NewStatusCmd("multi") - exec = NewSliceCmd("exec") -) - -func txPipelineWriteMulti(wr *proto.Writer, cmds []Cmder) error { - if err := writeCmd(wr, multi); err != nil { - return err +func wrapMultiExec(cmds []Cmder) []Cmder { + if len(cmds) == 0 { + panic("not reached") } - if err := writeCmds(wr, cmds); err != nil { - return err - } - if err := writeCmd(wr, exec); err != nil { - return err - } - return nil + cmds = append(cmds, make([]Cmder, 2)...) + copy(cmds[1:], cmds[:len(cmds)-2]) + cmds[0] = NewStatusCmd("multi") + cmds[len(cmds)-1] = NewSliceCmd("exec") + return cmds } -func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error { +func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { // Parse queued replies. - var statusCmd StatusCmd - err := statusCmd.readReply(rd) - if err != nil { + if err := statusCmd.readReply(rd); err != nil { return err } for range cmds { - err = statusCmd.readReply(rd) - if err != nil && !isRedisError(err) { + if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) { return err } } @@ -587,7 +589,7 @@ func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error { } func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline) + return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) } // Options returns read-only Options that were used to create the client. diff --git a/ring.go b/ring.go index 40bafc2..e1b4991 100644 --- a/ring.go +++ b/ring.go @@ -665,9 +665,9 @@ func (c *Ring) processShardPipeline( } if tx { - err = shard.Client.processPipeline(ctx, cmds) - } else { err = shard.Client.processTxPipeline(ctx, cmds) + } else { + err = shard.Client.processPipeline(ctx, cmds) } return err } diff --git a/ring_test.go b/ring_test.go index a0fdd46..eef8dc2 100644 --- a/ring_test.go +++ b/ring_test.go @@ -317,14 +317,14 @@ var _ = Describe("Redis Ring", func() { ring.ForEachShard(func(shard *redis.Client) error { shard.AddHook(&hook{ beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: ")) + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: ")) stack = append(stack, "shard.BeforeProcessPipeline") return ctx, nil }, afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].String()).To(Equal("ping: PONG")) + Expect(cmds).To(HaveLen(3)) + Expect(cmds[1].String()).To(Equal("ping: PONG")) stack = append(stack, "shard.AfterProcessPipeline") return nil }, diff --git a/tx.go b/tx.go index 4af5ebe..9ae1590 100644 --- a/tx.go +++ b/tx.go @@ -151,7 +151,7 @@ func (c *Tx) TxPipeline() Pipeliner { pipe := Pipeline{ ctx: c.ctx, exec: func(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline) + return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) }, } pipe.init()