diff --git a/cluster.go b/cluster.go index b0c4659b..1ae6082b 100644 --- a/cluster.go +++ b/cluster.go @@ -13,8 +13,8 @@ import ( type clusterNode struct { Addr string - Latency int Client *Client + Latency time.Duration } // ClusterClient is a Redis Cluster client representing a pool of zero @@ -73,8 +73,8 @@ func (c *ClusterClient) cmdInfo(name string) *CommandInfo { } func (c *ClusterClient) getNodes() map[string]*clusterNode { - c.mu.RLock() var nodes map[string]*clusterNode + c.mu.RLock() if !c.closed { nodes = make(map[string]*clusterNode, len(c.nodes)) for addr, node := range c.nodes { @@ -95,7 +95,7 @@ func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { // PoolStats returns accumulated connection pool stats. func (c *ClusterClient) PoolStats() *PoolStats { - acc := PoolStats{} + var acc PoolStats for _, node := range c.getNodes() { s := node.Client.connPool.Stats() acc.Requests += s.Requests @@ -214,7 +214,6 @@ func (c *ClusterClient) slotSlaveNode(slot int) (*clusterNode, error) { n := rand.Intn(len(nodes)-1) + 1 return nodes[n], nil } - } func (c *ClusterClient) slotClosestNode(slot int) (*clusterNode, error) { @@ -261,19 +260,19 @@ func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode, error) { return slot, node, err } -func (c *ClusterClient) Process(cmd Cmder) { - var ask bool +func (c *ClusterClient) Process(cmd Cmder) error { slot, node, err := c.cmdSlotAndNode(cmd) + if err != nil { + cmd.setErr(err) + return err + } + + var ask bool for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { if attempt > 0 { cmd.reset() } - if err != nil { - cmd.setErr(err) - return - } - if ask { pipe := node.Client.Pipeline() pipe.Process(NewCmd("ASKING")) @@ -288,7 +287,7 @@ func (c *ClusterClient) Process(cmd Cmder) { // If there is no (real) error, we are done! err := cmd.Err() if err == nil { - return + return nil } // On network errors try random node. @@ -307,11 +306,58 @@ func (c *ClusterClient) Process(cmd Cmder) { } node, err = c.nodeByAddr(addr) + if err != nil { + cmd.setErr(err) + return err + } continue } break } + + return cmd.Err() +} + +// ForEachMaster concurrently calls the fn on each master node in the cluster. +// It returns the first error if any. +func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { + c.mu.RLock() + slots := c.slots + c.mu.RUnlock() + + var retErr error + var mu sync.Mutex + + var wg sync.WaitGroup + visited := make(map[*clusterNode]struct{}) + for _, nodes := range slots { + if len(nodes) == 0 { + continue + } + + master := nodes[0] + if _, ok := visited[master]; ok { + continue + } + visited[master] = struct{}{} + + wg.Add(1) + go func(node *clusterNode) { + err := fn(node.Client) + if err != nil { + mu.Lock() + if retErr == nil { + retErr = err + } + mu.Unlock() + } + wg.Done() + }(master) + } + wg.Wait() + + return retErr } // closeClients closes all clients and returns the first error if there are any. @@ -327,9 +373,6 @@ func (c *ClusterClient) closeClients() error { func (c *ClusterClient) setSlots(cs []ClusterSlot) { slots := make([][]*clusterNode, hashtag.SlotNumber) - for i := 0; i < hashtag.SlotNumber; i++ { - slots[i] = nil - } for _, s := range cs { var nodes []*clusterNode for _, n := range s.Nodes { @@ -351,17 +394,11 @@ func (c *ClusterClient) setSlots(cs []ClusterSlot) { c.mu.Unlock() } -func (c *ClusterClient) setNodesLatency() { - nodes := c.getNodes() - for _, node := range nodes { - var latency int - for i := 0; i < 10; i++ { - t1 := time.Now() - node.Client.Ping() - latency += int(time.Since(t1) / time.Millisecond) - } - node.Latency = latency +func (c *ClusterClient) lazyReloadSlots() { + if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) { + return } + go c.reloadSlots() } func (c *ClusterClient) reloadSlots() { @@ -384,11 +421,17 @@ func (c *ClusterClient) reloadSlots() { } } -func (c *ClusterClient) lazyReloadSlots() { - if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) { - return +func (c *ClusterClient) setNodesLatency() { + const n = 10 + for _, node := range c.getNodes() { + var latency time.Duration + for i := 0; i < n; i++ { + t1 := time.Now() + node.Client.Ping() + latency += time.Since(t1) + } + node.Latency = latency / n } - go c.reloadSlots() } // reaper closes idle connections to the cluster. @@ -435,7 +478,7 @@ func (c *ClusterClient) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { func (c *ClusterClient) pipelineExec(cmds []Cmder) error { var retErr error - returnError := func(err error) { + setRetErr := func(err error) { if retErr == nil { retErr = err } @@ -446,7 +489,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { _, node, err := c.cmdSlotAndNode(cmd) if err != nil { cmd.setErr(err) - returnError(err) + setRetErr(err) continue } cmdsMap[node] = append(cmdsMap[node], cmd) @@ -461,21 +504,21 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { node, err = c.randomNode() if err != nil { setCmdsErr(cmds, err) - returnError(err) + setRetErr(err) continue } } cn, err := node.Client.conn() if err != nil { - setCmdsErr(cmds, retErr) - returnError(err) + setCmdsErr(cmds, err) + setRetErr(err) continue } failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds) if err != nil { - returnError(err) + setRetErr(err) } node.Client.putConn(cn, err, false) } @@ -495,7 +538,7 @@ func (c *ClusterClient) execClusterCmds( } var retErr error - returnError := func(err error) { + setRetErr := func(err error) { if retErr == nil { retErr = err } @@ -515,7 +558,7 @@ func (c *ClusterClient) execClusterCmds( cmd.reset() node, err := c.nodeByAddr(addr) if err != nil { - returnError(err) + setRetErr(err) continue } failedCmds[node] = append(failedCmds[node], cmd) @@ -523,12 +566,12 @@ func (c *ClusterClient) execClusterCmds( cmd.reset() node, err := c.nodeByAddr(addr) if err != nil { - returnError(err) + setRetErr(err) continue } failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd) } else { - returnError(err) + setRetErr(err) } } diff --git a/cluster_test.go b/cluster_test.go index 446e20df..6889d4e0 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -336,11 +336,11 @@ var _ = Describe("ClusterClient", func() { Expect(cnt).To(Equal(int64(1))) }) - It("should return pool stats", func() { + It("returns pool stats", func() { Expect(client.PoolStats()).To(BeAssignableToTypeOf(&redis.PoolStats{})) }) - It("should follow redirects", func() { + It("follows redirects", func() { Expect(client.Set("A", "VALUE", 0).Err()).NotTo(HaveOccurred()) slot := hashtag.Slot("A") @@ -351,7 +351,7 @@ var _ = Describe("ClusterClient", func() { Expect(val).To(Equal("VALUE")) }) - It("should return error when there are no attempts left", func() { + It("returns an error when there are no attempts left", func() { client := cluster.clusterClient(&redis.ClusterOptions{ MaxRedirects: -1, }) @@ -366,7 +366,7 @@ var _ = Describe("ClusterClient", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("should Watch", func() { + It("supports Watch", func() { var incr func(string) error // Transactionally increments key using GET and SET commands. @@ -461,17 +461,35 @@ var _ = Describe("ClusterClient", func() { Expect(c.Err()).NotTo(HaveOccurred()) Expect(c.Val()).To(Equal("C_value")) }) + + It("calls fn for every master node", func() { + for i := 0; i < 10; i++ { + Expect(client.Set(strconv.Itoa(i), "", 0).Err()).NotTo(HaveOccurred()) + } + + err := client.ForEachMaster(func(master *redis.Client) error { + return master.FlushDb().Err() + }) + Expect(err).NotTo(HaveOccurred()) + + for _, client := range cluster.masters() { + keys, err := client.Keys("*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(keys).To(HaveLen(0)) + } + }) } Describe("default ClusterClient", func() { BeforeEach(func() { client = cluster.clusterClient(nil) + + _ = client.ForEachMaster(func(master *redis.Client) error { + return master.FlushDb().Err() + }) }) AfterEach(func() { - for _, client := range cluster.masters() { - Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) - } Expect(client.Close()).NotTo(HaveOccurred()) }) @@ -483,12 +501,14 @@ var _ = Describe("ClusterClient", func() { client = cluster.clusterClient(&redis.ClusterOptions{ RouteByLatency: true, }) + + _ = client.ForEachMaster(func(master *redis.Client) error { + return master.FlushDb().Err() + }) }) AfterEach(func() { - for _, client := range cluster.masters() { - Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) - } + client.FlushDb() Expect(client.Close()).NotTo(HaveOccurred()) }) diff --git a/command.go b/command.go index 2db8fa2f..c5625828 100644 --- a/command.go +++ b/command.go @@ -31,6 +31,7 @@ var ( type Cmder interface { args() []interface{} arg(int) string + readReply(*pool.Conn) error setErr(error) reset() @@ -142,7 +143,9 @@ type Cmd struct { } func NewCmd(args ...interface{}) *Cmd { - return &Cmd{baseCmd: newBaseCmd(args)} + return &Cmd{ + baseCmd: newBaseCmd(args), + } } func (cmd *Cmd) reset() { diff --git a/commands.go b/commands.go index b9c7d25d..b21a8eec 100644 --- a/commands.go +++ b/commands.go @@ -52,11 +52,11 @@ func formatSec(dur time.Duration) string { } type cmdable struct { - process func(cmd Cmder) + process func(cmd Cmder) error } type statefulCmdable struct { - process func(cmd Cmder) + process func(cmd Cmder) error } //------------------------------------------------------------------------------ diff --git a/pipeline.go b/pipeline.go index 8d2d884a..e946b9e9 100644 --- a/pipeline.go +++ b/pipeline.go @@ -22,10 +22,11 @@ type Pipeline struct { closed int32 } -func (pipe *Pipeline) Process(cmd Cmder) { +func (pipe *Pipeline) Process(cmd Cmder) error { pipe.mu.Lock() pipe.cmds = append(pipe.cmds, cmd) pipe.mu.Unlock() + return nil } // Close closes the pipeline, releasing any open resources. diff --git a/redis.go b/redis.go index 121c5b24..5c93db00 100644 --- a/redis.go +++ b/redis.go @@ -74,7 +74,7 @@ func (c *baseClient) initConn(cn *pool.Conn) error { return err } -func (c *baseClient) Process(cmd Cmder) { +func (c *baseClient) Process(cmd Cmder) error { for i := 0; i <= c.opt.MaxRetries; i++ { if i > 0 { cmd.reset() @@ -83,7 +83,7 @@ func (c *baseClient) Process(cmd Cmder) { cn, err := c.conn() if err != nil { cmd.setErr(err) - return + return err } readTimeout := cmd.readTimeout() @@ -100,7 +100,7 @@ func (c *baseClient) Process(cmd Cmder) { if err != nil && shouldRetry(err) { continue } - return + return err } err = cmd.readReply(cn) @@ -109,8 +109,10 @@ func (c *baseClient) Process(cmd Cmder) { continue } - return + return err } + + return cmd.Err() } func (c *baseClient) closed() bool { diff --git a/ring.go b/ring.go index bc749c1d..50476fa8 100644 --- a/ring.go +++ b/ring.go @@ -199,13 +199,13 @@ func (ring *Ring) getClient(key string) (*Client, error) { return cl, nil } -func (ring *Ring) Process(cmd Cmder) { +func (ring *Ring) Process(cmd Cmder) error { cl, err := ring.getClient(ring.cmdFirstKey(cmd)) if err != nil { cmd.setErr(err) - return + return err } - cl.baseClient.Process(cmd) + return cl.baseClient.Process(cmd) } // rebalance removes dead shards from the ring. diff --git a/tx.go b/tx.go index 43271faa..4f85cc8d 100644 --- a/tx.go +++ b/tx.go @@ -50,12 +50,12 @@ func (c *Client) Watch(fn func(*Tx) error, keys ...string) error { return retErr } -func (tx *Tx) Process(cmd Cmder) { +func (tx *Tx) Process(cmd Cmder) error { if tx.cmds == nil { - tx.baseClient.Process(cmd) - } else { - tx.cmds = append(tx.cmds, cmd) + return tx.baseClient.Process(cmd) } + tx.cmds = append(tx.cmds, cmd) + return nil } // close closes the transaction, releasing any open resources.