Merge pull request #444 from go-redis/feature/tx-pipeline

Add TxPipeline.
This commit is contained in:
Vladimir Mihailenco 2016-12-16 11:40:46 +02:00 committed by GitHub
commit 152cc1ee34
13 changed files with 577 additions and 590 deletions

View File

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"fmt"
"math/rand" "math/rand"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -9,6 +10,7 @@ import (
"gopkg.in/redis.v5/internal" "gopkg.in/redis.v5/internal"
"gopkg.in/redis.v5/internal/hashtag" "gopkg.in/redis.v5/internal/hashtag"
"gopkg.in/redis.v5/internal/pool" "gopkg.in/redis.v5/internal/pool"
"gopkg.in/redis.v5/internal/proto"
) )
var errClusterNoNodes = internal.RedisError("redis: cluster has no nodes") var errClusterNoNodes = internal.RedisError("redis: cluster has no nodes")
@ -415,10 +417,6 @@ func (c *ClusterClient) Process(cmd Cmder) error {
var ask bool var ask bool
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
if attempt > 0 {
cmd.reset()
}
if ask { if ask {
pipe := node.Client.Pipeline() pipe := node.Client.Pipeline()
pipe.Process(NewCmd("ASKING")) pipe.Process(NewCmd("ASKING"))
@ -653,111 +651,252 @@ func (c *ClusterClient) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
} }
func (c *ClusterClient) pipelineExec(cmds []Cmder) error { func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
var firstErr error cmdsMap, err := c.mapCmdsByNode(cmds)
setFirstErr := func(err error) {
if firstErr == nil {
firstErr = err
}
}
state := c.state()
cmdsMap := make(map[*clusterNode][]Cmder)
for _, cmd := range cmds {
_, node, err := c.cmdSlotAndNode(state, cmd)
if err != nil { if err != nil {
cmd.setErr(err) return err
setFirstErr(err)
continue
}
cmdsMap[node] = append(cmdsMap[node], cmd)
} }
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { for i := 0; i <= c.opt.MaxRedirects; i++ {
failedCmds := make(map[*clusterNode][]Cmder) failedCmds := make(map[*clusterNode][]Cmder)
for node, cmds := range cmdsMap { for node, cmds := range cmdsMap {
cn, _, err := node.Client.conn() cn, _, err := node.Client.conn()
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
setFirstErr(err)
continue continue
} }
failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds) err = c.pipelineProcessCmds(cn, cmds, failedCmds)
node.Client.putConn(cn, err, false) node.Client.putConn(cn, err, false)
if err != nil {
setFirstErr(err)
}
} }
if len(failedCmds) == 0 {
break
}
cmdsMap = failedCmds cmdsMap = failedCmds
} }
var firstErr error
for _, cmd := range cmds {
if err := cmd.Err(); err != nil {
firstErr = err
break
}
}
return firstErr return firstErr
} }
func (c *ClusterClient) execClusterCmds( func (c *ClusterClient) mapCmdsByNode(cmds []Cmder) (map[*clusterNode][]Cmder, error) {
state := c.state()
cmdsMap := make(map[*clusterNode][]Cmder)
for _, cmd := range cmds {
_, node, err := c.cmdSlotAndNode(state, cmd)
if err != nil {
return nil, err
}
cmdsMap[node] = append(cmdsMap[node], cmd)
}
return cmdsMap, nil
}
func (c *ClusterClient) pipelineProcessCmds(
cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder,
) (map[*clusterNode][]Cmder, error) { ) error {
cn.SetWriteTimeout(c.opt.WriteTimeout) cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil { if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
return failedCmds, err return err
}
var firstErr error
setFirstErr := func(err error) {
if firstErr == nil {
firstErr = err
}
} }
// Set read timeout for all commands. // Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout) cn.SetReadTimeout(c.opt.ReadTimeout)
for i, cmd := range cmds { return c.pipelineReadCmds(cn, cmds, failedCmds)
}
func (c *ClusterClient) pipelineReadCmds(
cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder,
) error {
var firstErr error
for _, cmd := range cmds {
err := cmd.readReply(cn) err := cmd.readReply(cn)
if err == nil { if err == nil {
continue continue
} }
if i == 0 && internal.IsRetryableError(err) { if firstErr == nil {
node, err := c.nodes.Random() firstErr = err
if err != nil {
setFirstErr(err)
continue
} }
cmd.reset() err = c.checkMovedErr(cmd, failedCmds)
failedCmds[node] = append(failedCmds[node], cmds...) if err != nil && firstErr == nil {
break firstErr = err
} }
}
return firstErr
}
moved, ask, addr := internal.IsMovedError(err) func (c *ClusterClient) checkMovedErr(cmd Cmder, failedCmds map[*clusterNode][]Cmder) error {
moved, ask, addr := internal.IsMovedError(cmd.Err())
if moved { if moved {
c.lazyReloadSlots() c.lazyReloadSlots()
node, err := c.nodes.Get(addr) node, err := c.nodes.Get(addr)
if err != nil { if err != nil {
setFirstErr(err) return err
continue
} }
cmd.reset()
failedCmds[node] = append(failedCmds[node], cmd) failedCmds[node] = append(failedCmds[node], cmd)
} else if ask { }
if ask {
node, err := c.nodes.Get(addr) node, err := c.nodes.Get(addr)
if err != nil { if err != nil {
setFirstErr(err) return err
}
failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd)
}
return nil
}
func (c *ClusterClient) TxPipeline() *Pipeline {
pipe := Pipeline{
exec: c.txPipelineExec,
}
pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process
return &pipe
}
func (c *ClusterClient) TxPipelined(fn func(*Pipeline) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn)
}
func (c *ClusterClient) txPipelineExec(cmds []Cmder) error {
cmdsMap, err := c.mapCmdsBySlot(cmds)
if err != nil {
return err
}
for slot, cmds := range cmdsMap {
node, err := c.state().slotMasterNode(slot)
if err != nil {
setCmdsErr(cmds, err)
continue continue
} }
cmd.reset() cmdsMap := map[*clusterNode][]Cmder{node: cmds}
failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd) for i := 0; i <= c.opt.MaxRedirects; i++ {
} else { failedCmds := make(map[*clusterNode][]Cmder)
setFirstErr(err)
for node, cmds := range cmdsMap {
cn, _, err := node.Client.conn()
if err != nil {
setCmdsErr(cmds, err)
continue
}
err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds)
node.Client.putConn(cn, err, false)
}
if len(failedCmds) == 0 {
break
}
cmdsMap = failedCmds
} }
} }
return failedCmds, firstErr var firstErr error
for _, cmd := range cmds {
if err := cmd.Err(); err != nil {
firstErr = err
break
}
}
return firstErr
}
func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) (map[int][]Cmder, error) {
state := c.state()
cmdsMap := make(map[int][]Cmder)
for _, cmd := range cmds {
slot, _, err := c.cmdSlotAndNode(state, cmd)
if err != nil {
return nil, err
}
cmdsMap[slot] = append(cmdsMap[slot], cmd)
}
return cmdsMap, nil
}
func (c *ClusterClient) txPipelineProcessCmds(
node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder,
) error {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := txPipelineWriteMulti(cn, cmds); err != nil {
setCmdsErr(cmds, err)
failedCmds[node] = cmds
return err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
if err := c.txPipelineReadQueued(cn, cmds, failedCmds); err != nil {
return err
}
_, err := pipelineReadCmds(cn, cmds)
return err
}
func (c *ClusterClient) txPipelineReadQueued(
cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder,
) error {
var firstErr error
// Parse queued replies.
var statusCmd StatusCmd
if err := statusCmd.readReply(cn); err != nil && firstErr == nil {
firstErr = err
}
for _, cmd := range cmds {
err := statusCmd.readReply(cn)
if err == nil {
continue
}
cmd.setErr(err)
if firstErr == nil {
firstErr = err
}
err = c.checkMovedErr(cmd, failedCmds)
if err != nil && firstErr == nil {
firstErr = err
}
}
// Parse number of replies.
line, err := cn.Rd.ReadLine()
if err != nil {
if err == Nil {
err = TxFailedErr
}
return err
}
switch line[0] {
case proto.ErrorReply:
return proto.ParseErrorReply(line)
case proto.ArrayReply:
// ok
default:
err := fmt.Errorf("redis: expected '*', but got line %q", line)
return err
}
return firstErr
} }

