diff --git a/cluster.go b/cluster.go index 67e07e35..d24438fc 100644 --- a/cluster.go +++ b/cluster.go @@ -8,6 +8,7 @@ import ( "math" "math/rand" "net" + "strings" "sync" "sync/atomic" "time" @@ -35,6 +36,7 @@ type ClusterOptions struct { // Enables read-only commands on slave nodes. ReadOnly bool // Allows routing read-only commands to the closest master or slave node. + // It automatically enables ReadOnly. RouteByLatency bool // Allows routing read-only commands to the random master or slave node. RouteRandomly bool @@ -150,6 +152,10 @@ func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode { return &node } +func (n *clusterNode) String() string { + return n.Client.String() +} + func (n *clusterNode) Close() error { return n.Client.Close() } @@ -379,15 +385,17 @@ func (c *clusterNodes) Random() (*clusterNode, error) { type clusterState struct { nodes *clusterNodes - masters []*clusterNode - slaves []*clusterNode + Masters []*clusterNode + Slaves []*clusterNode slots [][]*clusterNode generation uint32 } -func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (*clusterState, error) { +func newClusterState( + nodes *clusterNodes, slots []ClusterSlot, origin string, +) (*clusterState, error) { c := clusterState{ nodes: nodes, generation: nodes.NextGeneration(), @@ -413,9 +421,9 @@ func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (* nodes = append(nodes, node) if i == 0 { - c.masters = appendNode(c.masters, node) + c.Masters = appendUniqueNode(c.Masters, node) } else { - c.slaves = appendNode(c.slaves, node) + c.Slaves = appendUniqueNode(c.Slaves, node) } } @@ -497,6 +505,28 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode { return nil } +func (c *clusterState) IsConsistent() bool { + if len(c.Masters) > len(c.Slaves) { + return false + } + + for _, master := range c.Masters { + s := master.Client.Info("replication").Val() + if !strings.Contains(s, "role:master") { + return false + } + } + + for _, slave := range c.Slaves { + s := slave.Client.Info("replication").Val() + if !strings.Contains(s, "role:slave") { + return false + } + } + + return true +} + //------------------------------------------------------------------------------ type clusterStateHolder struct { @@ -516,7 +546,18 @@ func newClusterStateHolder(fn func() (*clusterState, error)) *clusterStateHolder } } -func (c *clusterStateHolder) Load() (*clusterState, error) { +func (c *clusterStateHolder) Reload() (*clusterState, error) { + state, err := c.reload() + if err != nil { + return nil, err + } + if !state.IsConsistent() { + c.LazyReload() + } + return state, nil +} + +func (c *clusterStateHolder) reload() (*clusterState, error) { state, err := c.load() if err != nil { c.lastErrMu.Lock() @@ -535,9 +576,15 @@ func (c *clusterStateHolder) LazyReload() { go func() { defer atomic.StoreUint32(&c.reloading, 0) - _, err := c.Load() - if err == nil { - time.Sleep(time.Second) + for { + state, err := c.reload() + if err != nil { + return + } + time.Sleep(100 * time.Millisecond) + if state.IsConsistent() { + return + } } }() } @@ -596,7 +643,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.cmdable.setProcessor(c.Process) - _, _ = c.state.Load() + _, _ = c.state.Reload() if opt.IdleCheckFrequency > 0 { go c.reaper(opt.IdleCheckFrequency) } @@ -890,7 +937,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { var wg sync.WaitGroup errCh := make(chan error, 1) - for _, master := range state.masters { + for _, master := range state.Masters { wg.Add(1) go func(node *clusterNode) { defer wg.Done() @@ -923,7 +970,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { var wg sync.WaitGroup errCh := make(chan error, 1) - for _, slave := range state.slaves { + for _, slave := range state.Slaves { wg.Add(1) go func(node *clusterNode) { defer wg.Done() @@ -967,11 +1014,11 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { } } - for _, node := range state.masters { + for _, node := range state.Masters { wg.Add(1) go worker(node) } - for _, node := range state.slaves { + for _, node := range state.Slaves { wg.Add(1) go worker(node) } @@ -994,7 +1041,7 @@ func (c *ClusterClient) PoolStats() *PoolStats { return &acc } - for _, node := range state.masters { + for _, node := range state.Masters { s := node.Client.connPool.Stats() acc.Hits += s.Hits acc.Misses += s.Misses @@ -1005,7 +1052,7 @@ func (c *ClusterClient) PoolStats() *PoolStats { acc.StaleConns += s.StaleConns } - for _, node := range state.slaves { + for _, node := range state.Slaves { s := node.Client.connPool.Stats() acc.Hits += s.Hits acc.Misses += s.Misses @@ -1438,7 +1485,7 @@ func isLoopbackAddr(addr string) bool { return ip.IsLoopback() } -func appendNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { +func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { for _, n := range nodes { if n == node { return nodes diff --git a/cluster_test.go b/cluster_test.go index 24ea4e13..db0728d7 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -50,7 +50,15 @@ func (s *clusterScenario) addrs() []string { func (s *clusterScenario) clusterClient(opt *redis.ClusterOptions) *redis.ClusterClient { opt.Addrs = s.addrs() - return redis.NewClusterClient(opt) + client := redis.NewClusterClient(opt) + Eventually(func() bool { + state, err := client.GetState() + if err != nil { + return false + } + return state.IsConsistent() + }, 30*time.Second).Should(BeTrue()) + return client } func startCluster(scenario *clusterScenario) error { @@ -116,45 +124,43 @@ func startCluster(scenario *clusterScenario) error { } // Wait until all nodes have consistent info. + wanted := []redis.ClusterSlot{{ + Start: 0, + End: 4999, + Nodes: []redis.ClusterNode{{ + Id: "", + Addr: "127.0.0.1:8220", + }, { + Id: "", + Addr: "127.0.0.1:8223", + }}, + }, { + Start: 5000, + End: 9999, + Nodes: []redis.ClusterNode{{ + Id: "", + Addr: "127.0.0.1:8221", + }, { + Id: "", + Addr: "127.0.0.1:8224", + }}, + }, { + Start: 10000, + End: 16383, + Nodes: []redis.ClusterNode{{ + Id: "", + Addr: "127.0.0.1:8222", + }, { + Id: "", + Addr: "127.0.0.1:8225", + }}, + }} for _, client := range scenario.clients { err := eventually(func() error { res, err := client.ClusterSlots().Result() if err != nil { return err } - wanted := []redis.ClusterSlot{ - { - Start: 0, - End: 4999, - Nodes: []redis.ClusterNode{{ - Id: "", - Addr: "127.0.0.1:8220", - }, { - Id: "", - Addr: "127.0.0.1:8223", - }}, - }, { - Start: 5000, - End: 9999, - Nodes: []redis.ClusterNode{{ - Id: "", - Addr: "127.0.0.1:8221", - }, { - Id: "", - Addr: "127.0.0.1:8224", - }}, - }, { - Start: 10000, - End: 16383, - Nodes: []redis.ClusterNode{{ - Id: "", - Addr: "127.0.0.1:8222", - }, { - Id: "", - Addr: "127.0.0.1:8225", - }}, - }, - } return assertSlotsEqual(res, wanted) }, 30*time.Second) if err != nil { @@ -213,6 +219,7 @@ func stopCluster(scenario *clusterScenario) error { //------------------------------------------------------------------------------ var _ = Describe("ClusterClient", func() { + var failover bool var opt *redis.ClusterOptions var client *redis.ClusterClient @@ -233,15 +240,42 @@ var _ = Describe("ClusterClient", func() { Expect(cnt).To(Equal(int64(1))) }) - It("follows redirects", func() { - Expect(client.Set("A", "VALUE", 0).Err()).NotTo(HaveOccurred()) + It("GET follows redirects", func() { + err := client.Set("A", "VALUE", 0).Err() + Expect(err).NotTo(HaveOccurred()) - slot := hashtag.Slot("A") - client.SwapSlotNodes(slot) + if !failover { + Eventually(func() int64 { + nodes, err := client.Nodes("A") + if err != nil { + return 0 + } + return nodes[1].Client.DBSize().Val() + }, 30*time.Second).Should(Equal(int64(1))) - Eventually(func() string { - return client.Get("A").Val() - }, 30*time.Second).Should(Equal("VALUE")) + Eventually(func() error { + return client.SwapNodes("A") + }, 30*time.Second).ShouldNot(HaveOccurred()) + } + + v, err := client.Get("A").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(v).To(Equal("VALUE")) + }) + + It("SET follows redirects", func() { + if !failover { + Eventually(func() error { + return client.SwapNodes("A") + }, 30*time.Second).ShouldNot(HaveOccurred()) + } + + err := client.Set("A", "VALUE", 0).Err() + Expect(err).NotTo(HaveOccurred()) + + v, err := client.Get("A").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(v).To(Equal("VALUE")) }) It("distributes keys", func() { @@ -250,7 +284,8 @@ var _ = Describe("ClusterClient", func() { Expect(err).NotTo(HaveOccurred()) } - for _, master := range cluster.masters() { + client.ForEachMaster(func(master *redis.Client) error { + defer GinkgoRecover() Eventually(func() string { return master.Info("keyspace").Val() }, 30*time.Second).Should(Or( @@ -258,7 +293,8 @@ var _ = Describe("ClusterClient", func() { ContainSubstring("keys=29"), ContainSubstring("keys=40"), )) - } + return nil + }) }) It("distributes keys when using EVAL", func() { @@ -333,9 +369,12 @@ var _ = Describe("ClusterClient", func() { keys := []string{"A", "B", "C", "D", "E", "F", "G"} It("follows redirects", func() { - for _, key := range keys { - slot := hashtag.Slot(key) - client.SwapSlotNodes(slot) + if !failover { + for _, key := range keys { + Eventually(func() error { + return client.SwapNodes(key) + }, 30*time.Second).ShouldNot(HaveOccurred()) + } } for i, key := range keys { @@ -354,9 +393,12 @@ var _ = Describe("ClusterClient", func() { return nil }) - for _, key := range keys { - slot := hashtag.Slot(key) - client.SwapSlotNodes(slot) + if !failover { + for _, key := range keys { + Eventually(func() error { + return client.SwapNodes(key) + }, 30*time.Second).ShouldNot(HaveOccurred()) + } } for _, key := range keys { @@ -456,9 +498,10 @@ var _ = Describe("ClusterClient", func() { opt = redisClusterOptions() client = cluster.clusterClient(opt) - _ = client.ForEachMaster(func(master *redis.Client) error { + err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() }) + Expect(err).NotTo(HaveOccurred()) }) AfterEach(func() { @@ -469,7 +512,8 @@ var _ = Describe("ClusterClient", func() { }) It("returns pool stats", func() { - Expect(client.PoolStats()).To(BeAssignableToTypeOf(&redis.PoolStats{})) + stats := client.PoolStats() + Expect(stats).To(BeAssignableToTypeOf(&redis.PoolStats{})) }) It("removes idle connections", func() { @@ -489,8 +533,9 @@ var _ = Describe("ClusterClient", func() { opt.MaxRedirects = -1 client := cluster.clusterClient(opt) - slot := hashtag.Slot("A") - client.SwapSlotNodes(slot) + Eventually(func() error { + return client.SwapNodes("A") + }, 30*time.Second).ShouldNot(HaveOccurred()) err := client.Get("A").Err() Expect(err).To(HaveOccurred()) @@ -627,6 +672,8 @@ var _ = Describe("ClusterClient", func() { Describe("ClusterClient failover", func() { BeforeEach(func() { + failover = true + opt = redisClusterOptions() opt.MinRetryBackoff = 250 * time.Millisecond opt.MaxRetryBackoff = time.Second @@ -637,21 +684,34 @@ var _ = Describe("ClusterClient", func() { }) Expect(err).NotTo(HaveOccurred()) - _ = client.ForEachSlave(func(slave *redis.Client) error { + err = client.ForEachSlave(func(slave *redis.Client) error { defer GinkgoRecover() Eventually(func() int64 { return slave.DBSize().Val() }, 30*time.Second).Should(Equal(int64(0))) - return slave.ClusterFailover().Err() + return nil }) + Expect(err).NotTo(HaveOccurred()) + + state, err := client.GetState() + Expect(err).NotTo(HaveOccurred()) + Expect(state.IsConsistent()).To(BeTrue()) + + for _, slave := range state.Slaves { + err = slave.Client.ClusterFailover().Err() + Expect(err).NotTo(HaveOccurred()) + + Eventually(func() bool { + state, _ := client.LoadState() + return state.IsConsistent() + }, 30*time.Second).Should(BeTrue()) + } }) AfterEach(func() { - _ = client.ForEachMaster(func(master *redis.Client) error { - return master.FlushDB().Err() - }) + failover = false Expect(client.Close()).NotTo(HaveOccurred()) }) @@ -664,23 +724,28 @@ var _ = Describe("ClusterClient", func() { opt.RouteByLatency = true client = cluster.clusterClient(opt) - _ = client.ForEachMaster(func(master *redis.Client) error { + err := client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() }) + Expect(err).NotTo(HaveOccurred()) - _ = client.ForEachSlave(func(slave *redis.Client) error { + err = client.ForEachSlave(func(slave *redis.Client) error { Eventually(func() int64 { return client.DBSize().Val() }, 30*time.Second).Should(Equal(int64(0))) return nil }) + Expect(err).NotTo(HaveOccurred()) }) AfterEach(func() { - _ = client.ForEachMaster(func(master *redis.Client) error { - return master.FlushDB().Err() + err := client.ForEachSlave(func(slave *redis.Client) error { + return slave.ReadWrite().Err() }) - Expect(client.Close()).NotTo(HaveOccurred()) + Expect(err).NotTo(HaveOccurred()) + + err = client.Close() + Expect(err).NotTo(HaveOccurred()) }) assertClusterClient() diff --git a/commands.go b/commands.go index a3dacacd..c6a88154 100644 --- a/commands.go +++ b/commands.go @@ -266,6 +266,8 @@ type Cmdable interface { GeoDist(key string, member1, member2, unit string) *FloatCmd GeoHash(key string, members ...string) *StringSliceCmd Command() *CommandsInfoCmd + ReadOnly() *StatusCmd + ReadWrite() *StatusCmd } type StatefulCmdable interface { @@ -274,8 +276,6 @@ type StatefulCmdable interface { Select(index int) *StatusCmd SwapDB(index1, index2 int) *StatusCmd ClientSetName(name string) *BoolCmd - ReadOnly() *StatusCmd - ReadWrite() *StatusCmd } var _ Cmdable = (*Client)(nil) @@ -2054,13 +2054,13 @@ func (c *cmdable) ClusterSlaves(nodeID string) *StringSliceCmd { return cmd } -func (c *statefulCmdable) ReadOnly() *StatusCmd { +func (c *cmdable) ReadOnly() *StatusCmd { cmd := NewStatusCmd("readonly") c.process(cmd) return cmd } -func (c *statefulCmdable) ReadWrite() *StatusCmd { +func (c *cmdable) ReadWrite() *StatusCmd { cmd := NewStatusCmd("readwrite") c.process(cmd) return cmd diff --git a/export_test.go b/export_test.go index 288e86f0..fcb7fa0d 100644 --- a/export_test.go +++ b/export_test.go @@ -1,9 +1,11 @@ package redis import ( + "fmt" "net" "time" + "github.com/go-redis/redis/internal/hashtag" "github.com/go-redis/redis/internal/pool" ) @@ -19,6 +21,14 @@ func (c *PubSub) ReceiveMessageTimeout(timeout time.Duration) (*Message, error) return c.receiveMessage(timeout) } +func (c *ClusterClient) GetState() (*clusterState, error) { + return c.state.Get() +} + +func (c *ClusterClient) LoadState() (*clusterState, error) { + return c.loadState() +} + func (c *ClusterClient) SlotAddrs(slot int) []string { state, err := c.state.Get() if err != nil { @@ -32,15 +42,25 @@ func (c *ClusterClient) SlotAddrs(slot int) []string { return addrs } -// SwapSlot swaps a slot's master/slave address for testing MOVED redirects. -func (c *ClusterClient) SwapSlotNodes(slot int) { - state, err := c.state.Get() +func (c *ClusterClient) Nodes(key string) ([]*clusterNode, error) { + state, err := c.state.Reload() if err != nil { - panic(err) + return nil, err } + slot := hashtag.Slot(key) nodes := state.slots[slot] - if len(nodes) == 2 { - nodes[0], nodes[1] = nodes[1], nodes[0] + if len(nodes) != 2 { + return nil, fmt.Errorf("slot=%d does not have enough nodes: %v", slot, nodes) } + return nodes, nil +} + +func (c *ClusterClient) SwapNodes(key string) error { + nodes, err := c.Nodes(key) + if err != nil { + return err + } + nodes[0], nodes[1] = nodes[1], nodes[0] + return nil }