diff --git a/cluster.go b/cluster.go index 983480b5..b0c4659b 100644 --- a/cluster.go +++ b/cluster.go @@ -51,7 +51,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { client.cmdable.process = client.Process for _, addr := range opt.Addrs { - _ = client.nodeByAddr(addr) + _, _ = client.nodeByAddr(addr) } client.reloadSlots() @@ -86,7 +86,10 @@ func (c *ClusterClient) getNodes() map[string]*clusterNode { } func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { - node := c.slotMasterNode(hashtag.Slot(keys[0])) + node, err := c.slotMasterNode(hashtag.Slot(keys[0])) + if err != nil { + return err + } return node.Client.Watch(fn, keys...) } @@ -122,26 +125,29 @@ func (c *ClusterClient) Close() error { return nil } -func (c *ClusterClient) nodeByAddr(addr string) *clusterNode { +func (c *ClusterClient) nodeByAddr(addr string) (*clusterNode, error) { c.mu.RLock() node, ok := c.nodes[addr] c.mu.RUnlock() if ok { - return node + return node, nil } + defer c.mu.Unlock() c.mu.Lock() - if !c.closed { - node, ok = c.nodes[addr] - if !ok { - node = c.newNode(addr) - c.nodes[addr] = node - c.addrs = append(c.addrs, node.Addr) - } - } - c.mu.Unlock() - return node + if c.closed { + return nil, pool.ErrClosed + } + + node, ok = c.nodes[addr] + if !ok { + node = c.newNode(addr) + c.nodes[addr] = node + c.addrs = append(c.addrs, node.Addr) + } + + return node, nil } func (c *ClusterClient) newNode(addr string) *clusterNode { @@ -161,70 +167,81 @@ func (c *ClusterClient) slotNodes(slot int) []*clusterNode { } // randomNode returns random live node. -func (c *ClusterClient) randomNode() *clusterNode { +func (c *ClusterClient) randomNode() (*clusterNode, error) { var node *clusterNode + var err error for i := 0; i < 10; i++ { c.mu.RLock() + closed := c.closed addrs := c.addrs c.mu.RUnlock() - if len(addrs) == 0 { - return nil + if closed { + return nil, pool.ErrClosed } n := rand.Intn(len(addrs)) - node = c.nodeByAddr(addrs[n]) + node, err = c.nodeByAddr(addrs[n]) + if err != nil { + return nil, err + } if node.Client.ClusterInfo().Err() == nil { - return node + break } } - return node + return node, nil } -func (c *ClusterClient) slotMasterNode(slot int) *clusterNode { +func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) { nodes := c.slotNodes(slot) if len(nodes) == 0 { return c.randomNode() } - return nodes[0] + return nodes[0], nil } -func (c *ClusterClient) slotSlaveNode(slot int) *clusterNode { +func (c *ClusterClient) slotSlaveNode(slot int) (*clusterNode, error) { nodes := c.slotNodes(slot) switch len(nodes) { case 0: return c.randomNode() case 1: - return nodes[0] + return nodes[0], nil case 2: - return nodes[1] + return nodes[1], nil default: n := rand.Intn(len(nodes)-1) + 1 - return nodes[n] + return nodes[n], nil } } -func (c *ClusterClient) slotClosestNode(slot int) *clusterNode { +func (c *ClusterClient) slotClosestNode(slot int) (*clusterNode, error) { nodes := c.slotNodes(slot) + if len(nodes) == 0 { + return c.randomNode() + } + var node *clusterNode for _, n := range nodes { if node == nil || n.Latency < node.Latency { node = n } } - return node + return node, nil } -func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode) { +func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode, error) { cmdInfo := c.cmdInfo(cmd.arg(0)) if cmdInfo == nil { - return 0, c.randomNode() + node, err := c.randomNode() + return 0, node, err } if cmdInfo.FirstKeyPos == -1 { - return 0, c.randomNode() + node, err := c.randomNode() + return 0, node, err } firstKey := cmd.arg(int(cmdInfo.FirstKeyPos)) @@ -232,23 +249,28 @@ func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode) { if cmdInfo.ReadOnly && c.opt.ReadOnly { if c.opt.RouteByLatency { - return slot, c.slotClosestNode(slot) + node, err := c.slotClosestNode(slot) + return slot, node, err } - return slot, c.slotSlaveNode(slot) + + node, err := c.slotSlaveNode(slot) + return slot, node, err } - return slot, c.slotMasterNode(slot) + + node, err := c.slotMasterNode(slot) + return slot, node, err } func (c *ClusterClient) Process(cmd Cmder) { var ask bool - slot, node := c.cmdSlotAndNode(cmd) - + slot, node, err := c.cmdSlotAndNode(cmd) for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { if attempt > 0 { cmd.reset() } - if node == nil { - cmd.setErr(pool.ErrClosed) + + if err != nil { + cmd.setErr(err) return } @@ -271,7 +293,7 @@ func (c *ClusterClient) Process(cmd Cmder) { // On network errors try random node. if shouldRetry(err) { - node = c.randomNode() + node, err = c.randomNode() continue } @@ -279,11 +301,12 @@ func (c *ClusterClient) Process(cmd Cmder) { var addr string moved, ask, addr = isMovedError(err) if moved || ask { - if moved && c.slotMasterNode(slot).Addr != addr { + master, _ := c.slotMasterNode(slot) + if moved && (master == nil || master.Addr != addr) { c.lazyReloadSlots() } - node = c.nodeByAddr(addr) + node, err = c.nodeByAddr(addr) continue } @@ -310,7 +333,10 @@ func (c *ClusterClient) setSlots(cs []ClusterSlot) { for _, s := range cs { var nodes []*clusterNode for _, n := range s.Nodes { - nodes = append(nodes, c.nodeByAddr(n.Addr)) + node, err := c.nodeByAddr(n.Addr) + if err == nil { + nodes = append(nodes, node) + } } for i := s.Start; i <= s.End; i++ { @@ -341,8 +367,8 @@ func (c *ClusterClient) setNodesLatency() { func (c *ClusterClient) reloadSlots() { defer atomic.StoreUint32(&c.reloading, 0) - node := c.randomNode() - if node == nil { + node, err := c.randomNode() + if err != nil { return } @@ -409,10 +435,20 @@ func (c *ClusterClient) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { func (c *ClusterClient) pipelineExec(cmds []Cmder) error { var retErr error + returnError := func(err error) { + if retErr == nil { + retErr = err + } + } cmdsMap := make(map[*clusterNode][]Cmder) for _, cmd := range cmds { - _, node := c.cmdSlotAndNode(cmd) + _, node, err := c.cmdSlotAndNode(cmd) + if err != nil { + cmd.setErr(err) + returnError(err) + continue + } cmdsMap[node] = append(cmdsMap[node], cmd) } @@ -421,19 +457,25 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { for node, cmds := range cmdsMap { if node == nil { - node = c.randomNode() + var err error + node, err = c.randomNode() + if err != nil { + setCmdsErr(cmds, err) + returnError(err) + continue + } } cn, err := node.Client.conn() if err != nil { - setCmdsErr(cmds, err) - retErr = err + setCmdsErr(cmds, retErr) + returnError(err) continue } failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds) if err != nil { - retErr = err + returnError(err) } node.Client.putConn(cn, err, false) } @@ -452,7 +494,13 @@ func (c *ClusterClient) execClusterCmds( return failedCmds, err } - var firstCmdErr error + var retErr error + returnError := func(err error) { + if retErr == nil { + retErr = err + } + } + for i, cmd := range cmds { err := cmd.readReply(cn) if err == nil { @@ -465,18 +513,26 @@ func (c *ClusterClient) execClusterCmds( } else if moved, ask, addr := isMovedError(err); moved { c.lazyReloadSlots() cmd.reset() - node := c.nodeByAddr(addr) + node, err := c.nodeByAddr(addr) + if err != nil { + returnError(err) + continue + } failedCmds[node] = append(failedCmds[node], cmd) } else if ask { cmd.reset() - node := c.nodeByAddr(addr) + node, err := c.nodeByAddr(addr) + if err != nil { + returnError(err) + continue + } failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd) - } else if firstCmdErr == nil { - firstCmdErr = err + } else { + returnError(err) } } - return failedCmds, firstCmdErr + return failedCmds, retErr } //------------------------------------------------------------------------------ diff --git a/cluster_test.go b/cluster_test.go index 9cdbc004..446e20df 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -312,21 +312,12 @@ var _ = Describe("Cluster", func() { // Expect(res).To(Equal("OK")) // }) }) +}) - Describe("Client", func() { - var client *redis.ClusterClient - - BeforeEach(func() { - client = cluster.clusterClient(nil) - }) - - AfterEach(func() { - for _, client := range cluster.masters() { - Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) - } - Expect(client.Close()).NotTo(HaveOccurred()) - }) +var _ = Describe("ClusterClient", func() { + var client *redis.ClusterClient + describeClusterClient := func() { It("should GET/SET/DEL", func() { val, err := client.Get("A").Result() Expect(err).To(Equal(redis.Nil)) @@ -358,15 +349,10 @@ var _ = Describe("Cluster", func() { val, err := client.Get("A").Result() Expect(err).NotTo(HaveOccurred()) Expect(val).To(Equal("VALUE")) - - Eventually(func() []string { - return client.SlotAddrs(slot) - }, "10s").Should(Equal([]string{"127.0.0.1:8221", "127.0.0.1:8224"})) }) It("should return error when there are no attempts left", func() { - Expect(client.Close()).NotTo(HaveOccurred()) - client = cluster.clusterClient(&redis.ClusterOptions{ + client := cluster.clusterClient(&redis.ClusterOptions{ MaxRedirects: -1, }) @@ -376,6 +362,8 @@ var _ = Describe("Cluster", func() { err := client.Get("A").Err() Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("MOVED")) + + Expect(client.Close()).NotTo(HaveOccurred()) }) It("should Watch", func() { @@ -417,23 +405,8 @@ var _ = Describe("Cluster", func() { Expect(err).NotTo(HaveOccurred()) Expect(n).To(Equal(int64(100))) }) - }) - Describe("pipeline", func() { - var client *redis.ClusterClient - - BeforeEach(func() { - client = cluster.clusterClient(nil) - }) - - AfterEach(func() { - for _, client := range cluster.masters() { - Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) - } - Expect(client.Close()).NotTo(HaveOccurred()) - }) - - It("performs multi-pipelines", func() { + It("supports pipeline", func() { slot := hashtag.Slot("A") Expect(client.SwapSlotNodes(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) @@ -441,27 +414,31 @@ var _ = Describe("Cluster", func() { defer pipe.Close() keys := []string{"A", "B", "C", "D", "E", "F", "G"} + for i, key := range keys { pipe.Set(key, key+"_value", 0) pipe.Expire(key, time.Duration(i+1)*time.Hour) } + cmds, err := pipe.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(14)) + for _, key := range keys { pipe.Get(key) pipe.TTL(key) } - - cmds, err := pipe.Exec() + cmds, err = pipe.Exec() Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(28)) - Expect(cmds[14].(*redis.StringCmd).Val()).To(Equal("A_value")) - Expect(cmds[15].(*redis.DurationCmd).Val()).To(BeNumerically("~", 1*time.Hour, time.Second)) - Expect(cmds[20].(*redis.StringCmd).Val()).To(Equal("D_value")) - Expect(cmds[21].(*redis.DurationCmd).Val()).To(BeNumerically("~", 4*time.Hour, time.Second)) - Expect(cmds[26].(*redis.StringCmd).Val()).To(Equal("G_value")) - Expect(cmds[27].(*redis.DurationCmd).Val()).To(BeNumerically("~", 7*time.Hour, time.Second)) + Expect(cmds).To(HaveLen(14)) + Expect(cmds[0].(*redis.StringCmd).Val()).To(Equal("A_value")) + Expect(cmds[1].(*redis.DurationCmd).Val()).To(BeNumerically("~", 1*time.Hour, time.Second)) + Expect(cmds[6].(*redis.StringCmd).Val()).To(Equal("D_value")) + Expect(cmds[7].(*redis.DurationCmd).Val()).To(BeNumerically("~", 4*time.Hour, time.Second)) + Expect(cmds[12].(*redis.StringCmd).Val()).To(Equal("G_value")) + Expect(cmds[13].(*redis.DurationCmd).Val()).To(BeNumerically("~", 7*time.Hour, time.Second)) }) - It("works with missing keys", func() { + It("supports pipeline with missing keys", func() { Expect(client.Set("A", "A_value", 0).Err()).NotTo(HaveOccurred()) Expect(client.Set("C", "C_value", 0).Err()).NotTo(HaveOccurred()) @@ -484,6 +461,38 @@ var _ = Describe("Cluster", func() { Expect(c.Err()).NotTo(HaveOccurred()) Expect(c.Val()).To(Equal("C_value")) }) + } + + Describe("default ClusterClient", func() { + BeforeEach(func() { + client = cluster.clusterClient(nil) + }) + + AfterEach(func() { + for _, client := range cluster.masters() { + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + } + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + describeClusterClient() + }) + + Describe("ClusterClient with RouteByLatency", func() { + BeforeEach(func() { + client = cluster.clusterClient(&redis.ClusterOptions{ + RouteByLatency: true, + }) + }) + + AfterEach(func() { + for _, client := range cluster.masters() { + Expect(client.FlushDb().Err()).NotTo(HaveOccurred()) + } + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + describeClusterClient() }) })