View File

@ -373,14 +373,14 @@ var _ = Describe("ClusterClient", func() {
Expect(n).To(Equal(int64(100))) Expect(n).To(Equal(int64(100)))
}) })
Describe("pipeline", func() { Describe("pipelining", func() {
var pipe *redis.Pipeline
assertPipeline := func() {
It("follows redirects", func() { It("follows redirects", func() {
slot := hashtag.Slot("A") slot := hashtag.Slot("A")
Expect(client.SwapSlotNodes(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) Expect(client.SwapSlotNodes(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"}))
pipe := client.Pipeline()
defer pipe.Close()
keys := []string{"A", "B", "C", "D", "E", "F", "G"} keys := []string{"A", "B", "C", "D", "E", "F", "G"}
for i, key := range keys { for i, key := range keys {
@ -429,6 +429,31 @@ var _ = Describe("ClusterClient", func() {
Expect(c.Err()).NotTo(HaveOccurred()) Expect(c.Err()).NotTo(HaveOccurred())
Expect(c.Val()).To(Equal("C_value")) Expect(c.Val()).To(Equal("C_value"))
}) })
}
Describe("Pipeline", func() {
BeforeEach(func() {
pipe = client.Pipeline()
})
AfterEach(func() {
Expect(pipe.Close()).NotTo(HaveOccurred())
})
assertPipeline()
})
Describe("TxPipeline", func() {
BeforeEach(func() {
pipe = client.TxPipeline()
})
AfterEach(func() {
Expect(pipe.Close()).NotTo(HaveOccurred())
})
assertPipeline()
})
}) })
It("calls fn for every master node", func() { It("calls fn for every master node", func() {
@ -624,7 +649,7 @@ var _ = Describe("ClusterClient timeout", func() {
return client.ForEachNode(func(client *redis.Client) error { return client.ForEachNode(func(client *redis.Client) error {
return client.Ping().Err() return client.Ping().Err()
}) })
}, pause).ShouldNot(HaveOccurred()) }, 2*pause).ShouldNot(HaveOccurred())
}) })
testTimeout() testTimeout()

View File

@ -36,7 +36,6 @@ type Cmder interface {
readReply(*pool.Conn) error readReply(*pool.Conn) error
setErr(error) setErr(error)
reset()
readTimeout() *time.Duration readTimeout() *time.Duration
@ -50,12 +49,6 @@ func setCmdsErr(cmds []Cmder, e error) {
} }
} }
func resetCmds(cmds []Cmder) {
for _, cmd := range cmds {
cmd.reset()
}
}
func writeCmd(cn *pool.Conn, cmds ...Cmder) error { func writeCmd(cn *pool.Conn, cmds ...Cmder) error {
cn.Wb.Reset() cn.Wb.Reset()
for _, cmd := range cmds { for _, cmd := range cmds {
@ -167,11 +160,6 @@ func NewCmd(args ...interface{}) *Cmd {
} }
} }
func (cmd *Cmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *Cmd) Val() interface{} { func (cmd *Cmd) Val() interface{} {
return cmd.val return cmd.val
} }
@ -185,16 +173,13 @@ func (cmd *Cmd) String() string {
} }
func (cmd *Cmd) readReply(cn *pool.Conn) error { func (cmd *Cmd) readReply(cn *pool.Conn) error {
val, err := cn.Rd.ReadReply(sliceParser) cmd.val, cmd.err = cn.Rd.ReadReply(sliceParser)
if err != nil { if cmd.err != nil {
cmd.err = err
return cmd.err return cmd.err
} }
if b, ok := val.([]byte); ok { if b, ok := cmd.val.([]byte); ok {
// Bytes must be copied, because underlying memory is reused. // Bytes must be copied, because underlying memory is reused.
cmd.val = string(b) cmd.val = string(b)
} else {
cmd.val = val
} }
return nil return nil
} }
@ -212,11 +197,6 @@ func NewSliceCmd(args ...interface{}) *SliceCmd {
return &SliceCmd{baseCmd: cmd} return &SliceCmd{baseCmd: cmd}
} }
func (cmd *SliceCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *SliceCmd) Val() []interface{} { func (cmd *SliceCmd) Val() []interface{} {
return cmd.val return cmd.val
} }
@ -230,10 +210,10 @@ func (cmd *SliceCmd) String() string {
} }
func (cmd *SliceCmd) readReply(cn *pool.Conn) error { func (cmd *SliceCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(sliceParser) var v interface{}
if err != nil { v, cmd.err = cn.Rd.ReadArrayReply(sliceParser)
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.val = v.([]interface{}) cmd.val = v.([]interface{})
return nil return nil
@ -252,11 +232,6 @@ func NewStatusCmd(args ...interface{}) *StatusCmd {
return &StatusCmd{baseCmd: cmd} return &StatusCmd{baseCmd: cmd}
} }
func (cmd *StatusCmd) reset() {
cmd.val = ""
cmd.err = nil
}
func (cmd *StatusCmd) Val() string { func (cmd *StatusCmd) Val() string {
return cmd.val return cmd.val
} }
@ -287,11 +262,6 @@ func NewIntCmd(args ...interface{}) *IntCmd {
return &IntCmd{baseCmd: cmd} return &IntCmd{baseCmd: cmd}
} }
func (cmd *IntCmd) reset() {
cmd.val = 0
cmd.err = nil
}
func (cmd *IntCmd) Val() int64 { func (cmd *IntCmd) Val() int64 {
return cmd.val return cmd.val
} }
@ -326,11 +296,6 @@ func NewDurationCmd(precision time.Duration, args ...interface{}) *DurationCmd {
} }
} }
func (cmd *DurationCmd) reset() {
cmd.val = 0
cmd.err = nil
}
func (cmd *DurationCmd) Val() time.Duration { func (cmd *DurationCmd) Val() time.Duration {
return cmd.val return cmd.val
} }
@ -344,10 +309,10 @@ func (cmd *DurationCmd) String() string {
} }
func (cmd *DurationCmd) readReply(cn *pool.Conn) error { func (cmd *DurationCmd) readReply(cn *pool.Conn) error {
n, err := cn.Rd.ReadIntReply() var n int64
if err != nil { n, cmd.err = cn.Rd.ReadIntReply()
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.val = time.Duration(n) * cmd.precision cmd.val = time.Duration(n) * cmd.precision
return nil return nil
@ -368,11 +333,6 @@ func NewTimeCmd(args ...interface{}) *TimeCmd {
} }
} }
func (cmd *TimeCmd) reset() {
cmd.val = time.Time{}
cmd.err = nil
}
func (cmd *TimeCmd) Val() time.Time { func (cmd *TimeCmd) Val() time.Time {
return cmd.val return cmd.val
} }
@ -386,10 +346,10 @@ func (cmd *TimeCmd) String() string {
} }
func (cmd *TimeCmd) readReply(cn *pool.Conn) error { func (cmd *TimeCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(timeParser) var v interface{}
if err != nil { v, cmd.err = cn.Rd.ReadArrayReply(timeParser)
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.val = v.(time.Time) cmd.val = v.(time.Time)
return nil return nil
@ -408,11 +368,6 @@ func NewBoolCmd(args ...interface{}) *BoolCmd {
return &BoolCmd{baseCmd: cmd} return &BoolCmd{baseCmd: cmd}
} }
func (cmd *BoolCmd) reset() {
cmd.val = false
cmd.err = nil
}
func (cmd *BoolCmd) Val() bool { func (cmd *BoolCmd) Val() bool {
return cmd.val return cmd.val
} }
@ -428,27 +383,29 @@ func (cmd *BoolCmd) String() string {
var ok = []byte("OK") var ok = []byte("OK")
func (cmd *BoolCmd) readReply(cn *pool.Conn) error { func (cmd *BoolCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadReply(nil) var v interface{}
v, cmd.err = cn.Rd.ReadReply(nil)
// `SET key value NX` returns nil when key already exists. But // `SET key value NX` returns nil when key already exists. But
// `SETNX key value` returns bool (0/1). So convert nil to bool. // `SETNX key value` returns bool (0/1). So convert nil to bool.
// TODO: is this okay? // TODO: is this okay?
if err == Nil { if cmd.err == Nil {
cmd.val = false cmd.val = false
cmd.err = nil
return nil return nil
} }
if err != nil { if cmd.err != nil {
cmd.err = err return cmd.err
return err
} }
switch vv := v.(type) { switch v := v.(type) {
case int64: case int64:
cmd.val = vv == 1 cmd.val = v == 1
return nil return nil
case []byte: case []byte:
cmd.val = bytes.Equal(vv, ok) cmd.val = bytes.Equal(v, ok)
return nil return nil
default: default:
return fmt.Errorf("got %T, wanted int64 or string", v) cmd.err = fmt.Errorf("got %T, wanted int64 or string", v)
return cmd.err
} }
} }
@ -465,11 +422,6 @@ func NewStringCmd(args ...interface{}) *StringCmd {
return &StringCmd{baseCmd: cmd} return &StringCmd{baseCmd: cmd}
} }
func (cmd *StringCmd) reset() {
cmd.val = ""
cmd.err = nil
}
func (cmd *StringCmd) Val() string { func (cmd *StringCmd) Val() string {
return cmd.val return cmd.val
} }
@ -515,13 +467,8 @@ func (cmd *StringCmd) String() string {
} }
func (cmd *StringCmd) readReply(cn *pool.Conn) error { func (cmd *StringCmd) readReply(cn *pool.Conn) error {
b, err := cn.Rd.ReadBytesReply() cmd.val, cmd.err = cn.Rd.ReadStringReply()
if err != nil { return cmd.err
cmd.err = err
return err
}
cmd.val = string(b)
return nil
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -537,11 +484,6 @@ func NewFloatCmd(args ...interface{}) *FloatCmd {
return &FloatCmd{baseCmd: cmd} return &FloatCmd{baseCmd: cmd}
} }
func (cmd *FloatCmd) reset() {
cmd.val = 0
cmd.err = nil
}
func (cmd *FloatCmd) Val() float64 { func (cmd *FloatCmd) Val() float64 {
return cmd.val return cmd.val
} }
@ -572,11 +514,6 @@ func NewStringSliceCmd(args ...interface{}) *StringSliceCmd {
return &StringSliceCmd{baseCmd: cmd} return &StringSliceCmd{baseCmd: cmd}
} }
func (cmd *StringSliceCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *StringSliceCmd) Val() []string { func (cmd *StringSliceCmd) Val() []string {
return cmd.val return cmd.val
} }
@ -590,10 +527,10 @@ func (cmd *StringSliceCmd) String() string {
} }
func (cmd *StringSliceCmd) readReply(cn *pool.Conn) error { func (cmd *StringSliceCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(stringSliceParser) var v interface{}
if err != nil { v, cmd.err = cn.Rd.ReadArrayReply(stringSliceParser)
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.val = v.([]string) cmd.val = v.([]string)
return nil return nil
@ -612,11 +549,6 @@ func NewBoolSliceCmd(args ...interface{}) *BoolSliceCmd {
return &BoolSliceCmd{baseCmd: cmd} return &BoolSliceCmd{baseCmd: cmd}
} }
func (cmd *BoolSliceCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *BoolSliceCmd) Val() []bool { func (cmd *BoolSliceCmd) Val() []bool {
return cmd.val return cmd.val
} }
@ -630,10 +562,10 @@ func (cmd *BoolSliceCmd) String() string {
} }
func (cmd *BoolSliceCmd) readReply(cn *pool.Conn) error { func (cmd *BoolSliceCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(boolSliceParser) var v interface{}
if err != nil { v, cmd.err = cn.Rd.ReadArrayReply(boolSliceParser)
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.val = v.([]bool) cmd.val = v.([]bool)
return nil return nil
@ -652,11 +584,6 @@ func NewStringStringMapCmd(args ...interface{}) *StringStringMapCmd {
return &StringStringMapCmd{baseCmd: cmd} return &StringStringMapCmd{baseCmd: cmd}
} }
func (cmd *StringStringMapCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *StringStringMapCmd) Val() map[string]string { func (cmd *StringStringMapCmd) Val() map[string]string {
return cmd.val return cmd.val
} }
@ -670,10 +597,10 @@ func (cmd *StringStringMapCmd) String() string {
} }
func (cmd *StringStringMapCmd) readReply(cn *pool.Conn) error { func (cmd *StringStringMapCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(stringStringMapParser) var v interface{}
if err != nil { v, cmd.err = cn.Rd.ReadArrayReply(stringStringMapParser)
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.val = v.(map[string]string) cmd.val = v.(map[string]string)
return nil return nil
@ -704,16 +631,11 @@ func (cmd *StringIntMapCmd) String() string {
return cmdString(cmd, cmd.val) return cmdString(cmd, cmd.val)
} }
func (cmd *StringIntMapCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *StringIntMapCmd) readReply(cn *pool.Conn) error { func (cmd *StringIntMapCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(stringIntMapParser) var v interface{}
if err != nil { v, cmd.err = cn.Rd.ReadArrayReply(stringIntMapParser)
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.val = v.(map[string]int64) cmd.val = v.(map[string]int64)
return nil return nil
@ -732,11 +654,6 @@ func NewZSliceCmd(args ...interface{}) *ZSliceCmd {
return &ZSliceCmd{baseCmd: cmd} return &ZSliceCmd{baseCmd: cmd}
} }
func (cmd *ZSliceCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *ZSliceCmd) Val() []Z { func (cmd *ZSliceCmd) Val() []Z {
return cmd.val return cmd.val
} }
@ -750,10 +667,10 @@ func (cmd *ZSliceCmd) String() string {
} }
func (cmd *ZSliceCmd) readReply(cn *pool.Conn) error { func (cmd *ZSliceCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(zSliceParser) var v interface{}
if err != nil { v, cmd.err = cn.Rd.ReadArrayReply(zSliceParser)
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.val = v.([]Z) cmd.val = v.([]Z)
return nil return nil
@ -775,12 +692,6 @@ func NewScanCmd(args ...interface{}) *ScanCmd {
} }
} }
func (cmd *ScanCmd) reset() {
cmd.cursor = 0
cmd.page = nil
cmd.err = nil
}
func (cmd *ScanCmd) Val() (keys []string, cursor uint64) { func (cmd *ScanCmd) Val() (keys []string, cursor uint64) {
return cmd.page, cmd.cursor return cmd.page, cmd.cursor
} }
@ -794,14 +705,8 @@ func (cmd *ScanCmd) String() string {
} }
func (cmd *ScanCmd) readReply(cn *pool.Conn) error { func (cmd *ScanCmd) readReply(cn *pool.Conn) error {
page, cursor, err := cn.Rd.ReadScanReply() cmd.page, cmd.cursor, cmd.err = cn.Rd.ReadScanReply()
if err != nil {
cmd.err = err
return cmd.err return cmd.err
}
cmd.page = page
cmd.cursor = cursor
return nil
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -840,16 +745,11 @@ func (cmd *ClusterSlotsCmd) String() string {
return cmdString(cmd, cmd.val) return cmdString(cmd, cmd.val)
} }
func (cmd *ClusterSlotsCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *ClusterSlotsCmd) readReply(cn *pool.Conn) error { func (cmd *ClusterSlotsCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(clusterSlotsParser) var v interface{}
if err != nil { v, cmd.err = cn.Rd.ReadArrayReply(clusterSlotsParser)
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.val = v.([]ClusterSlot) cmd.val = v.([]ClusterSlot)
return nil return nil
@ -913,11 +813,6 @@ func NewGeoLocationCmd(q *GeoRadiusQuery, args ...interface{}) *GeoLocationCmd {
} }
} }
func (cmd *GeoLocationCmd) reset() {
cmd.locations = nil
cmd.err = nil
}
func (cmd *GeoLocationCmd) Val() []GeoLocation { func (cmd *GeoLocationCmd) Val() []GeoLocation {
return cmd.locations return cmd.locations
} }
@ -931,12 +826,12 @@ func (cmd *GeoLocationCmd) String() string {
} }
func (cmd *GeoLocationCmd) readReply(cn *pool.Conn) error { func (cmd *GeoLocationCmd) readReply(cn *pool.Conn) error {
reply, err := cn.Rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) var v interface{}
if err != nil { v, cmd.err = cn.Rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q))
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.locations = reply.([]GeoLocation) cmd.locations = v.([]GeoLocation)
return nil return nil
} }
@ -969,18 +864,13 @@ func (cmd *GeoPosCmd) String() string {
return cmdString(cmd, cmd.positions) return cmdString(cmd, cmd.positions)
} }
func (cmd *GeoPosCmd) reset() {
cmd.positions = nil
cmd.err = nil
}
func (cmd *GeoPosCmd) readReply(cn *pool.Conn) error { func (cmd *GeoPosCmd) readReply(cn *pool.Conn) error {
reply, err := cn.Rd.ReadArrayReply(geoPosSliceParser) var v interface{}
if err != nil { v, cmd.err = cn.Rd.ReadArrayReply(geoPosSliceParser)
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.positions = reply.([]*GeoPos) cmd.positions = v.([]*GeoPos)
return nil return nil
} }
@ -1019,16 +909,11 @@ func (cmd *CommandsInfoCmd) String() string {
return cmdString(cmd, cmd.val) return cmdString(cmd, cmd.val)
} }
func (cmd *CommandsInfoCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *CommandsInfoCmd) readReply(cn *pool.Conn) error { func (cmd *CommandsInfoCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(commandInfoSliceParser) var v interface{}
if err != nil { v, cmd.err = cn.Rd.ReadArrayReply(commandInfoSliceParser)
cmd.err = err if cmd.err != nil {
return err return cmd.err
} }
cmd.val = v.(map[string]*CommandInfo) cmd.val = v.(map[string]*CommandInfo)
return nil return nil

View File

@ -69,3 +69,7 @@ func IsMovedError(err error) (moved bool, ask bool, addr string) {
func IsLoadingError(err error) bool { func IsLoadingError(err error) bool {
return strings.HasPrefix(err.Error(), "LOADING") return strings.HasPrefix(err.Error(), "LOADING")
} }
func IsExecAbortError(err error) bool {
return strings.HasPrefix(err.Error(), "EXECABORT")
}

View File

@ -70,7 +70,7 @@ func (p *Reader) ReadReply(m MultiBulkParse) (interface{}, error) {
switch line[0] { switch line[0] {
case ErrorReply: case ErrorReply:
return nil, parseErrorValue(line) return nil, ParseErrorReply(line)
case StatusReply: case StatusReply:
return parseStatusValue(line) return parseStatusValue(line)
case IntReply: case IntReply:
@ -94,7 +94,7 @@ func (p *Reader) ReadIntReply() (int64, error) {
} }
switch line[0] { switch line[0] {
case ErrorReply: case ErrorReply:
return 0, parseErrorValue(line) return 0, ParseErrorReply(line)
case IntReply: case IntReply:
return parseIntValue(line) return parseIntValue(line)
default: default:
@ -109,7 +109,7 @@ func (p *Reader) ReadBytesReply() ([]byte, error) {
} }
switch line[0] { switch line[0] {
case ErrorReply: case ErrorReply:
return nil, parseErrorValue(line) return nil, ParseErrorReply(line)
case StringReply: case StringReply:
return p.readBytesValue(line) return p.readBytesValue(line)
case StatusReply: case StatusReply:
@ -142,7 +142,7 @@ func (p *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) {
} }
switch line[0] { switch line[0] {
case ErrorReply: case ErrorReply:
return nil, parseErrorValue(line) return nil, ParseErrorReply(line)
case ArrayReply: case ArrayReply:
n, err := parseArrayLen(line) n, err := parseArrayLen(line)
if err != nil { if err != nil {
@ -161,7 +161,7 @@ func (p *Reader) ReadArrayLen() (int64, error) {
} }
switch line[0] { switch line[0] {
case ErrorReply: case ErrorReply:
return 0, parseErrorValue(line) return 0, ParseErrorReply(line)
case ArrayReply: case ArrayReply:
return parseArrayLen(line) return parseArrayLen(line)
default: default:
@ -272,7 +272,7 @@ func isNilReply(b []byte) bool {
b[1] == '-' && b[2] == '1' b[1] == '-' && b[2] == '1'
} }
func parseErrorValue(line []byte) error { func ParseErrorReply(line []byte) error {
return internal.RedisError(string(line[1:])) return internal.RedisError(string(line[1:]))
} }

View File

@ -58,7 +58,6 @@ func (it *ScanIterator) Next() bool {
} else { } else {
it.ScanCmd._args[2] = it.ScanCmd.cursor it.ScanCmd._args[2] = it.ScanCmd.cursor
} }
it.ScanCmd.reset()
it.client.process(it.ScanCmd) it.client.process(it.ScanCmd)
if it.ScanCmd.Err() != nil { if it.ScanCmd.Err() != nil {
return false return false

View File

@ -7,6 +7,8 @@ import (
"gopkg.in/redis.v5/internal/pool" "gopkg.in/redis.v5/internal/pool"
) )
type pipelineExecer func([]Cmder) error
// Pipeline implements pipelining as described in // Pipeline implements pipelining as described in
// http://redis.io/topics/pipelining. It's safe for concurrent use // http://redis.io/topics/pipelining. It's safe for concurrent use
// by multiple goroutines. // by multiple goroutines.
@ -14,7 +16,7 @@ type Pipeline struct {
cmdable cmdable
statefulCmdable statefulCmdable
exec func([]Cmder) error exec pipelineExecer
mu sync.Mutex mu sync.Mutex
cmds []Cmder cmds []Cmder

View File

@ -1,17 +1,15 @@
package redis_test package redis_test
import ( import (
"strconv"
"sync"
"gopkg.in/redis.v5" "gopkg.in/redis.v5"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
var _ = Describe("Pipeline", func() { var _ = Describe("pipelining", func() {
var client *redis.Client var client *redis.Client
var pipe *redis.Pipeline
BeforeEach(func() { BeforeEach(func() {
client = redis.NewClient(redisOptions()) client = redis.NewClient(redisOptions())
@ -22,44 +20,7 @@ var _ = Describe("Pipeline", func() {
Expect(client.Close()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred())
}) })
It("should pipeline", func() { It("supports block style", func() {
set := client.Set("key2", "hello2", 0)
Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK"))
pipeline := client.Pipeline()
set = pipeline.Set("key1", "hello1", 0)
get := pipeline.Get("key2")
incr := pipeline.Incr("key3")
getNil := pipeline.Get("key4")
cmds, err := pipeline.Exec()
Expect(err).To(Equal(redis.Nil))
Expect(cmds).To(HaveLen(4))
Expect(pipeline.Close()).NotTo(HaveOccurred())
Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK"))
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello2"))
Expect(incr.Err()).NotTo(HaveOccurred())
Expect(incr.Val()).To(Equal(int64(1)))
Expect(getNil.Err()).To(Equal(redis.Nil))
Expect(getNil.Val()).To(Equal(""))
})
It("discards queued commands", func() {
pipeline := client.Pipeline()
pipeline.Get("key")
pipeline.Discard()
_, err := pipeline.Exec()
Expect(err).To(MatchError("redis: pipeline is empty"))
})
It("should support block style", func() {
var get *redis.StringCmd var get *redis.StringCmd
cmds, err := client.Pipelined(func(pipe *redis.Pipeline) error { cmds, err := client.Pipelined(func(pipe *redis.Pipeline) error {
get = pipe.Get("foo") get = pipe.Get("foo")
@ -72,98 +33,47 @@ var _ = Describe("Pipeline", func() {
Expect(get.Val()).To(Equal("")) Expect(get.Val()).To(Equal(""))
}) })
It("should handle vals/err", func() { assertPipeline := func() {
pipeline := client.Pipeline()
get := pipeline.Get("key")
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal(""))
Expect(pipeline.Close()).NotTo(HaveOccurred())
})
It("returns an error when there are no commands", func() { It("returns an error when there are no commands", func() {
pipeline := client.Pipeline() _, err := pipe.Exec()
_, err := pipeline.Exec()
Expect(err).To(MatchError("redis: pipeline is empty")) Expect(err).To(MatchError("redis: pipeline is empty"))
}) })
It("should increment correctly", func() { It("discards queued commands", func() {
const N = 20000 pipe.Get("key")
key := "TestPipelineIncr" pipe.Discard()
pipeline := client.Pipeline() _, err := pipe.Exec()
for i := 0; i < N; i++ { Expect(err).To(MatchError("redis: pipeline is empty"))
pipeline.Incr(key)
}
cmds, err := pipeline.Exec()
Expect(err).NotTo(HaveOccurred())
Expect(pipeline.Close()).NotTo(HaveOccurred())
Expect(len(cmds)).To(Equal(20000))
for _, cmd := range cmds {
Expect(cmd.Err()).NotTo(HaveOccurred())
}
get := client.Get(key)
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal(strconv.Itoa(N)))
}) })
It("should PipelineEcho", func() { It("handles val/err", func() {
const N = 1000 err := client.Set("key", "value", 0).Err()
wg := &sync.WaitGroup{}
wg.Add(N)
for i := 0; i < N; i++ {
go func(i int) {
defer GinkgoRecover()
defer wg.Done()
pipeline := client.Pipeline()
msg1 := "echo" + strconv.Itoa(i)
msg2 := "echo" + strconv.Itoa(i+1)
echo1 := pipeline.Echo(msg1)
echo2 := pipeline.Echo(msg2)
cmds, err := pipeline.Exec()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(2))
Expect(echo1.Err()).NotTo(HaveOccurred()) get := pipe.Get("key")
Expect(echo1.Val()).To(Equal(msg1)) cmds, err := pipe.Exec()
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(1))
Expect(echo2.Err()).NotTo(HaveOccurred()) val, err := get.Result()
Expect(echo2.Val()).To(Equal(msg2)) Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("value"))
Expect(pipeline.Close()).NotTo(HaveOccurred()) })
}(i)
} }
wg.Wait()
Describe("Pipeline", func() {
BeforeEach(func() {
pipe = client.Pipeline()
}) })
It("should be thread-safe", func() { assertPipeline()
const N = 1000
pipeline := client.Pipeline()
var wg sync.WaitGroup
wg.Add(N)
for i := 0; i < N; i++ {
go func() {
defer GinkgoRecover()
pipeline.Ping()
wg.Done()
}()
}
wg.Wait()
cmds, err := pipeline.Exec()
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N))
Expect(pipeline.Close()).NotTo(HaveOccurred())
}) })
Describe("TxPipeline", func() {
BeforeEach(func() {
pipe = client.TxPipeline()
})
assertPipeline()
})
}) })

View File

@ -245,4 +245,35 @@ var _ = Describe("races", func() {
Expect(val).To(Equal(int64(C * N))) Expect(val).To(Equal(int64(C * N)))
}) })
It("should Pipeline", func() {
perform(C, func(id int) {
pipe := client.Pipeline()
for i := 0; i < N; i++ {
pipe.Echo(fmt.Sprint(i))
}
cmds, err := pipe.Exec()
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N))
for i := 0; i < N; i++ {
Expect(cmds[i].(*redis.StringCmd).Val()).To(Equal(fmt.Sprint(i)))
}
})
})
It("should Pipeline", func() {
pipe := client.Pipeline()
perform(N, func(id int) {
pipe.Incr("key")
})
cmds, err := pipe.Exec()
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N))
n, err := client.Get("key").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(N)))
})
}) })

194
redis.go
View File

@ -7,6 +7,7 @@ import (
"gopkg.in/redis.v5/internal" "gopkg.in/redis.v5/internal"
"gopkg.in/redis.v5/internal/pool" "gopkg.in/redis.v5/internal/pool"
"gopkg.in/redis.v5/internal/proto"
) )
// Redis nil reply, .e.g. when key does not exist. // Redis nil reply, .e.g. when key does not exist.
@ -96,10 +97,6 @@ func (c *baseClient) WrapProcess(fn func(oldProcess func(cmd Cmder) error) func(
func (c *baseClient) defaultProcess(cmd Cmder) error { func (c *baseClient) defaultProcess(cmd Cmder) error {
for i := 0; i <= c.opt.MaxRetries; i++ { for i := 0; i <= c.opt.MaxRetries; i++ {
if i > 0 {
cmd.reset()
}
cn, _, err := c.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
cmd.setErr(err) cmd.setErr(err)
@ -162,6 +159,129 @@ func (c *baseClient) getAddr() string {
return c.opt.Addr return c.opt.Addr
} }
type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error)
func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer {
return func(cmds []Cmder) error {
var firstErr error
for i := 0; i <= c.opt.MaxRetries; i++ {
cn, _, err := c.conn()
if err != nil {
setCmdsErr(cmds, err)
return err
}
canRetry, err := p(cn, cmds)
c.putConn(cn, err, false)
if err == nil {
return nil
}
if firstErr == nil {
firstErr = err
}
if !canRetry || !internal.IsRetryableError(err) {
break
}
}
return firstErr
}
}
func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return true, err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
return pipelineReadCmds(cn, cmds)
}
func pipelineReadCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
for i, cmd := range cmds {
err := cmd.readReply(cn)
if err == nil {
continue
}
if i == 0 {
retry = true
}
if firstErr == nil {
firstErr = err
}
}
return false, firstErr
}
func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := txPipelineWriteMulti(cn, cmds); err != nil {
setCmdsErr(cmds, err)
return true, err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
if err := c.txPipelineReadQueued(cn, cmds); err != nil {
return false, err
}
_, err := pipelineReadCmds(cn, cmds)
return false, err
}
func txPipelineWriteMulti(cn *pool.Conn, cmds []Cmder) error {
multiExec := make([]Cmder, 0, len(cmds)+2)
multiExec = append(multiExec, NewStatusCmd("MULTI"))
multiExec = append(multiExec, cmds...)
multiExec = append(multiExec, NewSliceCmd("EXEC"))
return writeCmd(cn, multiExec...)
}
func (c *baseClient) txPipelineReadQueued(cn *pool.Conn, cmds []Cmder) error {
var firstErr error
// Parse queued replies.
var statusCmd StatusCmd
if err := statusCmd.readReply(cn); err != nil && firstErr == nil {
firstErr = err
}
for _, cmd := range cmds {
err := statusCmd.readReply(cn)
if err != nil {
cmd.setErr(err)
if firstErr == nil {
firstErr = err
}
}
}
// Parse number of replies.
line, err := cn.Rd.ReadLine()
if err != nil {
if err == Nil {
err = TxFailedErr
}
return err
}
switch line[0] {
case proto.ErrorReply:
return proto.ParseErrorReply(line)
case proto.ArrayReply:
// ok
default:
err := fmt.Errorf("redis: expected '*', but got line %q", line)
return err
}
return nil
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// Client is a Redis client representing a pool of zero or more // Client is a Redis client representing a pool of zero or more
@ -200,70 +320,30 @@ func (c *Client) PoolStats() *PoolStats {
} }
} }
func (c *Client) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn)
}
func (c *Client) Pipeline() *Pipeline { func (c *Client) Pipeline() *Pipeline {
pipe := Pipeline{ pipe := Pipeline{
exec: c.pipelineExec, exec: c.pipelineExecer(c.pipelineProcessCmds),
} }
pipe.cmdable.process = pipe.Process pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process pipe.statefulCmdable.process = pipe.Process
return &pipe return &pipe
} }
func (c *Client) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { func (c *Client) TxPipelined(fn func(*Pipeline) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn) return c.TxPipeline().pipelined(fn)
} }
func (c *Client) pipelineExec(cmds []Cmder) error { func (c *Client) TxPipeline() *Pipeline {
var firstErr error pipe := Pipeline{
for i := 0; i <= c.opt.MaxRetries; i++ { exec: c.pipelineExecer(c.txPipelineProcessCmds),
if i > 0 {
resetCmds(cmds)
} }
pipe.cmdable.process = pipe.Process
cn, _, err := c.conn() pipe.statefulCmdable.process = pipe.Process
if err != nil { return &pipe
setCmdsErr(cmds, err)
return err
}
retry, err := c.execCmds(cn, cmds)
c.putConn(cn, err, false)
if err == nil {
return nil
}
if firstErr == nil {
firstErr = err
}
if !retry {
break
}
}
return firstErr
}
func (c *Client) execCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return true, err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
for i, cmd := range cmds {
err := cmd.readReply(cn)
if err == nil {
continue
}
if i == 0 && internal.IsNetworkError(err) {
return true, err
}
if firstErr == nil {
firstErr = err
}
}
return false, firstErr
} }
func (c *Client) pubSub() *PubSub { func (c *Client) pubSub() *PubSub {

View File

@ -379,10 +379,6 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) {
var failedCmdsMap map[string][]Cmder var failedCmdsMap map[string][]Cmder
for name, cmds := range cmdsMap { for name, cmds := range cmdsMap {
if i > 0 {
resetCmds(cmds)
}
shard, err := c.shardByName(name) shard, err := c.shardByName(name)
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
@ -401,7 +397,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) {
continue continue
} }
retry, err := shard.Client.execCmds(cn, cmds) canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds)
shard.Client.putConn(cn, err, false) shard.Client.putConn(cn, err, false)
if err == nil { if err == nil {
continue continue
@ -409,7 +405,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) {
if firstErr == nil { if firstErr == nil {
firstErr = err firstErr = err
} }
if retry { if canRetry && internal.IsRetryableError(err) {
if failedCmdsMap == nil { if failedCmdsMap == nil {
failedCmdsMap = make(map[string][]Cmder) failedCmdsMap = make(map[string][]Cmder)
} }

94
tx.go
View File

@ -1,11 +1,8 @@
package redis package redis
import ( import (
"fmt"
"gopkg.in/redis.v5/internal" "gopkg.in/redis.v5/internal"
"gopkg.in/redis.v5/internal/pool" "gopkg.in/redis.v5/internal/pool"
"gopkg.in/redis.v5/internal/proto"
) )
// Redis transaction failed. // Redis transaction failed.
@ -19,8 +16,6 @@ type Tx struct {
cmdable cmdable
statefulCmdable statefulCmdable
baseClient baseClient
closed bool
} }
func (c *Client) newTx() *Tx { func (c *Client) newTx() *Tx {
@ -39,26 +34,20 @@ func (c *Client) Watch(fn func(*Tx) error, keys ...string) error {
tx := c.newTx() tx := c.newTx()
if len(keys) > 0 { if len(keys) > 0 {
if err := tx.Watch(keys...).Err(); err != nil { if err := tx.Watch(keys...).Err(); err != nil {
_ = tx.close() _ = tx.Close()
return err return err
} }
} }
firstErr := fn(tx) firstErr := fn(tx)
if err := tx.close(); err != nil && firstErr == nil { if err := tx.Close(); err != nil && firstErr == nil {
firstErr = err firstErr = err
} }
return firstErr return firstErr
} }
// close closes the transaction, releasing any open resources. // close closes the transaction, releasing any open resources.
func (c *Tx) close() error { func (c *Tx) Close() error {
if c.closed { _ = c.Unwatch().Err()
return nil
}
c.closed = true
if err := c.Unwatch().Err(); err != nil {
internal.Logf("Unwatch failed: %s", err)
}
return c.baseClient.Close() return c.baseClient.Close()
} }
@ -89,7 +78,7 @@ func (c *Tx) Unwatch(keys ...string) *StatusCmd {
func (c *Tx) Pipeline() *Pipeline { func (c *Tx) Pipeline() *Pipeline {
pipe := Pipeline{ pipe := Pipeline{
exec: c.exec, exec: c.pipelineExecer(c.txPipelineProcessCmds),
} }
pipe.cmdable.process = pipe.Process pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process pipe.statefulCmdable.process = pipe.Process
@ -108,76 +97,3 @@ func (c *Tx) Pipeline() *Pipeline {
func (c *Tx) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { func (c *Tx) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn) return c.Pipeline().pipelined(fn)
} }
func (c *Tx) exec(cmds []Cmder) error {
if c.closed {
return pool.ErrClosed
}
cn, _, err := c.conn()
if err != nil {
setCmdsErr(cmds, err)
return err
}
multiExec := make([]Cmder, 0, len(cmds)+2)
multiExec = append(multiExec, NewStatusCmd("MULTI"))
multiExec = append(multiExec, cmds...)
multiExec = append(multiExec, NewSliceCmd("EXEC"))
err = c.execCmds(cn, multiExec)
c.putConn(cn, err, false)
return err
}
func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error {
cn.SetWriteTimeout(c.opt.WriteTimeout)
err := writeCmd(cn, cmds...)
if err != nil {
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
// Omit last command (EXEC).
cmdsLen := len(cmds) - 1
// Parse queued replies.
statusCmd := cmds[0]
for i := 0; i < cmdsLen; i++ {
if err := statusCmd.readReply(cn); err != nil {
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
}
// Parse number of replies.
line, err := cn.Rd.ReadLine()
if err != nil {
if err == Nil {
err = TxFailedErr
}
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
if line[0] != proto.ArrayReply {
err := fmt.Errorf("redis: expected '*', but got line %q", line)
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
var firstErr error
// Parse replies.
// Loop starts from 1 to omit MULTI cmd.
for i := 1; i < cmdsLen; i++ {
cmd := cmds[i]
if err := cmd.readReply(cn); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}