diff --git a/cluster.go b/cluster.go index 29b0cfb..2922238 100644 --- a/cluster.go +++ b/cluster.go @@ -316,6 +316,92 @@ func (c *ClusterClient) reaper(frequency time.Duration) { } } +func (c *ClusterClient) Pipeline() *Pipeline { + pipe := &Pipeline{ + exec: c.pipelineExec, + } + pipe.commandable.process = pipe.process + return pipe +} + +func (c *ClusterClient) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { + return c.Pipeline().pipelined(fn) +} + +func (c *ClusterClient) pipelineExec(cmds []Cmder) error { + var retErr error + + cmdsMap := make(map[string][]Cmder) + for _, cmd := range cmds { + slot := hashtag.Slot(cmd.clusterKey()) + addr := c.slotMasterAddr(slot) + cmdsMap[addr] = append(cmdsMap[addr], cmd) + } + + for attempt := 0; attempt <= c.opt.getMaxRedirects(); attempt++ { + failedCmds := make(map[string][]Cmder) + + for addr, cmds := range cmdsMap { + client, err := c.getClient(addr) + if err != nil { + setCmdsErr(cmds, err) + retErr = err + continue + } + + cn, err := client.conn() + if err != nil { + setCmdsErr(cmds, err) + retErr = err + continue + } + + failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds) + if err != nil { + retErr = err + } + client.putConn(cn, err, false) + } + + cmdsMap = failedCmds + } + + return retErr +} + +func (c *ClusterClient) execClusterCmds( + cn *pool.Conn, cmds []Cmder, failedCmds map[string][]Cmder, +) (map[string][]Cmder, error) { + if err := writeCmd(cn, cmds...); err != nil { + setCmdsErr(cmds, err) + return failedCmds, err + } + + var firstCmdErr error + for i, cmd := range cmds { + err := cmd.readReply(cn) + if err == nil { + continue + } + if isNetworkError(err) { + cmd.reset() + failedCmds[""] = append(failedCmds[""], cmds[i:]...) + break + } else if moved, ask, addr := isMovedError(err); moved { + c.lazyReloadSlots() + cmd.reset() + failedCmds[addr] = append(failedCmds[addr], cmd) + } else if ask { + cmd.reset() + failedCmds[addr] = append(failedCmds[addr], NewCmd("ASKING"), cmd) + } else if firstCmdErr == nil { + firstCmdErr = err + } + } + + return failedCmds, firstCmdErr +} + //------------------------------------------------------------------------------ // ClusterOptions are used to configure a cluster client and should be diff --git a/cluster_pipeline.go b/cluster_pipeline.go deleted file mode 100644 index 883c770..0000000 --- a/cluster_pipeline.go +++ /dev/null @@ -1,140 +0,0 @@ -package redis - -import ( - "gopkg.in/redis.v3/internal/hashtag" - "gopkg.in/redis.v3/internal/pool" -) - -// ClusterPipeline is not thread-safe. -type ClusterPipeline struct { - commandable - - cluster *ClusterClient - - cmds []Cmder - closed bool -} - -// Pipeline creates a new pipeline which is able to execute commands -// against multiple shards. It's NOT safe for concurrent use by -// multiple goroutines. -func (c *ClusterClient) Pipeline() *ClusterPipeline { - pipe := &ClusterPipeline{ - cluster: c, - cmds: make([]Cmder, 0, 10), - } - pipe.commandable.process = pipe.process - return pipe -} - -func (c *ClusterClient) Pipelined(fn func(*ClusterPipeline) error) ([]Cmder, error) { - pipe := c.Pipeline() - if err := fn(pipe); err != nil { - return nil, err - } - cmds, err := pipe.Exec() - _ = pipe.Close() - return cmds, err -} - -func (pipe *ClusterPipeline) process(cmd Cmder) { - pipe.cmds = append(pipe.cmds, cmd) -} - -// Discard resets the pipeline and discards queued commands. -func (pipe *ClusterPipeline) Discard() error { - if pipe.closed { - return pool.ErrClosed - } - pipe.cmds = pipe.cmds[:0] - return nil -} - -func (pipe *ClusterPipeline) Exec() (cmds []Cmder, retErr error) { - if pipe.closed { - return nil, pool.ErrClosed - } - if len(pipe.cmds) == 0 { - return []Cmder{}, nil - } - - cmds = pipe.cmds - pipe.cmds = make([]Cmder, 0, 10) - - cmdsMap := make(map[string][]Cmder) - for _, cmd := range cmds { - slot := hashtag.Slot(cmd.clusterKey()) - addr := pipe.cluster.slotMasterAddr(slot) - cmdsMap[addr] = append(cmdsMap[addr], cmd) - } - - for attempt := 0; attempt <= pipe.cluster.opt.getMaxRedirects(); attempt++ { - failedCmds := make(map[string][]Cmder) - - for addr, cmds := range cmdsMap { - client, err := pipe.cluster.getClient(addr) - if err != nil { - setCmdsErr(cmds, err) - retErr = err - continue - } - - cn, err := client.conn() - if err != nil { - setCmdsErr(cmds, err) - retErr = err - continue - } - - failedCmds, err = pipe.execClusterCmds(cn, cmds, failedCmds) - if err != nil { - retErr = err - } - client.putConn(cn, err, false) - } - - cmdsMap = failedCmds - } - - return cmds, retErr -} - -// Close closes the pipeline, releasing any open resources. -func (pipe *ClusterPipeline) Close() error { - pipe.Discard() - pipe.closed = true - return nil -} - -func (pipe *ClusterPipeline) execClusterCmds( - cn *pool.Conn, cmds []Cmder, failedCmds map[string][]Cmder, -) (map[string][]Cmder, error) { - if err := writeCmd(cn, cmds...); err != nil { - setCmdsErr(cmds, err) - return failedCmds, err - } - - var firstCmdErr error - for i, cmd := range cmds { - err := cmd.readReply(cn) - if err == nil { - continue - } - if isNetworkError(err) { - cmd.reset() - failedCmds[""] = append(failedCmds[""], cmds[i:]...) - break - } else if moved, ask, addr := isMovedError(err); moved { - pipe.cluster.lazyReloadSlots() - cmd.reset() - failedCmds[addr] = append(failedCmds[addr], cmd) - } else if ask { - cmd.reset() - failedCmds[addr] = append(failedCmds[addr], NewCmd("ASKING"), cmd) - } else if firstCmdErr == nil { - firstCmdErr = err - } - } - - return failedCmds, firstCmdErr -} diff --git a/cluster_test.go b/cluster_test.go index ff091a2..e7aa9b9 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -449,7 +449,7 @@ var _ = Describe("Cluster", func() { Expect(client.Set("C", "C_value", 0).Err()).NotTo(HaveOccurred()) var a, b, c *redis.StringCmd - cmds, err := client.Pipelined(func(pipe *redis.ClusterPipeline) error { + cmds, err := client.Pipelined(func(pipe *redis.Pipeline) error { a = pipe.Get("A") b = pipe.Get("B") c = pipe.Get("C") diff --git a/pipeline.go b/pipeline.go index 888d8c4..95d389f 100644 --- a/pipeline.go +++ b/pipeline.go @@ -13,7 +13,7 @@ import ( type Pipeline struct { commandable - client baseClient + exec func([]Cmder) error mu sync.Mutex // protects cmds cmds []Cmder @@ -21,25 +21,6 @@ type Pipeline struct { closed int32 } -func (c *Client) Pipeline() *Pipeline { - pipe := &Pipeline{ - client: c.baseClient, - cmds: make([]Cmder, 0, 10), - } - pipe.commandable.process = pipe.process - return pipe -} - -func (c *Client) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { - pipe := c.Pipeline() - if err := fn(pipe); err != nil { - return nil, err - } - cmds, err := pipe.Exec() - _ = pipe.Close() - return cmds, err -} - func (pipe *Pipeline) process(cmd Cmder) { pipe.mu.Lock() pipe.cmds = append(pipe.cmds, cmd) @@ -73,7 +54,7 @@ func (pipe *Pipeline) Discard() error { // // Exec always returns list of commands and error of the first failed // command if any. -func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) { +func (pipe *Pipeline) Exec() ([]Cmder, error) { if pipe.isClosed() { return nil, pool.ErrClosed } @@ -85,31 +66,19 @@ func (pipe *Pipeline) Exec() (cmds []Cmder, retErr error) { return pipe.cmds, nil } - cmds = pipe.cmds - pipe.cmds = make([]Cmder, 0, 10) + cmds := pipe.cmds + pipe.cmds = nil - failedCmds := cmds - for i := 0; i <= pipe.client.opt.MaxRetries; i++ { - cn, err := pipe.client.conn() - if err != nil { - setCmdsErr(failedCmds, err) - return cmds, err - } + return cmds, pipe.exec(cmds) +} - if i > 0 { - resetCmds(failedCmds) - } - failedCmds, err = execCmds(cn, failedCmds) - pipe.client.putConn(cn, err, false) - if err != nil && retErr == nil { - retErr = err - } - if len(failedCmds) == 0 { - break - } +func (pipe *Pipeline) pipelined(fn func(*Pipeline) error) ([]Cmder, error) { + if err := fn(pipe); err != nil { + return nil, err } - - return cmds, retErr + cmds, err := pipe.Exec() + _ = pipe.Close() + return cmds, err } func execCmds(cn *pool.Conn, cmds []Cmder) ([]Cmder, error) { diff --git a/redis.go b/redis.go index aed713c..ee4e69a 100644 --- a/redis.go +++ b/redis.go @@ -174,3 +174,40 @@ func (c *Client) PoolStats() *PoolStats { FreeConns: s.FreeConns, } } + +func (c *Client) Pipeline() *Pipeline { + pipe := &Pipeline{ + exec: c.pipelineExec, + } + pipe.commandable.process = pipe.process + return pipe +} + +func (c *Client) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { + return c.Pipeline().pipelined(fn) +} + +func (c *Client) pipelineExec(cmds []Cmder) error { + var retErr error + failedCmds := cmds + for i := 0; i <= c.opt.MaxRetries; i++ { + cn, err := c.conn() + if err != nil { + setCmdsErr(failedCmds, err) + return err + } + + if i > 0 { + resetCmds(failedCmds) + } + failedCmds, err = execCmds(cn, failedCmds) + c.putConn(cn, err, false) + if err != nil && retErr == nil { + retErr = err + } + if len(failedCmds) == 0 { + break + } + } + return retErr +} diff --git a/ring.go b/ring.go index fd74b3b..02be83c 100644 --- a/ring.go +++ b/ring.go @@ -241,66 +241,24 @@ func (ring *Ring) Close() (retErr error) { return retErr } -// RingPipeline creates a new pipeline which is able to execute commands -// against multiple shards. It's NOT safe for concurrent use by -// multiple goroutines. -type RingPipeline struct { - commandable - - ring *Ring - - cmds []Cmder - closed bool -} - -func (ring *Ring) Pipeline() *RingPipeline { - pipe := &RingPipeline{ - ring: ring, - cmds: make([]Cmder, 0, 10), +func (ring *Ring) Pipeline() *Pipeline { + pipe := &Pipeline{ + exec: ring.pipelineExec, } pipe.commandable.process = pipe.process return pipe } -func (ring *Ring) Pipelined(fn func(*RingPipeline) error) ([]Cmder, error) { - pipe := ring.Pipeline() - if err := fn(pipe); err != nil { - return nil, err - } - cmds, err := pipe.Exec() - pipe.Close() - return cmds, err +func (ring *Ring) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { + return ring.Pipeline().pipelined(fn) } -func (pipe *RingPipeline) process(cmd Cmder) { - pipe.cmds = append(pipe.cmds, cmd) -} - -// Discard resets the pipeline and discards queued commands. -func (pipe *RingPipeline) Discard() error { - if pipe.closed { - return pool.ErrClosed - } - pipe.cmds = pipe.cmds[:0] - return nil -} - -// Exec always returns list of commands and error of the first failed -// command if any. -func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) { - if pipe.closed { - return nil, pool.ErrClosed - } - if len(pipe.cmds) == 0 { - return pipe.cmds, nil - } - - cmds = pipe.cmds - pipe.cmds = make([]Cmder, 0, 10) +func (ring *Ring) pipelineExec(cmds []Cmder) error { + var retErr error cmdsMap := make(map[string][]Cmder) for _, cmd := range cmds { - name := pipe.ring.hash.Get(hashtag.Key(cmd.clusterKey())) + name := ring.hash.Get(hashtag.Key(cmd.clusterKey())) if name == "" { cmd.setErr(errRingShardsDown) if retErr == nil { @@ -311,11 +269,11 @@ func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) { cmdsMap[name] = append(cmdsMap[name], cmd) } - for i := 0; i <= pipe.ring.opt.MaxRetries; i++ { + for i := 0; i <= ring.opt.MaxRetries; i++ { failedCmdsMap := make(map[string][]Cmder) for name, cmds := range cmdsMap { - client := pipe.ring.shards[name].Client + client := ring.shards[name].Client cn, err := client.conn() if err != nil { setCmdsErr(cmds, err) @@ -344,12 +302,5 @@ func (pipe *RingPipeline) Exec() (cmds []Cmder, retErr error) { cmdsMap = failedCmdsMap } - return cmds, retErr -} - -// Close closes the pipeline, releasing any open resources. -func (pipe *RingPipeline) Close() error { - pipe.Discard() - pipe.closed = true - return nil + return retErr } diff --git a/ring_test.go b/ring_test.go index 7d5cd91..06e6c30 100644 --- a/ring_test.go +++ b/ring_test.go @@ -96,7 +96,7 @@ var _ = Describe("Redis ring", func() { Describe("pipelining", func() { It("returns an error when all shards are down", func() { ring := redis.NewRing(&redis.RingOptions{}) - _, err := ring.Pipelined(func(pipe *redis.RingPipeline) error { + _, err := ring.Pipelined(func(pipe *redis.Pipeline) error { pipe.Ping() return nil }) @@ -133,7 +133,7 @@ var _ = Describe("Redis ring", func() { keys = append(keys, string(key)) } - _, err := ring.Pipelined(func(pipe *redis.RingPipeline) error { + _, err := ring.Pipelined(func(pipe *redis.Pipeline) error { for _, key := range keys { pipe.Set(key, "value", 0).Err() } @@ -149,7 +149,7 @@ var _ = Describe("Redis ring", func() { }) It("supports hash tags", func() { - _, err := ring.Pipelined(func(pipe *redis.RingPipeline) error { + _, err := ring.Pipelined(func(pipe *redis.Pipeline) error { for i := 0; i < 100; i++ { pipe.Set(fmt.Sprintf("key%d{tag}", i), "value", 0).Err() }