diff --git a/cluster.go b/cluster.go index 52f357a4..32e406e4 100644 --- a/cluster.go +++ b/cluster.go @@ -14,6 +14,7 @@ import ( ) var errClusterNoNodes = internal.RedisError("redis: cluster has no nodes") +var errNilClusterState = internal.RedisError("redis: cannot load cluster slots") // ClusterOptions are used to configure a cluster client and should be // passed to NewClusterClient. @@ -357,7 +358,14 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { _, _ = c.nodes.Get(addr) } - c.reloadSlots() + // Preload cluster slots. + for i := 0; i < 10; i++ { + state, err := c.reloadSlots() + if err == nil { + c._state.Store(state) + break + } + } if opt.IdleCheckFrequency > 0 { go c.reaper(opt.IdleCheckFrequency) @@ -368,10 +376,11 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { func (c *ClusterClient) state() *clusterState { v := c._state.Load() - if v == nil { - return nil + if v != nil { + return v.(*clusterState) } - return v.(*clusterState) + c.lazyReloadSlots() + return nil } func (c *ClusterClient) cmdSlotAndNode(state *clusterState, cmd Cmder) (int, *clusterNode, error) { @@ -399,10 +408,12 @@ func (c *ClusterClient) cmdSlotAndNode(state *clusterState, cmd Cmder) (int, *cl } func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { + state := c.state() + var node *clusterNode var err error - if len(keys) > 0 { - node, err = c.state().slotMasterNode(hashtag.Slot(keys[0])) + if state != nil && len(keys) > 0 { + node, err = state.slotMasterNode(hashtag.Slot(keys[0])) } else { node, err = c.nodes.Random() } @@ -465,8 +476,9 @@ func (c *ClusterClient) Process(cmd Cmder) error { var addr string moved, ask, addr = internal.IsMovedError(err) if moved || ask { - if slot >= 0 { - master, _ := c.state().slotMasterNode(slot) + state := c.state() + if state != nil && slot >= 0 { + master, _ := state.slotMasterNode(slot) if moved && (master == nil || master.Client.getAddr() != addr) { c.lazyReloadSlots() } @@ -525,7 +537,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { state := c.state() if state == nil { - return nil + return errNilClusterState } var wg sync.WaitGroup @@ -566,12 +578,13 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { // PoolStats returns accumulated connection pool stats. func (c *ClusterClient) PoolStats() *PoolStats { + var acc PoolStats + nodes, err := c.nodes.All() if err != nil { - return nil + return &acc } - var acc PoolStats for _, node := range nodes { s := node.Client.connPool.Stats() acc.Requests += s.Requests @@ -587,37 +600,46 @@ func (c *ClusterClient) lazyReloadSlots() { if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) { return } + go func() { - c.reloadSlots() + for i := 0; i < 1000; i++ { + state, err := c.reloadSlots() + if err == pool.ErrClosed { + break + } + if err == nil { + c._state.Store(state) + break + } + time.Sleep(time.Millisecond) + } + + time.Sleep(3 * time.Second) atomic.StoreUint32(&c.reloading, 0) }() } -func (c *ClusterClient) reloadSlots() { - for i := 0; i < 10; i++ { - node, err := c.nodes.Random() - if err != nil { - return - } - - if c.cmds == nil { - cmds, err := node.Client.Command().Result() - if err == nil { - c.cmds = cmds - } - } - - slots, err := node.Client.ClusterSlots().Result() - if err != nil { - continue - } - - state, err := newClusterState(c.nodes, slots) - if err != nil { - return - } - c._state.Store(state) +func (c *ClusterClient) reloadSlots() (*clusterState, error) { + node, err := c.nodes.Random() + if err != nil { + return nil, err } + + // TODO: fix race + if c.cmds == nil { + cmds, err := node.Client.Command().Result() + if err != nil { + return nil, err + } + c.cmds = cmds + } + + slots, err := node.Client.ClusterSlots().Result() + if err != nil { + return nil, err + } + + return newClusterState(c.nodes, slots) } // reaper closes idle connections to the cluster. @@ -791,8 +813,13 @@ func (c *ClusterClient) txPipelineExec(cmds []Cmder) error { return err } + state := c.state() + if state == nil { + return errNilClusterState + } + for slot, cmds := range cmdsMap { - node, err := c.state().slotMasterNode(slot) + node, err := state.slotMasterNode(slot) if err != nil { setCmdsErr(cmds, err) continue diff --git a/cluster_test.go b/cluster_test.go index 49cb13ca..53a33631 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -582,7 +582,7 @@ var _ = Describe("ClusterClient timeout", func() { var client *redis.ClusterClient AfterEach(func() { - Expect(client.Close()).NotTo(HaveOccurred()) + _ = client.Close() }) testTimeout := func() {