diff --git a/cluster.go b/cluster.go index f2832a4..a5a16db 100644 --- a/cluster.go +++ b/cluster.go @@ -14,6 +14,7 @@ import ( ) var errClusterNoNodes = internal.RedisError("redis: cluster has no nodes") +var errNilClusterState = internal.RedisError("redis: cannot load cluster slots") // ClusterOptions are used to configure a cluster client and should be // passed to NewClusterClient. @@ -355,7 +356,14 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { _, _ = c.nodes.Get(addr) } - c.reloadSlots() + // Preload cluster slots. + for i := 0; i < 10; i++ { + state, err := c.reloadSlots() + if err == nil { + c._state.Store(state) + break + } + } if opt.IdleCheckFrequency > 0 { go c.reaper(opt.IdleCheckFrequency) @@ -366,10 +374,11 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { func (c *ClusterClient) state() *clusterState { v := c._state.Load() - if v == nil { - return nil + if v != nil { + return v.(*clusterState) } - return v.(*clusterState) + c.lazyReloadSlots() + return nil } func (c *ClusterClient) cmdSlotAndNode(state *clusterState, cmd Cmder) (int, *clusterNode, error) { @@ -397,10 +406,12 @@ func (c *ClusterClient) cmdSlotAndNode(state *clusterState, cmd Cmder) (int, *cl } func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { + state := c.state() + var node *clusterNode var err error - if len(keys) > 0 { - node, err = c.state().slotMasterNode(hashtag.Slot(keys[0])) + if state != nil && len(keys) > 0 { + node, err = state.slotMasterNode(hashtag.Slot(keys[0])) } else { node, err = c.nodes.Random() } @@ -463,8 +474,9 @@ func (c *ClusterClient) Process(cmd Cmder) error { var addr string moved, ask, addr = internal.IsMovedError(err) if moved || ask { - if slot >= 0 { - master, _ := c.state().slotMasterNode(slot) + state := c.state() + if state != nil && slot >= 0 { + master, _ := state.slotMasterNode(slot) if moved && (master == nil || master.Client.getAddr() != addr) { c.lazyReloadSlots() } @@ -523,7 +535,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { state := c.state() if state == nil { - return nil + return errNilClusterState } var wg sync.WaitGroup @@ -564,12 +576,13 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { // PoolStats returns accumulated connection pool stats. func (c *ClusterClient) PoolStats() *PoolStats { + var acc PoolStats + nodes, err := c.nodes.All() if err != nil { - return nil + return &acc } - var acc PoolStats for _, node := range nodes { s := node.Client.connPool.Stats() acc.Requests += s.Requests @@ -585,37 +598,46 @@ func (c *ClusterClient) lazyReloadSlots() { if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) { return } + go func() { - c.reloadSlots() + for i := 0; i < 1000; i++ { + state, err := c.reloadSlots() + if err == pool.ErrClosed { + break + } + if err == nil { + c._state.Store(state) + break + } + time.Sleep(time.Millisecond) + } + + time.Sleep(3 * time.Second) atomic.StoreUint32(&c.reloading, 0) }() } -func (c *ClusterClient) reloadSlots() { - for i := 0; i < 10; i++ { - node, err := c.nodes.Random() - if err != nil { - return - } - - if c.cmds == nil { - cmds, err := node.Client.Command().Result() - if err == nil { - c.cmds = cmds - } - } - - slots, err := node.Client.ClusterSlots().Result() - if err != nil { - continue - } - - state, err := newClusterState(c.nodes, slots) - if err != nil { - return - } - c._state.Store(state) +func (c *ClusterClient) reloadSlots() (*clusterState, error) { + node, err := c.nodes.Random() + if err != nil { + return nil, err } + + // TODO: fix race + if c.cmds == nil { + cmds, err := node.Client.Command().Result() + if err != nil { + return nil, err + } + c.cmds = cmds + } + + slots, err := node.Client.ClusterSlots().Result() + if err != nil { + return nil, err + } + + return newClusterState(c.nodes, slots) } // reaper closes idle connections to the cluster. @@ -789,8 +811,13 @@ func (c *ClusterClient) txPipelineExec(cmds []Cmder) error { return err } + state := c.state() + if state == nil { + return errNilClusterState + } + for slot, cmds := range cmdsMap { - node, err := c.state().slotMasterNode(slot) + node, err := state.slotMasterNode(slot) if err != nil { setCmdsErr(cmds, err) continue diff --git a/cluster_test.go b/cluster_test.go index 589ef98..28314bc 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -578,7 +578,7 @@ var _ = Describe("ClusterClient timeout", func() { var client *redis.ClusterClient AfterEach(func() { - Expect(client.Close()).NotTo(HaveOccurred()) + _ = client.Close() }) testTimeout := func() {