Merge pull request from go-redis/fix/txPipelineWriteMulti

Don't allocate tmp slice in txPipelineWriteMulti
This commit is contained in:
Vladimir Mihailenco 2020-02-14 16:25:45 +02:00 committed by GitHub
commit 726f6807ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 486 additions and 84 deletions

View File

@ -254,7 +254,7 @@ func BenchmarkClusterPing(b *testing.B) {
} }
defer stopCluster(cluster) defer stopCluster(cluster)
client := cluster.clusterClient(redisClusterOptions()) client := cluster.newClusterClient(redisClusterOptions())
defer client.Close() defer client.Close()
b.ResetTimer() b.ResetTimer()
@ -280,7 +280,7 @@ func BenchmarkClusterSetString(b *testing.B) {
} }
defer stopCluster(cluster) defer stopCluster(cluster)
client := cluster.clusterClient(redisClusterOptions()) client := cluster.newClusterClient(redisClusterOptions())
defer client.Close() defer client.Close()
value := string(bytes.Repeat([]byte{'1'}, 10000)) value := string(bytes.Repeat([]byte{'1'}, 10000))
@ -308,7 +308,7 @@ func BenchmarkClusterReloadState(b *testing.B) {
} }
defer stopCluster(cluster) defer stopCluster(cluster)
client := cluster.clusterClient(redisClusterOptions()) client := cluster.newClusterClient(redisClusterOptions())
defer client.Close() defer client.Close()
b.ResetTimer() b.ResetTimer()

View File

