diff --git a/cluster.go b/cluster.go index c5fcb9b..12721c3 100644 --- a/cluster.go +++ b/cluster.go @@ -1070,7 +1070,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmd(wr, cmds...) + return writeCmds(wr, cmds) }) if err != nil { return err diff --git a/command.go b/command.go index 0a4a345..8c5fbe2 100644 --- a/command.go +++ b/command.go @@ -41,16 +41,19 @@ func cmdsFirstErr(cmds []Cmder) error { return nil } -func writeCmd(wr *proto.Writer, cmds ...Cmder) error { +func writeCmds(wr *proto.Writer, cmds []Cmder) error { for _, cmd := range cmds { - err := wr.WriteArgs(cmd.Args()) - if err != nil { + if err := writeCmd(wr, cmd); err != nil { return err } } return nil } +func writeCmd(wr *proto.Writer, cmd Cmder) error { + return wr.WriteArgs(cmd.Args()) +} + func cmdString(cmd Cmder, val interface{}) string { ss := make([]string, 0, len(cmd.Args())) for _, arg := range cmd.Args() { diff --git a/redis.go b/redis.go index 2b50fff..caba631 100644 --- a/redis.go +++ b/redis.go @@ -411,7 +411,7 @@ func (c *baseClient) pipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { - return writeCmd(wr, cmds...) + return writeCmds(wr, cmds) }) if err != nil { return true, err @@ -453,12 +453,22 @@ func (c *baseClient) txPipelineProcessCmds( return false, err } +var ( + multi = NewStatusCmd("multi") + exec = NewSliceCmd("exec") +) + func txPipelineWriteMulti(wr *proto.Writer, cmds []Cmder) error { - multiExec := make([]Cmder, 0, len(cmds)+2) - multiExec = append(multiExec, NewStatusCmd("MULTI")) - multiExec = append(multiExec, cmds...) - multiExec = append(multiExec, NewSliceCmd("EXEC")) - return writeCmd(wr, multiExec...) + if err := writeCmd(wr, multi); err != nil { + return err + } + if err := writeCmds(wr, cmds); err != nil { + return err + } + if err := writeCmd(wr, exec); err != nil { + return err + } + return nil } func txPipelineReadQueued(rd *proto.Reader, cmds []Cmder) error {