Add TxPipeline.

This commit is contained in:
Vladimir Mihailenco 2016-12-13 17:28:39 +02:00
parent c6acf2ed15
commit 865d501d07
13 changed files with 577 additions and 590 deletions

View File

@ -1,6 +1,7 @@
package redis
import (
"fmt"
"math/rand"
"sync"
"sync/atomic"
@ -9,6 +10,7 @@ import (
"gopkg.in/redis.v5/internal"
"gopkg.in/redis.v5/internal/hashtag"
"gopkg.in/redis.v5/internal/pool"
"gopkg.in/redis.v5/internal/proto"
)
var errClusterNoNodes = internal.RedisError("redis: cluster has no nodes")
@ -417,10 +419,6 @@ func (c *ClusterClient) Process(cmd Cmder) error {
var ask bool
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
if attempt > 0 {
cmd.reset()
}
if ask {
pipe := node.Client.Pipeline()
pipe.Process(NewCmd("ASKING"))
@ -655,111 +653,252 @@ func (c *ClusterClient) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
}
func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
var firstErr error
setFirstErr := func(err error) {
if firstErr == nil {
firstErr = err
}
}
state := c.state()
cmdsMap := make(map[*clusterNode][]Cmder)
for _, cmd := range cmds {
_, node, err := c.cmdSlotAndNode(state, cmd)
cmdsMap, err := c.mapCmdsByNode(cmds)
if err != nil {
cmd.setErr(err)
setFirstErr(err)
continue
}
cmdsMap[node] = append(cmdsMap[node], cmd)
return err
}
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
for i := 0; i <= c.opt.MaxRedirects; i++ {
failedCmds := make(map[*clusterNode][]Cmder)
for node, cmds := range cmdsMap {
cn, _, err := node.Client.conn()
if err != nil {
setCmdsErr(cmds, err)
setFirstErr(err)
continue
}
failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds)
err = c.pipelineProcessCmds(cn, cmds, failedCmds)
node.Client.putConn(cn, err, false)
if err != nil {
setFirstErr(err)
}
}
if len(failedCmds) == 0 {
break
}
cmdsMap = failedCmds
}
var firstErr error
for _, cmd := range cmds {
if err := cmd.Err(); err != nil {
firstErr = err
break
}
}
return firstErr
}
func (c *ClusterClient) execClusterCmds(
func (c *ClusterClient) mapCmdsByNode(cmds []Cmder) (map[*clusterNode][]Cmder, error) {
state := c.state()
cmdsMap := make(map[*clusterNode][]Cmder)
for _, cmd := range cmds {
_, node, err := c.cmdSlotAndNode(state, cmd)
if err != nil {
return nil, err
}
cmdsMap[node] = append(cmdsMap[node], cmd)
}
return cmdsMap, nil
}
func (c *ClusterClient) pipelineProcessCmds(
cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder,
) (map[*clusterNode][]Cmder, error) {
) error {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return failedCmds, err
}
var firstErr error
setFirstErr := func(err error) {
if firstErr == nil {
firstErr = err
}
return err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
for i, cmd := range cmds {
return c.pipelineReadCmds(cn, cmds, failedCmds)
}
func (c *ClusterClient) pipelineReadCmds(
cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder,
) error {
var firstErr error
for _, cmd := range cmds {
err := cmd.readReply(cn)
if err == nil {
continue
}
if i == 0 && internal.IsRetryableError(err) {
node, err := c.nodes.Random()
if err != nil {
setFirstErr(err)
continue
if firstErr == nil {
firstErr = err
}
cmd.reset()
failedCmds[node] = append(failedCmds[node], cmds...)
break
err = c.checkMovedErr(cmd, failedCmds)
if err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
moved, ask, addr := internal.IsMovedError(err)
func (c *ClusterClient) checkMovedErr(cmd Cmder, failedCmds map[*clusterNode][]Cmder) error {
moved, ask, addr := internal.IsMovedError(cmd.Err())
if moved {
c.lazyReloadSlots()
node, err := c.nodes.Get(addr)
if err != nil {
setFirstErr(err)
continue
return err
}
cmd.reset()
failedCmds[node] = append(failedCmds[node], cmd)
} else if ask {
}
if ask {
node, err := c.nodes.Get(addr)
if err != nil {
setFirstErr(err)
return err
}
failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd)
}
return nil
}
func (c *ClusterClient) TxPipeline() *Pipeline {
pipe := Pipeline{
exec: c.txPipelineExec,
}
pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process
return &pipe
}
func (c *ClusterClient) TxPipelined(fn func(*Pipeline) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn)
}
func (c *ClusterClient) txPipelineExec(cmds []Cmder) error {
cmdsMap, err := c.mapCmdsBySlot(cmds)
if err != nil {
return err
}
for slot, cmds := range cmdsMap {
node, err := c.state().slotMasterNode(slot)
if err != nil {
setCmdsErr(cmds, err)
continue
}
cmd.reset()
failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd)
} else {
setFirstErr(err)
cmdsMap := map[*clusterNode][]Cmder{node: cmds}
for i := 0; i <= c.opt.MaxRedirects; i++ {
failedCmds := make(map[*clusterNode][]Cmder)
for node, cmds := range cmdsMap {
cn, _, err := node.Client.conn()
if err != nil {
setCmdsErr(cmds, err)
continue
}
err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds)
node.Client.putConn(cn, err, false)
}
if len(failedCmds) == 0 {
break
}
cmdsMap = failedCmds
}
}
return failedCmds, firstErr
var firstErr error
for _, cmd := range cmds {
if err := cmd.Err(); err != nil {
firstErr = err
break
}
}
return firstErr
}
func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) (map[int][]Cmder, error) {
state := c.state()
cmdsMap := make(map[int][]Cmder)
for _, cmd := range cmds {
slot, _, err := c.cmdSlotAndNode(state, cmd)
if err != nil {
return nil, err
}
cmdsMap[slot] = append(cmdsMap[slot], cmd)
}
return cmdsMap, nil
}
func (c *ClusterClient) txPipelineProcessCmds(
node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder,
) error {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := txPipelineWriteMulti(cn, cmds); err != nil {
setCmdsErr(cmds, err)
failedCmds[node] = cmds
return err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
if err := c.txPipelineReadQueued(cn, cmds, failedCmds); err != nil {
return err
}
_, err := pipelineReadCmds(cn, cmds)
return err
}
func (c *ClusterClient) txPipelineReadQueued(
cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder,
) error {
var firstErr error
// Parse queued replies.
var statusCmd StatusCmd
if err := statusCmd.readReply(cn); err != nil && firstErr == nil {
firstErr = err
}
for _, cmd := range cmds {
err := statusCmd.readReply(cn)
if err == nil {
continue
}
cmd.setErr(err)
if firstErr == nil {
firstErr = err
}
err = c.checkMovedErr(cmd, failedCmds)
if err != nil && firstErr == nil {
firstErr = err
}
}
// Parse number of replies.
line, err := cn.Rd.ReadLine()
if err != nil {
if err == Nil {
err = TxFailedErr
}
return err
}
switch line[0] {
case proto.ErrorReply:
return proto.ParseErrorReply(line)
case proto.ArrayReply:
// ok
default:
err := fmt.Errorf("redis: expected '*', but got line %q", line)
return err
}
return firstErr
}

View File

@ -373,14 +373,14 @@ var _ = Describe("ClusterClient", func() {
Expect(n).To(Equal(int64(100)))
})
Describe("pipeline", func() {
Describe("pipelining", func() {
var pipe *redis.Pipeline
assertPipeline := func() {
It("follows redirects", func() {
slot := hashtag.Slot("A")
Expect(client.SwapSlotNodes(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"}))
pipe := client.Pipeline()
defer pipe.Close()
keys := []string{"A", "B", "C", "D", "E", "F", "G"}
for i, key := range keys {
@ -429,6 +429,31 @@ var _ = Describe("ClusterClient", func() {
Expect(c.Err()).NotTo(HaveOccurred())
Expect(c.Val()).To(Equal("C_value"))
})
}
Describe("Pipeline", func() {
BeforeEach(func() {
pipe = client.Pipeline()
})
AfterEach(func() {
Expect(pipe.Close()).NotTo(HaveOccurred())
})
assertPipeline()
})
Describe("TxPipeline", func() {
BeforeEach(func() {
pipe = client.TxPipeline()
})
AfterEach(func() {
Expect(pipe.Close()).NotTo(HaveOccurred())
})
assertPipeline()
})
})
It("calls fn for every master node", func() {
@ -624,7 +649,7 @@ var _ = Describe("ClusterClient timeout", func() {
return client.ForEachNode(func(client *redis.Client) error {
return client.Ping().Err()
})
}, pause).ShouldNot(HaveOccurred())
}, 2*pause).ShouldNot(HaveOccurred())
})
testTimeout()

View File

@ -36,7 +36,6 @@ type Cmder interface {
readReply(*pool.Conn) error
setErr(error)
reset()
readTimeout() *time.Duration
@ -50,12 +49,6 @@ func setCmdsErr(cmds []Cmder, e error) {
}
}
func resetCmds(cmds []Cmder) {
for _, cmd := range cmds {
cmd.reset()
}
}
func writeCmd(cn *pool.Conn, cmds ...Cmder) error {
cn.Wb.Reset()
for _, cmd := range cmds {
@ -167,11 +160,6 @@ func NewCmd(args ...interface{}) *Cmd {
}
}
func (cmd *Cmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *Cmd) Val() interface{} {
return cmd.val
}
@ -185,16 +173,13 @@ func (cmd *Cmd) String() string {
}
func (cmd *Cmd) readReply(cn *pool.Conn) error {
val, err := cn.Rd.ReadReply(sliceParser)
if err != nil {
cmd.err = err
cmd.val, cmd.err = cn.Rd.ReadReply(sliceParser)
if cmd.err != nil {
return cmd.err
}
if b, ok := val.([]byte); ok {
if b, ok := cmd.val.([]byte); ok {
// Bytes must be copied, because underlying memory is reused.
cmd.val = string(b)
} else {
cmd.val = val
}
return nil
}
@ -212,11 +197,6 @@ func NewSliceCmd(args ...interface{}) *SliceCmd {
return &SliceCmd{baseCmd: cmd}
}
func (cmd *SliceCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *SliceCmd) Val() []interface{} {
return cmd.val
}
@ -230,10 +210,10 @@ func (cmd *SliceCmd) String() string {
}
func (cmd *SliceCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(sliceParser)
if err != nil {
cmd.err = err
return err
var v interface{}
v, cmd.err = cn.Rd.ReadArrayReply(sliceParser)
if cmd.err != nil {
return cmd.err
}
cmd.val = v.([]interface{})
return nil
@ -252,11 +232,6 @@ func NewStatusCmd(args ...interface{}) *StatusCmd {
return &StatusCmd{baseCmd: cmd}
}
func (cmd *StatusCmd) reset() {
cmd.val = ""
cmd.err = nil
}
func (cmd *StatusCmd) Val() string {
return cmd.val
}
@ -287,11 +262,6 @@ func NewIntCmd(args ...interface{}) *IntCmd {
return &IntCmd{baseCmd: cmd}
}
func (cmd *IntCmd) reset() {
cmd.val = 0
cmd.err = nil
}
func (cmd *IntCmd) Val() int64 {
return cmd.val
}
@ -326,11 +296,6 @@ func NewDurationCmd(precision time.Duration, args ...interface{}) *DurationCmd {
}
}
func (cmd *DurationCmd) reset() {
cmd.val = 0
cmd.err = nil
}
func (cmd *DurationCmd) Val() time.Duration {
return cmd.val
}
@ -344,10 +309,10 @@ func (cmd *DurationCmd) String() string {
}
func (cmd *DurationCmd) readReply(cn *pool.Conn) error {
n, err := cn.Rd.ReadIntReply()
if err != nil {
cmd.err = err
return err
var n int64
n, cmd.err = cn.Rd.ReadIntReply()
if cmd.err != nil {
return cmd.err
}
cmd.val = time.Duration(n) * cmd.precision
return nil
@ -368,11 +333,6 @@ func NewTimeCmd(args ...interface{}) *TimeCmd {
}
}
func (cmd *TimeCmd) reset() {
cmd.val = time.Time{}
cmd.err = nil
}
func (cmd *TimeCmd) Val() time.Time {
return cmd.val
}
@ -386,10 +346,10 @@ func (cmd *TimeCmd) String() string {
}
func (cmd *TimeCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(timeParser)
if err != nil {
cmd.err = err
return err
var v interface{}
v, cmd.err = cn.Rd.ReadArrayReply(timeParser)
if cmd.err != nil {
return cmd.err
}
cmd.val = v.(time.Time)
return nil
@ -408,11 +368,6 @@ func NewBoolCmd(args ...interface{}) *BoolCmd {
return &BoolCmd{baseCmd: cmd}
}
func (cmd *BoolCmd) reset() {
cmd.val = false
cmd.err = nil
}
func (cmd *BoolCmd) Val() bool {
return cmd.val
}
@ -428,27 +383,29 @@ func (cmd *BoolCmd) String() string {
var ok = []byte("OK")
func (cmd *BoolCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadReply(nil)
var v interface{}
v, cmd.err = cn.Rd.ReadReply(nil)
// `SET key value NX` returns nil when key already exists. But
// `SETNX key value` returns bool (0/1). So convert nil to bool.
// TODO: is this okay?
if err == Nil {
if cmd.err == Nil {
cmd.val = false
cmd.err = nil
return nil
}
if err != nil {
cmd.err = err
return err
if cmd.err != nil {
return cmd.err
}
switch vv := v.(type) {
switch v := v.(type) {
case int64:
cmd.val = vv == 1
cmd.val = v == 1
return nil
case []byte:
cmd.val = bytes.Equal(vv, ok)
cmd.val = bytes.Equal(v, ok)
return nil
default:
return fmt.Errorf("got %T, wanted int64 or string", v)
cmd.err = fmt.Errorf("got %T, wanted int64 or string", v)
return cmd.err
}
}
@ -465,11 +422,6 @@ func NewStringCmd(args ...interface{}) *StringCmd {
return &StringCmd{baseCmd: cmd}
}
func (cmd *StringCmd) reset() {
cmd.val = ""
cmd.err = nil
}
func (cmd *StringCmd) Val() string {
return cmd.val
}
@ -515,13 +467,8 @@ func (cmd *StringCmd) String() string {
}
func (cmd *StringCmd) readReply(cn *pool.Conn) error {
b, err := cn.Rd.ReadBytesReply()
if err != nil {
cmd.err = err
return err
}
cmd.val = string(b)
return nil
cmd.val, cmd.err = cn.Rd.ReadStringReply()
return cmd.err
}
//------------------------------------------------------------------------------
@ -537,11 +484,6 @@ func NewFloatCmd(args ...interface{}) *FloatCmd {
return &FloatCmd{baseCmd: cmd}
}
func (cmd *FloatCmd) reset() {
cmd.val = 0
cmd.err = nil
}
func (cmd *FloatCmd) Val() float64 {
return cmd.val
}
@ -572,11 +514,6 @@ func NewStringSliceCmd(args ...interface{}) *StringSliceCmd {
return &StringSliceCmd{baseCmd: cmd}
}
func (cmd *StringSliceCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *StringSliceCmd) Val() []string {
return cmd.val
}
@ -590,10 +527,10 @@ func (cmd *StringSliceCmd) String() string {
}
func (cmd *StringSliceCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(stringSliceParser)
if err != nil {
cmd.err = err
return err
var v interface{}
v, cmd.err = cn.Rd.ReadArrayReply(stringSliceParser)
if cmd.err != nil {
return cmd.err
}
cmd.val = v.([]string)
return nil
@ -612,11 +549,6 @@ func NewBoolSliceCmd(args ...interface{}) *BoolSliceCmd {
return &BoolSliceCmd{baseCmd: cmd}
}
func (cmd *BoolSliceCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *BoolSliceCmd) Val() []bool {
return cmd.val
}
@ -630,10 +562,10 @@ func (cmd *BoolSliceCmd) String() string {
}
func (cmd *BoolSliceCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(boolSliceParser)
if err != nil {
cmd.err = err
return err
var v interface{}
v, cmd.err = cn.Rd.ReadArrayReply(boolSliceParser)
if cmd.err != nil {
return cmd.err
}
cmd.val = v.([]bool)
return nil
@ -652,11 +584,6 @@ func NewStringStringMapCmd(args ...interface{}) *StringStringMapCmd {
return &StringStringMapCmd{baseCmd: cmd}
}
func (cmd *StringStringMapCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *StringStringMapCmd) Val() map[string]string {
return cmd.val
}
@ -670,10 +597,10 @@ func (cmd *StringStringMapCmd) String() string {
}
func (cmd *StringStringMapCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(stringStringMapParser)
if err != nil {
cmd.err = err
return err
var v interface{}
v, cmd.err = cn.Rd.ReadArrayReply(stringStringMapParser)
if cmd.err != nil {
return cmd.err
}
cmd.val = v.(map[string]string)
return nil
@ -704,16 +631,11 @@ func (cmd *StringIntMapCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *StringIntMapCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *StringIntMapCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(stringIntMapParser)
if err != nil {
cmd.err = err
return err
var v interface{}
v, cmd.err = cn.Rd.ReadArrayReply(stringIntMapParser)
if cmd.err != nil {
return cmd.err
}
cmd.val = v.(map[string]int64)
return nil
@ -732,11 +654,6 @@ func NewZSliceCmd(args ...interface{}) *ZSliceCmd {
return &ZSliceCmd{baseCmd: cmd}
}
func (cmd *ZSliceCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *ZSliceCmd) Val() []Z {
return cmd.val
}
@ -750,10 +667,10 @@ func (cmd *ZSliceCmd) String() string {
}
func (cmd *ZSliceCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(zSliceParser)
if err != nil {
cmd.err = err
return err
var v interface{}
v, cmd.err = cn.Rd.ReadArrayReply(zSliceParser)
if cmd.err != nil {
return cmd.err
}
cmd.val = v.([]Z)
return nil
@ -775,12 +692,6 @@ func NewScanCmd(args ...interface{}) *ScanCmd {
}
}
func (cmd *ScanCmd) reset() {
cmd.cursor = 0
cmd.page = nil
cmd.err = nil
}
func (cmd *ScanCmd) Val() (keys []string, cursor uint64) {
return cmd.page, cmd.cursor
}
@ -794,15 +705,9 @@ func (cmd *ScanCmd) String() string {
}
func (cmd *ScanCmd) readReply(cn *pool.Conn) error {
page, cursor, err := cn.Rd.ReadScanReply()
if err != nil {
cmd.err = err
cmd.page, cmd.cursor, cmd.err = cn.Rd.ReadScanReply()
return cmd.err
}
cmd.page = page
cmd.cursor = cursor
return nil
}
//------------------------------------------------------------------------------
@ -840,16 +745,11 @@ func (cmd *ClusterSlotsCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *ClusterSlotsCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *ClusterSlotsCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(clusterSlotsParser)
if err != nil {
cmd.err = err
return err
var v interface{}
v, cmd.err = cn.Rd.ReadArrayReply(clusterSlotsParser)
if cmd.err != nil {
return cmd.err
}
cmd.val = v.([]ClusterSlot)
return nil
@ -913,11 +813,6 @@ func NewGeoLocationCmd(q *GeoRadiusQuery, args ...interface{}) *GeoLocationCmd {
}
}
func (cmd *GeoLocationCmd) reset() {
cmd.locations = nil
cmd.err = nil
}
func (cmd *GeoLocationCmd) Val() []GeoLocation {
return cmd.locations
}
@ -931,12 +826,12 @@ func (cmd *GeoLocationCmd) String() string {
}
func (cmd *GeoLocationCmd) readReply(cn *pool.Conn) error {
reply, err := cn.Rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q))
if err != nil {
cmd.err = err
return err
var v interface{}
v, cmd.err = cn.Rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q))
if cmd.err != nil {
return cmd.err
}
cmd.locations = reply.([]GeoLocation)
cmd.locations = v.([]GeoLocation)
return nil
}
@ -969,18 +864,13 @@ func (cmd *GeoPosCmd) String() string {
return cmdString(cmd, cmd.positions)
}
func (cmd *GeoPosCmd) reset() {
cmd.positions = nil
cmd.err = nil
}
func (cmd *GeoPosCmd) readReply(cn *pool.Conn) error {
reply, err := cn.Rd.ReadArrayReply(geoPosSliceParser)
if err != nil {
cmd.err = err
return err
var v interface{}
v, cmd.err = cn.Rd.ReadArrayReply(geoPosSliceParser)
if cmd.err != nil {
return cmd.err
}
cmd.positions = reply.([]*GeoPos)
cmd.positions = v.([]*GeoPos)
return nil
}
@ -1019,16 +909,11 @@ func (cmd *CommandsInfoCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *CommandsInfoCmd) reset() {
cmd.val = nil
cmd.err = nil
}
func (cmd *CommandsInfoCmd) readReply(cn *pool.Conn) error {
v, err := cn.Rd.ReadArrayReply(commandInfoSliceParser)
if err != nil {
cmd.err = err
return err
var v interface{}
v, cmd.err = cn.Rd.ReadArrayReply(commandInfoSliceParser)
if cmd.err != nil {
return cmd.err
}
cmd.val = v.(map[string]*CommandInfo)
return nil

View File

@ -69,3 +69,7 @@ func IsMovedError(err error) (moved bool, ask bool, addr string) {
func IsLoadingError(err error) bool {
return strings.HasPrefix(err.Error(), "LOADING")
}
func IsExecAbortError(err error) bool {
return strings.HasPrefix(err.Error(), "EXECABORT")
}

View File

@ -70,7 +70,7 @@ func (p *Reader) ReadReply(m MultiBulkParse) (interface{}, error) {
switch line[0] {
case ErrorReply:
return nil, parseErrorValue(line)
return nil, ParseErrorReply(line)
case StatusReply:
return parseStatusValue(line)
case IntReply:
@ -94,7 +94,7 @@ func (p *Reader) ReadIntReply() (int64, error) {
}
switch line[0] {
case ErrorReply:
return 0, parseErrorValue(line)
return 0, ParseErrorReply(line)
case IntReply:
return parseIntValue(line)
default:
@ -109,7 +109,7 @@ func (p *Reader) ReadBytesReply() ([]byte, error) {
}
switch line[0] {
case ErrorReply:
return nil, parseErrorValue(line)
return nil, ParseErrorReply(line)
case StringReply:
return p.readBytesValue(line)
case StatusReply:
@ -142,7 +142,7 @@ func (p *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) {
}
switch line[0] {
case ErrorReply:
return nil, parseErrorValue(line)
return nil, ParseErrorReply(line)
case ArrayReply:
n, err := parseArrayLen(line)
if err != nil {
@ -161,7 +161,7 @@ func (p *Reader) ReadArrayLen() (int64, error) {
}
switch line[0] {
case ErrorReply:
return 0, parseErrorValue(line)
return 0, ParseErrorReply(line)
case ArrayReply:
return parseArrayLen(line)
default:
@ -272,7 +272,7 @@ func isNilReply(b []byte) bool {
b[1] == '-' && b[2] == '1'
}
func parseErrorValue(line []byte) error {
func ParseErrorReply(line []byte) error {
return internal.RedisError(string(line[1:]))
}

View File

@ -58,7 +58,6 @@ func (it *ScanIterator) Next() bool {
} else {
it.ScanCmd._args[2] = it.ScanCmd.cursor
}
it.ScanCmd.reset()
it.client.process(it.ScanCmd)
if it.ScanCmd.Err() != nil {
return false

View File

@ -7,6 +7,8 @@ import (
"gopkg.in/redis.v5/internal/pool"
)
type pipelineExecer func([]Cmder) error
// Pipeline implements pipelining as described in
// http://redis.io/topics/pipelining. It's safe for concurrent use
// by multiple goroutines.
@ -14,7 +16,7 @@ type Pipeline struct {
cmdable
statefulCmdable
exec func([]Cmder) error
exec pipelineExecer
mu sync.Mutex
cmds []Cmder

View File

@ -1,17 +1,15 @@
package redis_test
import (
"strconv"
"sync"
"gopkg.in/redis.v5"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Pipeline", func() {
var _ = Describe("pipelining", func() {
var client *redis.Client
var pipe *redis.Pipeline
BeforeEach(func() {
client = redis.NewClient(redisOptions())
@ -22,44 +20,7 @@ var _ = Describe("Pipeline", func() {
Expect(client.Close()).NotTo(HaveOccurred())
})
It("should pipeline", func() {
set := client.Set("key2", "hello2", 0)
Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK"))
pipeline := client.Pipeline()
set = pipeline.Set("key1", "hello1", 0)
get := pipeline.Get("key2")
incr := pipeline.Incr("key3")
getNil := pipeline.Get("key4")
cmds, err := pipeline.Exec()
Expect(err).To(Equal(redis.Nil))
Expect(cmds).To(HaveLen(4))
Expect(pipeline.Close()).NotTo(HaveOccurred())
Expect(set.Err()).NotTo(HaveOccurred())
Expect(set.Val()).To(Equal("OK"))
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal("hello2"))
Expect(incr.Err()).NotTo(HaveOccurred())
Expect(incr.Val()).To(Equal(int64(1)))
Expect(getNil.Err()).To(Equal(redis.Nil))
Expect(getNil.Val()).To(Equal(""))
})
It("discards queued commands", func() {
pipeline := client.Pipeline()
pipeline.Get("key")
pipeline.Discard()
_, err := pipeline.Exec()
Expect(err).To(MatchError("redis: pipeline is empty"))
})
It("should support block style", func() {
It("supports block style", func() {
var get *redis.StringCmd
cmds, err := client.Pipelined(func(pipe *redis.Pipeline) error {
get = pipe.Get("foo")
@ -72,98 +33,47 @@ var _ = Describe("Pipeline", func() {
Expect(get.Val()).To(Equal(""))
})
It("should handle vals/err", func() {
pipeline := client.Pipeline()
get := pipeline.Get("key")
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal(""))
Expect(pipeline.Close()).NotTo(HaveOccurred())
})
assertPipeline := func() {
It("returns an error when there are no commands", func() {
pipeline := client.Pipeline()
_, err := pipeline.Exec()
_, err := pipe.Exec()
Expect(err).To(MatchError("redis: pipeline is empty"))
})
It("should increment correctly", func() {
const N = 20000
key := "TestPipelineIncr"
pipeline := client.Pipeline()
for i := 0; i < N; i++ {
pipeline.Incr(key)
}
It("discards queued commands", func() {
pipe.Get("key")
pipe.Discard()
_, err := pipe.Exec()
Expect(err).To(MatchError("redis: pipeline is empty"))
})
cmds, err := pipeline.Exec()
It("handles val/err", func() {
err := client.Set("key", "value", 0).Err()
Expect(err).NotTo(HaveOccurred())
Expect(pipeline.Close()).NotTo(HaveOccurred())
Expect(len(cmds)).To(Equal(20000))
for _, cmd := range cmds {
Expect(cmd.Err()).NotTo(HaveOccurred())
}
get := client.Get(key)
Expect(get.Err()).NotTo(HaveOccurred())
Expect(get.Val()).To(Equal(strconv.Itoa(N)))
})
It("should PipelineEcho", func() {
const N = 1000
wg := &sync.WaitGroup{}
wg.Add(N)
for i := 0; i < N; i++ {
go func(i int) {
defer GinkgoRecover()
defer wg.Done()
pipeline := client.Pipeline()
msg1 := "echo" + strconv.Itoa(i)
msg2 := "echo" + strconv.Itoa(i+1)
echo1 := pipeline.Echo(msg1)
echo2 := pipeline.Echo(msg2)
cmds, err := pipeline.Exec()
get := pipe.Get("key")
cmds, err := pipe.Exec()
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(2))
Expect(cmds).To(HaveLen(1))
Expect(echo1.Err()).NotTo(HaveOccurred())
Expect(echo1.Val()).To(Equal(msg1))
Expect(echo2.Err()).NotTo(HaveOccurred())
Expect(echo2.Val()).To(Equal(msg2))
Expect(pipeline.Close()).NotTo(HaveOccurred())
}(i)
}
wg.Wait()
})
It("should be thread-safe", func() {
const N = 1000
pipeline := client.Pipeline()
var wg sync.WaitGroup
wg.Add(N)
for i := 0; i < N; i++ {
go func() {
defer GinkgoRecover()
pipeline.Ping()
wg.Done()
}()
}
wg.Wait()
cmds, err := pipeline.Exec()
val, err := get.Result()
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N))
Expect(val).To(Equal("value"))
})
}
Expect(pipeline.Close()).NotTo(HaveOccurred())
Describe("Pipeline", func() {
BeforeEach(func() {
pipe = client.Pipeline()
})
assertPipeline()
})
Describe("TxPipeline", func() {
BeforeEach(func() {
pipe = client.TxPipeline()
})
assertPipeline()
})
})

View File

@ -245,4 +245,35 @@ var _ = Describe("races", func() {
Expect(val).To(Equal(int64(C * N)))
})
It("should Pipeline", func() {
perform(C, func(id int) {
pipe := client.Pipeline()
for i := 0; i < N; i++ {
pipe.Echo(fmt.Sprint(i))
}
cmds, err := pipe.Exec()
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N))
for i := 0; i < N; i++ {
Expect(cmds[i].(*redis.StringCmd).Val()).To(Equal(fmt.Sprint(i)))
}
})
})
It("should Pipeline", func() {
pipe := client.Pipeline()
perform(N, func(id int) {
pipe.Incr("key")
})
cmds, err := pipe.Exec()
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(N))
n, err := client.Get("key").Int64()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(N)))
})
})

194
redis.go
View File

@ -7,6 +7,7 @@ import (
"gopkg.in/redis.v5/internal"
"gopkg.in/redis.v5/internal/pool"
"gopkg.in/redis.v5/internal/proto"
)
// Redis nil reply, .e.g. when key does not exist.
@ -96,10 +97,6 @@ func (c *baseClient) WrapProcess(fn func(oldProcess func(cmd Cmder) error) func(
func (c *baseClient) defaultProcess(cmd Cmder) error {
for i := 0; i <= c.opt.MaxRetries; i++ {
if i > 0 {
cmd.reset()
}
cn, _, err := c.conn()
if err != nil {
cmd.setErr(err)
@ -162,6 +159,129 @@ func (c *baseClient) getAddr() string {
return c.opt.Addr
}
type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error)
func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer {
return func(cmds []Cmder) error {
var firstErr error
for i := 0; i <= c.opt.MaxRetries; i++ {
cn, _, err := c.conn()
if err != nil {
setCmdsErr(cmds, err)
return err
}
canRetry, err := p(cn, cmds)
c.putConn(cn, err, false)
if err == nil {
return nil
}
if firstErr == nil {
firstErr = err
}
if !canRetry || !internal.IsRetryableError(err) {
break
}
}
return firstErr
}
}
func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return true, err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
return pipelineReadCmds(cn, cmds)
}
func pipelineReadCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
for i, cmd := range cmds {
err := cmd.readReply(cn)
if err == nil {
continue
}
if i == 0 {
retry = true
}
if firstErr == nil {
firstErr = err
}
}
return false, firstErr
}
func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := txPipelineWriteMulti(cn, cmds); err != nil {
setCmdsErr(cmds, err)
return true, err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
if err := c.txPipelineReadQueued(cn, cmds); err != nil {
return false, err
}
_, err := pipelineReadCmds(cn, cmds)
return false, err
}
func txPipelineWriteMulti(cn *pool.Conn, cmds []Cmder) error {
multiExec := make([]Cmder, 0, len(cmds)+2)
multiExec = append(multiExec, NewStatusCmd("MULTI"))
multiExec = append(multiExec, cmds...)
multiExec = append(multiExec, NewSliceCmd("EXEC"))
return writeCmd(cn, multiExec...)
}
func (c *baseClient) txPipelineReadQueued(cn *pool.Conn, cmds []Cmder) error {
var firstErr error
// Parse queued replies.
var statusCmd StatusCmd
if err := statusCmd.readReply(cn); err != nil && firstErr == nil {
firstErr = err
}
for _, cmd := range cmds {
err := statusCmd.readReply(cn)
if err != nil {
cmd.setErr(err)
if firstErr == nil {
firstErr = err
}
}
}
// Parse number of replies.
line, err := cn.Rd.ReadLine()
if err != nil {
if err == Nil {
err = TxFailedErr
}
return err
}
switch line[0] {
case proto.ErrorReply:
return proto.ParseErrorReply(line)
case proto.ArrayReply:
// ok
default:
err := fmt.Errorf("redis: expected '*', but got line %q", line)
return err
}
return nil
}
//------------------------------------------------------------------------------
// Client is a Redis client representing a pool of zero or more
@ -202,70 +322,30 @@ func (c *Client) PoolStats() *PoolStats {
}
}
func (c *Client) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn)
}
func (c *Client) Pipeline() *Pipeline {
pipe := Pipeline{
exec: c.pipelineExec,
exec: c.pipelineExecer(c.pipelineProcessCmds),
}
pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process
return &pipe
}
func (c *Client) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn)
func (c *Client) TxPipelined(fn func(*Pipeline) error) ([]Cmder, error) {
return c.TxPipeline().pipelined(fn)
}
func (c *Client) pipelineExec(cmds []Cmder) error {
var firstErr error
for i := 0; i <= c.opt.MaxRetries; i++ {
if i > 0 {
resetCmds(cmds)
func (c *Client) TxPipeline() *Pipeline {
pipe := Pipeline{
exec: c.pipelineExecer(c.txPipelineProcessCmds),
}
cn, _, err := c.conn()
if err != nil {
setCmdsErr(cmds, err)
return err
}
retry, err := c.execCmds(cn, cmds)
c.putConn(cn, err, false)
if err == nil {
return nil
}
if firstErr == nil {
firstErr = err
}
if !retry {
break
}
}
return firstErr
}
func (c *Client) execCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
return true, err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
for i, cmd := range cmds {
err := cmd.readReply(cn)
if err == nil {
continue
}
if i == 0 && internal.IsNetworkError(err) {
return true, err
}
if firstErr == nil {
firstErr = err
}
}
return false, firstErr
pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process
return &pipe
}
func (c *Client) pubSub() *PubSub {

View File

@ -381,10 +381,6 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) {
var failedCmdsMap map[string][]Cmder
for name, cmds := range cmdsMap {
if i > 0 {
resetCmds(cmds)
}
shard, err := c.shardByName(name)
if err != nil {
setCmdsErr(cmds, err)
@ -403,7 +399,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) {
continue
}
retry, err := shard.Client.execCmds(cn, cmds)
canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds)
shard.Client.putConn(cn, err, false)
if err == nil {
continue
@ -411,7 +407,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) {
if firstErr == nil {
firstErr = err
}
if retry {
if canRetry && internal.IsRetryableError(err) {
if failedCmdsMap == nil {
failedCmdsMap = make(map[string][]Cmder)
}

94
tx.go
View File

@ -1,11 +1,8 @@
package redis
import (
"fmt"
"gopkg.in/redis.v5/internal"
"gopkg.in/redis.v5/internal/pool"
"gopkg.in/redis.v5/internal/proto"
)
// Redis transaction failed.
@ -19,8 +16,6 @@ type Tx struct {
cmdable
statefulCmdable
baseClient
closed bool
}
var _ Cmdable = (*Tx)(nil)
@ -41,26 +36,20 @@ func (c *Client) Watch(fn func(*Tx) error, keys ...string) error {
tx := c.newTx()
if len(keys) > 0 {
if err := tx.Watch(keys...).Err(); err != nil {
_ = tx.close()
_ = tx.Close()
return err
}
}
firstErr := fn(tx)
if err := tx.close(); err != nil && firstErr == nil {
if err := tx.Close(); err != nil && firstErr == nil {
firstErr = err
}
return firstErr
}
// close closes the transaction, releasing any open resources.
func (c *Tx) close() error {
if c.closed {
return nil
}
c.closed = true
if err := c.Unwatch().Err(); err != nil {
internal.Logf("Unwatch failed: %s", err)
}
func (c *Tx) Close() error {
_ = c.Unwatch().Err()
return c.baseClient.Close()
}
@ -91,7 +80,7 @@ func (c *Tx) Unwatch(keys ...string) *StatusCmd {
func (c *Tx) Pipeline() *Pipeline {
pipe := Pipeline{
exec: c.exec,
exec: c.pipelineExecer(c.txPipelineProcessCmds),
}
pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process
@ -110,76 +99,3 @@ func (c *Tx) Pipeline() *Pipeline {
func (c *Tx) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn)
}
func (c *Tx) exec(cmds []Cmder) error {
if c.closed {
return pool.ErrClosed
}
cn, _, err := c.conn()
if err != nil {
setCmdsErr(cmds, err)
return err
}
multiExec := make([]Cmder, 0, len(cmds)+2)
multiExec = append(multiExec, NewStatusCmd("MULTI"))
multiExec = append(multiExec, cmds...)
multiExec = append(multiExec, NewSliceCmd("EXEC"))
err = c.execCmds(cn, multiExec)
c.putConn(cn, err, false)
return err
}
func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error {
cn.SetWriteTimeout(c.opt.WriteTimeout)
err := writeCmd(cn, cmds...)
if err != nil {
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
// Omit last command (EXEC).
cmdsLen := len(cmds) - 1
// Parse queued replies.
statusCmd := cmds[0]
for i := 0; i < cmdsLen; i++ {
if err := statusCmd.readReply(cn); err != nil {
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
}
// Parse number of replies.
line, err := cn.Rd.ReadLine()
if err != nil {
if err == Nil {
err = TxFailedErr
}
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
if line[0] != proto.ArrayReply {
err := fmt.Errorf("redis: expected '*', but got line %q", line)
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
var firstErr error
// Parse replies.
// Loop starts from 1 to omit MULTI cmd.
for i := 1; i < cmdsLen; i++ {
cmd := cmds[i]
if err := cmd.readReply(cn); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}