@ -773,13 +773,13 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
if ask { if ask {
pipe := node.Client.Pipeline() pipe := node.Client.Pipeline()
_ = pipe.Process(NewCmd("ASKING")) _ = pipe.Process(NewCmd("asking"))
_ = pipe.Process(cmd) _ = pipe.Process(cmd)
_, lastErr = pipe.ExecContext(ctx) _, lastErr = pipe.ExecContext(ctx)
_ = pipe.Close() _ = pipe.Close()
ask = false ask = false
} else { } else {
lastErr = node.Client._process(ctx, cmd) lastErr = node.Client.ProcessContext(ctx, cmd)
} }
// If there is no error - we are done. // If there is no error - we are done.
@ -840,6 +840,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
var wg sync.WaitGroup var wg sync.WaitGroup
errCh := make(chan error, 1) errCh := make(chan error, 1)
for _, master := range state.Masters { for _, master := range state.Masters {
wg.Add(1) wg.Add(1)
go func(node *clusterNode) { go func(node *clusterNode) {
@ -853,6 +854,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
} }
}(master) }(master)
} }
wg.Wait() wg.Wait()
select { select {
@ -873,6 +875,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
var wg sync.WaitGroup var wg sync.WaitGroup
errCh := make(chan error, 1) errCh := make(chan error, 1)
for _, slave := range state.Slaves { for _, slave := range state.Slaves {
wg.Add(1) wg.Add(1)
go func(node *clusterNode) { go func(node *clusterNode) {
@ -886,6 +889,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
} }
}(slave) }(slave)
} }
wg.Wait() wg.Wait()
select { select {
@ -906,6 +910,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
var wg sync.WaitGroup var wg sync.WaitGroup
errCh := make(chan error, 1) errCh := make(chan error, 1)
worker := func(node *clusterNode) { worker := func(node *clusterNode) {
defer wg.Done() defer wg.Done()
err := fn(node.Client) err := fn(node.Client)
@ -927,6 +932,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
} }
wg.Wait() wg.Wait()
select { select {
case err := <-errCh: case err := <-errCh:
return err return err
@ -1068,18 +1074,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
go func(node *clusterNode, cmds []Cmder) { go func(node *clusterNode, cmds []Cmder) {
defer wg.Done() defer wg.Done()
err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { err := c._processPipelineNode(ctx, node, cmds, failedCmds)
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmds...)
})
if err != nil {
return err
}
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return c.pipelineReadCmds(node, rd, cmds, failedCmds)
})
})
if err == nil { if err == nil {
return return
} }
@ -1142,6 +1137,25 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool {
return true return true
} }
func (c *ClusterClient) _processPipelineNode(
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
) error {
return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
})
if err != nil {
return err
}
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return c.pipelineReadCmds(node, rd, cmds, failedCmds)
})
})
})
}
func (c *ClusterClient) pipelineReadCmds( func (c *ClusterClient) pipelineReadCmds(
node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap,
) error { ) error {
@ -1186,7 +1200,7 @@ func (c *ClusterClient) checkMovedErr(
} }
if ask { if ask {
failedCmds.Add(node, NewCmd("ASKING"), cmd) failedCmds.Add(node, NewCmd("asking"), cmd)
return true return true
} }
@ -1243,26 +1257,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
go func(node *clusterNode, cmds []Cmder) { go func(node *clusterNode, cmds []Cmder) {
defer wg.Done() defer wg.Done()
err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { err := c._processTxPipelineNode(ctx, node, cmds, failedCmds)
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return txPipelineWriteMulti(wr, cmds)
})
if err != nil {
return err
}
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
err := c.txPipelineReadQueued(rd, cmds, failedCmds)
if err != nil {
moved, ask, addr := isMovedError(err)
if moved || ask {
return c.cmdsMoved(cmds, moved, ask, addr, failedCmds)
}
return err
}
return pipelineReadCmds(rd, cmds)
})
})
if err == nil { if err == nil {
return return
} }
@ -1296,11 +1291,42 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder {
return cmdsMap return cmdsMap
} }
func (c *ClusterClient) _processTxPipelineNode(
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
) error {
return node.Client.hooks.processTxPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
})
if err != nil {
return err
}
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
statusCmd := cmds[0].(*StatusCmd)
// Trim multi and exec.
cmds = cmds[1 : len(cmds)-1]
err := c.txPipelineReadQueued(rd, statusCmd, cmds, failedCmds)
if err != nil {
moved, ask, addr := isMovedError(err)
if moved || ask {
return c.cmdsMoved(cmds, moved, ask, addr, failedCmds)
}
return err
}
return pipelineReadCmds(rd, cmds)
})
})
})
}
func (c *ClusterClient) txPipelineReadQueued( func (c *ClusterClient) txPipelineReadQueued(
rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap, rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder, failedCmds *cmdsMap,
) error { ) error {
// Parse queued replies. // Parse queued replies.
var statusCmd StatusCmd
if err := statusCmd.readReply(rd); err != nil { if err := statusCmd.readReply(rd); err != nil {
return err return err
} }
@ -1352,7 +1378,7 @@ func (c *ClusterClient) cmdsMoved(
if ask { if ask {
for _, cmd := range cmds { for _, cmd := range cmds {
failedCmds.Add(node, NewCmd("ASKING"), cmd) failedCmds.Add(node, NewCmd("asking"), cmd)
} }
return nil return nil
} }

View File

@ -47,14 +47,14 @@ func (s *clusterScenario) addrs() []string {
return addrs return addrs
} }
func (s *clusterScenario) clusterClientUnsafe(opt *redis.ClusterOptions) *redis.ClusterClient { func (s *clusterScenario) newClusterClientUnsafe(opt *redis.ClusterOptions) *redis.ClusterClient {
opt.Addrs = s.addrs() opt.Addrs = s.addrs()
return redis.NewClusterClient(opt) return redis.NewClusterClient(opt)
} }
func (s *clusterScenario) clusterClient(opt *redis.ClusterOptions) *redis.ClusterClient { func (s *clusterScenario) newClusterClient(opt *redis.ClusterOptions) *redis.ClusterClient {
client := s.clusterClientUnsafe(opt) client := s.newClusterClientUnsafe(opt)
err := eventually(func() error { err := eventually(func() error {
if opt.ClusterSlots != nil { if opt.ClusterSlots != nil {
@ -527,12 +527,184 @@ var _ = Describe("ClusterClient", func() {
err := pubsub.Ping() err := pubsub.Ping()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
It("supports Process hook", func() {
err := client.Ping().Err()
Expect(err).NotTo(HaveOccurred())
err = client.ForEachNode(func(node *redis.Client) error {
return node.Ping().Err()
})
Expect(err).NotTo(HaveOccurred())
var stack []string
clusterHook := &hook{
beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
Expect(cmd.String()).To(Equal("ping: "))
stack = append(stack, "cluster.BeforeProcess")
return ctx, nil
},
afterProcess: func(ctx context.Context, cmd redis.Cmder) error {
Expect(cmd.String()).To(Equal("ping: PONG"))
stack = append(stack, "cluster.AfterProcess")
return nil
},
}
client.AddHook(clusterHook)
nodeHook := &hook{
beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
Expect(cmd.String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcess")
return ctx, nil
},
afterProcess: func(ctx context.Context, cmd redis.Cmder) error {
Expect(cmd.String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcess")
return nil
},
}
_ = client.ForEachNode(func(node *redis.Client) error {
node.AddHook(nodeHook)
return nil
})
err = client.Ping().Err()
Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{
"cluster.BeforeProcess",
"shard.BeforeProcess",
"shard.AfterProcess",
"cluster.AfterProcess",
}))
clusterHook.beforeProcess = nil
clusterHook.afterProcess = nil
nodeHook.beforeProcess = nil
nodeHook.afterProcess = nil
})
It("supports Pipeline hook", func() {
err := client.Ping().Err()
Expect(err).NotTo(HaveOccurred())
err = client.ForEachNode(func(node *redis.Client) error {
return node.Ping().Err()
})
Expect(err).NotTo(HaveOccurred())
var stack []string
client.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "cluster.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "cluster.AfterProcessPipeline")
return nil
},
})
_ = client.ForEachNode(func(node *redis.Client) error {
node.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcessPipeline")
return nil
},
})
return nil
})
_, err = client.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
return nil
})
Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{
"cluster.BeforeProcessPipeline",
"shard.BeforeProcessPipeline",
"shard.AfterProcessPipeline",
"cluster.AfterProcessPipeline",
}))
})
It("supports TxPipeline hook", func() {
err := client.Ping().Err()
Expect(err).NotTo(HaveOccurred())
err = client.ForEachNode(func(node *redis.Client) error {
return node.Ping().Err()
})
Expect(err).NotTo(HaveOccurred())
var stack []string
client.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "cluster.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "cluster.AfterProcessPipeline")
return nil
},
})
_ = client.ForEachNode(func(node *redis.Client) error {
node.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(3))
Expect(cmds[1].String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(3))
Expect(cmds[1].String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcessPipeline")
return nil
},
})
return nil
})
_, err = client.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
return nil
})
Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{
"cluster.BeforeProcessPipeline",
"shard.BeforeProcessPipeline",
"shard.AfterProcessPipeline",
"cluster.AfterProcessPipeline",
}))
})
} }
Describe("ClusterClient", func() { Describe("ClusterClient", func() {
BeforeEach(func() { BeforeEach(func() {
opt = redisClusterOptions() opt = redisClusterOptions()
client = cluster.clusterClient(opt) client = cluster.newClusterClient(opt)
err := client.ForEachMaster(func(master *redis.Client) error { err := client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB().Err()
@ -555,7 +727,7 @@ var _ = Describe("ClusterClient", func() {
It("returns an error when there are no attempts left", func() { It("returns an error when there are no attempts left", func() {
opt := redisClusterOptions() opt := redisClusterOptions()
opt.MaxRedirects = -1 opt.MaxRedirects = -1
client := cluster.clusterClient(opt) client := cluster.newClusterClient(opt)
Eventually(func() error { Eventually(func() error {
return client.SwapNodes("A") return client.SwapNodes("A")
@ -707,7 +879,7 @@ var _ = Describe("ClusterClient", func() {
opt = redisClusterOptions() opt = redisClusterOptions()
opt.MinRetryBackoff = 250 * time.Millisecond opt.MinRetryBackoff = 250 * time.Millisecond
opt.MaxRetryBackoff = time.Second opt.MaxRetryBackoff = time.Second
client = cluster.clusterClient(opt) client = cluster.newClusterClient(opt)
err := client.ForEachMaster(func(master *redis.Client) error { err := client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB().Err()
@ -757,7 +929,7 @@ var _ = Describe("ClusterClient", func() {
BeforeEach(func() { BeforeEach(func() {
opt = redisClusterOptions() opt = redisClusterOptions()
opt.RouteByLatency = true opt.RouteByLatency = true
client = cluster.clusterClient(opt) client = cluster.newClusterClient(opt)
err := client.ForEachMaster(func(master *redis.Client) error { err := client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB().Err()
@ -813,7 +985,7 @@ var _ = Describe("ClusterClient", func() {
}} }}
return slots, nil return slots, nil
} }
client = cluster.clusterClient(opt) client = cluster.newClusterClient(opt)
err := client.ForEachMaster(func(master *redis.Client) error { err := client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB().Err()
@ -867,7 +1039,7 @@ var _ = Describe("ClusterClient", func() {
}} }}
return slots, nil return slots, nil
} }
client = cluster.clusterClient(opt) client = cluster.newClusterClient(opt)
err := client.ForEachMaster(func(master *redis.Client) error { err := client.ForEachMaster(func(master *redis.Client) error {
return master.FlushDB().Err() return master.FlushDB().Err()
@ -959,7 +1131,7 @@ var _ = Describe("ClusterClient with unavailable Cluster", func() {
opt.ReadTimeout = 250 * time.Millisecond opt.ReadTimeout = 250 * time.Millisecond
opt.WriteTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond
opt.MaxRedirects = 1 opt.MaxRedirects = 1
client = cluster.clusterClientUnsafe(opt) client = cluster.newClusterClientUnsafe(opt)
}) })
AfterEach(func() { AfterEach(func() {
@ -1028,7 +1200,7 @@ var _ = Describe("ClusterClient timeout", func() {
opt.ReadTimeout = 250 * time.Millisecond opt.ReadTimeout = 250 * time.Millisecond
opt.WriteTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond
opt.MaxRedirects = 1 opt.MaxRedirects = 1
client = cluster.clusterClient(opt) client = cluster.newClusterClient(opt)
err := client.ForEachNode(func(client *redis.Client) error { err := client.ForEachNode(func(client *redis.Client) error {
return client.ClientPause(pause).Err() return client.ClientPause(pause).Err()

View File

@ -15,6 +15,7 @@ import (
type Cmder interface { type Cmder interface {
Name() string Name() string
Args() []interface{} Args() []interface{}
String() string
stringArg(int) string stringArg(int) string
readTimeout() *time.Duration readTimeout() *time.Duration
@ -41,16 +42,19 @@ func cmdsFirstErr(cmds []Cmder) error {
return nil return nil
} }
func writeCmd(wr *proto.Writer, cmds ...Cmder) error { func writeCmds(wr *proto.Writer, cmds []Cmder) error {
for _, cmd := range cmds { for _, cmd := range cmds {
err := wr.WriteArgs(cmd.Args()) if err := writeCmd(wr, cmd); err != nil {
if err != nil {
return err return err
} }
} }
return nil return nil
} }
func writeCmd(wr *proto.Writer, cmd Cmder) error {
return wr.WriteArgs(cmd.Args())
}
func cmdString(cmd Cmder, val interface{}) string { func cmdString(cmd Cmder, val interface{}) string {
ss := make([]string, 0, len(cmd.Args())) ss := make([]string, 0, len(cmd.Args()))
for _, arg := range cmd.Args() { for _, arg := range cmd.Args() {
@ -149,6 +153,10 @@ func NewCmd(args ...interface{}) *Cmd {
} }
} }
func (cmd *Cmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *Cmd) Val() interface{} { func (cmd *Cmd) Val() interface{} {
return cmd.val return cmd.val
} }
@ -157,7 +165,7 @@ func (cmd *Cmd) Result() (interface{}, error) {
return cmd.val, cmd.err return cmd.val, cmd.err
} }
func (cmd *Cmd) String() (string, error) { func (cmd *Cmd) Text() (string, error) {
if cmd.err != nil { if cmd.err != nil {
return "", cmd.err return "", cmd.err
} }

View File

@ -447,7 +447,7 @@ func Example_customCommand() {
} }
func Example_customCommand2() { func Example_customCommand2() {
v, err := rdb.Do("get", "key_does_not_exist").String() v, err := rdb.Do("get", "key_does_not_exist").Text()
fmt.Printf("%q %s", v, err) fmt.Printf("%q %s", v, err)
// Output: "" redis: nil // Output: "" redis: nil
} }

View File

@ -1,6 +1,7 @@
package redis_test package redis_test
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -370,3 +371,41 @@ func (cn *badConn) Write([]byte) (int, error) {
} }
return 0, badConnError("bad connection") return 0, badConnError("bad connection")
} }
//------------------------------------------------------------------------------
type hook struct {
beforeProcess func(ctx context.Context, cmd redis.Cmder) (context.Context, error)
afterProcess func(ctx context.Context, cmd redis.Cmder) error
beforeProcessPipeline func(ctx context.Context, cmds []redis.Cmder) (context.Context, error)
afterProcessPipeline func(ctx context.Context, cmds []redis.Cmder) error
}
func (h *hook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
if h.beforeProcess != nil {
return h.beforeProcess(ctx, cmd)
}
return ctx, nil
}
func (h *hook) AfterProcess(ctx context.Context, cmd redis.Cmder) error {
if h.afterProcess != nil {
return h.afterProcess(ctx, cmd)
}
return nil
}
func (h *hook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
if h.beforeProcessPipeline != nil {
return h.beforeProcessPipeline(ctx, cmds)
}
return ctx, nil
}
func (h *hook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error {
if h.afterProcessPipeline != nil {
return h.afterProcessPipeline(ctx, cmds)
}
return nil
}

View File

@ -299,7 +299,7 @@ var _ = Describe("cluster races", func() {
BeforeEach(func() { BeforeEach(func() {
opt := redisClusterOptions() opt := redisClusterOptions()
client = cluster.clusterClient(opt) client = cluster.newClusterClient(opt)
C, N = 10, 1000 C, N = 10, 1000
if testing.Short() { if testing.Short() {

View File

@ -128,6 +128,13 @@ func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error {
return firstErr return firstErr
} }
func (hs hooks) processTxPipeline(
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
) error {
cmds = wrapMultiExec(cmds)
return hs.processPipeline(ctx, cmds, fn)
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type baseClient struct { type baseClient struct {
@ -411,7 +418,7 @@ func (c *baseClient) pipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder, ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) { ) (bool, error) {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmds...) return writeCmds(wr, cmds)
}) })
if err != nil { if err != nil {
return true, err return true, err
@ -437,41 +444,46 @@ func (c *baseClient) txPipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder, ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) { ) (bool, error) {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return txPipelineWriteMulti(wr, cmds) return writeCmds(wr, cmds)
}) })
if err != nil { if err != nil {
return true, err return true, err
} }
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
err := txPipelineReadQueued(rd, cmds) statusCmd := cmds[0].(*StatusCmd)
// Trim multi and exec.
cmds = cmds[1 : len(cmds)-1]
err := txPipelineReadQueued(rd, statusCmd, cmds)
if err != nil { if err != nil {
return err return err
} }
return pipelineReadCmds(rd, cmds) return pipelineReadCmds(rd, cmds)
}) })
return false, err return false, err
} }
func txPipelineWriteMulti(wr *proto.Writer, cmds []Cmder) error { func wrapMultiExec(cmds []Cmder) []Cmder {
multiExec := make([]Cmder, 0, len(cmds)+2) if len(cmds) == 0 {
multiExec = append(multiExec, NewStatusCmd("MULTI")) panic("not reached")
multiExec = append(multiExec, cmds...) }
multiExec = append(multiExec, NewSliceCmd("EXEC")) cmds = append(cmds, make([]Cmder, 2)...)
return writeCmd(wr, multiExec...) copy(cmds[1:], cmds[:len(cmds)-2])
cmds[0] = NewStatusCmd("multi")
cmds[len(cmds)-1] = NewSliceCmd("exec")
return cmds
} }
func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error { func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
// Parse queued replies. // Parse queued replies.
var statusCmd StatusCmd if err := statusCmd.readReply(rd); err != nil {
err := statusCmd.readReply(rd)
if err != nil {
return err return err
} }
for range cmds { for range cmds {
err = statusCmd.readReply(rd) if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) {
if err != nil && !isRedisError(err) {
return err return err
} }
} }
@ -577,7 +589,7 @@ func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
} }
func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline) return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
} }
// Options returns read-only Options that were used to create the client. // Options returns read-only Options that were used to create the client.

14
ring.go
View File

@ -581,7 +581,7 @@ func (c *Ring) _process(ctx context.Context, cmd Cmder) error {
return err return err
} }
lastErr = shard.Client._process(ctx, cmd) lastErr = shard.Client.ProcessContext(ctx, cmd)
if lastErr == nil || !isRetryableError(lastErr, cmd.readTimeout() == nil) { if lastErr == nil || !isRetryableError(lastErr, cmd.readTimeout() == nil) {
return lastErr return lastErr
} }
@ -646,10 +646,7 @@ func (c *Ring) generalProcessPipeline(
go func(hash string, cmds []Cmder) { go func(hash string, cmds []Cmder) {
defer wg.Done() defer wg.Done()
err := c.processShardPipeline(ctx, hash, cmds, tx) _ = c.processShardPipeline(ctx, hash, cmds, tx)
if err != nil {
setCmdsErr(cmds, err)
}
}(hash, cmds) }(hash, cmds)
} }
@ -663,15 +660,14 @@ func (c *Ring) processShardPipeline(
//TODO: retry? //TODO: retry?
shard, err := c.shards.GetByHash(hash) shard, err := c.shards.GetByHash(hash)
if err != nil { if err != nil {
setCmdsErr(cmds, err)
return err return err
} }
if tx { if tx {
err = shard.Client._generalProcessPipeline( err = shard.Client.processTxPipeline(ctx, cmds)
ctx, cmds, shard.Client.txPipelineProcessCmds)
} else { } else {
err = shard.Client._generalProcessPipeline( err = shard.Client.processPipeline(ctx, cmds)
ctx, cmds, shard.Client.pipelineProcessCmds)
} }
return err return err
} }

View File

@ -195,6 +195,155 @@ var _ = Describe("Redis Ring", func() {
Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set")) Expect(err).To(MatchError("ERR Client sent AUTH, but no password is set"))
}) })
}) })
It("supports Process hook", func() {
err := ring.Ping().Err()
Expect(err).NotTo(HaveOccurred())
var stack []string
ring.AddHook(&hook{
beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
Expect(cmd.String()).To(Equal("ping: "))
stack = append(stack, "ring.BeforeProcess")
return ctx, nil
},
afterProcess: func(ctx context.Context, cmd redis.Cmder) error {
Expect(cmd.String()).To(Equal("ping: PONG"))
stack = append(stack, "ring.AfterProcess")
return nil
},
})
ring.ForEachShard(func(shard *redis.Client) error {
shard.AddHook(&hook{
beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
Expect(cmd.String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcess")
return ctx, nil
},
afterProcess: func(ctx context.Context, cmd redis.Cmder) error {
Expect(cmd.String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcess")
return nil
},
})
return nil
})
err = ring.Ping().Err()
Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{
"ring.BeforeProcess",
"shard.BeforeProcess",
"shard.AfterProcess",
"ring.AfterProcess",
}))
})
It("supports Pipeline hook", func() {
err := ring.Ping().Err()
Expect(err).NotTo(HaveOccurred())
var stack []string
ring.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "ring.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "ring.AfterProcessPipeline")
return nil
},
})
ring.ForEachShard(func(shard *redis.Client) error {
shard.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcessPipeline")
return nil
},
})
return nil
})
_, err = ring.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
return nil
})
Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{
"ring.BeforeProcessPipeline",
"shard.BeforeProcessPipeline",
"shard.AfterProcessPipeline",
"ring.AfterProcessPipeline",
}))
})
It("supports TxPipeline hook", func() {
err := ring.Ping().Err()
Expect(err).NotTo(HaveOccurred())
var stack []string
ring.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "ring.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "ring.AfterProcessPipeline")
return nil
},
})
ring.ForEachShard(func(shard *redis.Client) error {
shard.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(3))
Expect(cmds[1].String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(3))
Expect(cmds[1].String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcessPipeline")
return nil
},
})
return nil
})
_, err = ring.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
return nil
})
Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{
"ring.BeforeProcessPipeline",
"shard.BeforeProcessPipeline",
"shard.AfterProcessPipeline",
"ring.AfterProcessPipeline",
}))
})
}) })
var _ = Describe("empty Redis Ring", func() { var _ = Describe("empty Redis Ring", func() {

2
tx.go
View File

@ -151,7 +151,7 @@ func (c *Tx) TxPipeline() Pipeliner {
pipe := Pipeline{ pipe := Pipeline{
ctx: c.ctx, ctx: c.ctx,
exec: func(ctx context.Context, cmds []Cmder) error { exec: func(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline) return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
}, },
} }
pipe.init() pipe.init()