diff --git a/cluster.go b/cluster.go index a2c18b38..628f513e 100644 --- a/cluster.go +++ b/cluster.go @@ -2,6 +2,7 @@ package redis import ( "fmt" + "math" "math/rand" "net" "sync" @@ -118,11 +119,11 @@ func (opt *ClusterOptions) clientOptions() *Options { //------------------------------------------------------------------------------ type clusterNode struct { - Client *Client - Latency time.Duration + Client *Client - loading time.Time + latency uint32 // atomic generation uint32 + loading int64 // atomic } func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode { @@ -132,8 +133,9 @@ func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode { Client: NewClient(opt), } + node.latency = math.MaxUint32 if clOpt.RouteByLatency { - node.updateLatency() + go node.updateLatency() } return &node @@ -141,16 +143,46 @@ func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode { func (n *clusterNode) updateLatency() { const probes = 10 + + var latency uint32 for i := 0; i < probes; i++ { start := time.Now() n.Client.Ping() - n.Latency += time.Since(start) + probe := uint32(time.Since(start) / time.Microsecond) + latency = (latency + probe) / 2 } - n.Latency = n.Latency / probes + atomic.StoreUint32(&n.latency, latency) +} + +func (n *clusterNode) Close() error { + return n.Client.Close() +} + +func (n *clusterNode) Test() error { + return n.Client.ClusterInfo().Err() +} + +func (n *clusterNode) Latency() time.Duration { + latency := atomic.LoadUint32(&n.latency) + return time.Duration(latency) * time.Microsecond +} + +func (n *clusterNode) MarkAsLoading() { + atomic.StoreInt64(&n.loading, time.Now().Unix()) } func (n *clusterNode) Loading() bool { - return !n.loading.IsZero() && time.Since(n.loading) < time.Minute + const minute = int64(time.Minute / time.Second) + + loading := atomic.LoadInt64(&n.loading) + if loading == 0 { + return false + } + if time.Now().Unix()-loading < minute { + return true + } + atomic.StoreInt64(&n.loading, 0) + return false } func (n *clusterNode) Generation() uint32 { @@ -310,7 +342,7 @@ func (c *clusterNodes) Random() (*clusterNode, error) { return nil, err } - nodeErr = node.Client.ClusterInfo().Err() + nodeErr = node.Test() if nodeErr == nil { return node, nil } @@ -416,7 +448,7 @@ func (c *clusterState) slotClosestNode(slot int) (*clusterNode, error) { if n.Loading() { continue } - if node == nil || node.Latency-n.Latency > threshold { + if node == nil || node.Latency()-n.Latency() > threshold { node = n } } @@ -687,8 +719,7 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error { // If slave is loading - read from master. if c.opt.ReadOnly && internal.IsLoadingError(err) { - // TODO: race - node.loading = time.Now() + node.MarkAsLoading() continue } diff --git a/cluster_test.go b/cluster_test.go index 43f3261b..a142f8c0 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -320,6 +320,14 @@ var _ = Describe("ClusterClient", func() { Expect(err).NotTo(HaveOccurred()) Expect(cmds).To(HaveLen(14)) + _ = client.ForEachNode(func(node *redis.Client) error { + defer GinkgoRecover() + Eventually(func() int64 { + return node.DBSize().Val() + }, 30*time.Second).ShouldNot(BeZero()) + return nil + }) + for _, key := range keys { slot := hashtag.Slot(key) client.SwapSlotNodes(slot) @@ -576,7 +584,7 @@ var _ = Describe("ClusterClient", func() { _ = client.ForEachSlave(func(slave *redis.Client) error { Eventually(func() int64 { - return client.DBSize().Val() + return slave.DBSize().Val() }, 30*time.Second).Should(Equal(int64(0))) return slave.ClusterFailover().Err() }) @@ -717,7 +725,7 @@ var _ = Describe("ClusterClient timeout", func() { }) } - const pause = time.Second + const pause = 2 * time.Second Context("read/write timeout", func() { BeforeEach(func() { diff --git a/commands_test.go b/commands_test.go index 71537955..067ffccd 100644 --- a/commands_test.go +++ b/commands_test.go @@ -447,7 +447,7 @@ var _ = Describe("Commands", func() { pttl := client.PTTL("key") Expect(pttl.Err()).NotTo(HaveOccurred()) - Expect(pttl.Val()).To(BeNumerically("~", expiration, 10*time.Millisecond)) + Expect(pttl.Val()).To(BeNumerically("~", expiration, 100*time.Millisecond)) }) It("should PExpireAt", func() { @@ -466,7 +466,7 @@ var _ = Describe("Commands", func() { pttl := client.PTTL("key") Expect(pttl.Err()).NotTo(HaveOccurred()) - Expect(pttl.Val()).To(BeNumerically("~", expiration, 10*time.Millisecond)) + Expect(pttl.Val()).To(BeNumerically("~", expiration, 100*time.Millisecond)) }) It("should PTTL", func() { @@ -481,7 +481,7 @@ var _ = Describe("Commands", func() { pttl := client.PTTL("key") Expect(pttl.Err()).NotTo(HaveOccurred()) - Expect(pttl.Val()).To(BeNumerically("~", expiration, 10*time.Millisecond)) + Expect(pttl.Val()).To(BeNumerically("~", expiration, 100*time.Millisecond)) }) It("should RandomKey", func() {