diff --git a/cluster.go b/cluster.go index 7a4c2c84..ba562cf1 100644 --- a/cluster.go +++ b/cluster.go @@ -1,6 +1,7 @@ package redis import ( + "errors" "fmt" "math" "math/rand" @@ -292,21 +293,6 @@ func (c *clusterNodes) GC(generation uint32) { } } -func (c *clusterNodes) All() ([]*clusterNode, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - if c.closed { - return nil, pool.ErrClosed - } - - cp := make([]*clusterNode, 0, len(c.allNodes)) - for _, node := range c.allNodes { - cp = append(cp, node) - } - return cp, nil -} - func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { var node *clusterNode var err error @@ -353,6 +339,21 @@ func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { return node, err } +func (c *clusterNodes) All() ([]*clusterNode, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.closed { + return nil, pool.ErrClosed + } + + cp := make([]*clusterNode, 0, len(c.allNodes)) + for _, node := range c.allNodes { + cp = append(cp, node) + } + return cp, nil +} + func (c *clusterNodes) Random() (*clusterNode, error) { addrs, err := c.Addrs() if err != nil { @@ -412,6 +413,10 @@ func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (* } } + time.AfterFunc(time.Minute, func() { + nodes.GC(c.generation) + }) + return &c, nil } @@ -477,6 +482,66 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode { //------------------------------------------------------------------------------ +type clusterStateHolder struct { + load func() (*clusterState, error) + reloading uint32 // atomic + + state atomic.Value + + lastErrMu sync.RWMutex + lastErr error +} + +func newClusterStateHolder(fn func() (*clusterState, error)) *clusterStateHolder { + return &clusterStateHolder{ + load: fn, + } +} + +func (c *clusterStateHolder) Load() (*clusterState, error) { + state, err := c.load() + if err != nil { + c.lastErrMu.Lock() + c.lastErr = err + c.lastErrMu.Unlock() + return nil, err + } + c.state.Store(state) + return state, nil +} + +func (c *clusterStateHolder) LazyReload() { + if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) { + return + } + go func() { + defer atomic.StoreUint32(&c.reloading, 0) + + _, err := c.Load() + if err == nil { + time.Sleep(time.Second) + } + }() +} + +func (c *clusterStateHolder) Get() (*clusterState, error) { + v := c.state.Load() + if v != nil { + return v.(*clusterState), nil + } + + c.lastErrMu.RLock() + err := c.lastErr + c.lastErrMu.RUnlock() + if err != nil { + return nil, err + } + + return nil, errors.New("redis: cluster has no state") +} + +//------------------------------------------------------------------------------ + // ClusterClient is a Redis Cluster client representing a pool of zero // or more underlying connections. It's safe for concurrent use by // multiple goroutines. @@ -485,18 +550,12 @@ type ClusterClient struct { opt *ClusterOptions nodes *clusterNodes + state *clusterStateHolder cmdsInfoCache *cmdsInfoCache - _state atomic.Value - stateErrMu sync.RWMutex - stateErr error - process func(Cmder) error processPipeline func([]Cmder) error processTxPipeline func([]Cmder) error - - // Reports whether slots reloading is in progress. - reloading uint32 // atomic } // NewClusterClient returns a Redis Cluster client as described in @@ -509,6 +568,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { nodes: newClusterNodes(opt), cmdsInfoCache: newCmdsInfoCache(), } + c.state = newClusterStateHolder(c.loadState) c.process = c.defaultProcess c.processPipeline = c.defaultProcessPipeline @@ -516,7 +576,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.cmdable.setProcessor(c.Process) - c.reloadState() + _, _ = c.state.Load() if opt.IdleCheckFrequency > 0 { go c.reaper(opt.IdleCheckFrequency) } @@ -565,7 +625,7 @@ func (c *ClusterClient) cmdSlot(cmd Cmder) int { } func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode, error) { - state, err := c.state() + state, err := c.state.Get() if err != nil { return 0, nil, err } @@ -588,7 +648,7 @@ func (c *ClusterClient) cmdSlotAndNode(cmd Cmder) (int, *clusterNode, error) { } func (c *ClusterClient) slotMasterNode(slot int) (*clusterNode, error) { - state, err := c.state() + state, err := c.state.Get() if err != nil { return nil, err } @@ -633,7 +693,7 @@ func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error { moved, ask, addr := internal.IsMovedError(err) if moved || ask { - c.lazyReloadState() + c.state.LazyReload() node, err = c.nodes.GetOrCreate(addr) if err != nil { return err @@ -725,7 +785,7 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error { var addr string moved, ask, addr = internal.IsMovedError(err) if moved || ask { - c.lazyReloadState() + c.state.LazyReload() node, err = c.nodes.GetOrCreate(addr) if err != nil { @@ -748,7 +808,7 @@ func (c *ClusterClient) defaultProcess(cmd Cmder) error { // ForEachMaster concurrently calls the fn on each master node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { - state, err := c.state() + state, err := c.state.Get() if err != nil { return err } @@ -781,7 +841,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { // ForEachSlave concurrently calls the fn on each slave node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { - state, err := c.state() + state, err := c.state.Get() if err != nil { return err } @@ -814,7 +874,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { // ForEachNode concurrently calls the fn on each known node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { - state, err := c.state() + state, err := c.state.Get() if err != nil { return err } @@ -854,7 +914,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { func (c *ClusterClient) PoolStats() *PoolStats { var acc PoolStats - state, _ := c.state() + state, _ := c.state.Get() if state == nil { return &acc } @@ -884,75 +944,34 @@ func (c *ClusterClient) PoolStats() *PoolStats { return &acc } -func (c *ClusterClient) lazyReloadState() { - if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) { - return - } - go func() { - if c.reloadState() { - time.Sleep(time.Second) - } - atomic.StoreUint32(&c.reloading, 0) - }() -} - -func (c *ClusterClient) reloadState() bool { - for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { - if attempt > 0 { - time.Sleep(c.retryBackoff(attempt)) - } - - state, err := c.loadState() - if err == nil { - c._state.Store(state) - time.AfterFunc(time.Minute, func() { - c.nodes.GC(state.generation) - }) - return true - } - - c.setStateErr(err) - switch err { - case pool.ErrClosed, errClusterNoNodes: - return false - } - } - return false -} - func (c *ClusterClient) loadState() (*clusterState, error) { - node, err := c.nodes.Random() + addrs, err := c.nodes.Addrs() if err != nil { return nil, err } - slots, err := node.Client.ClusterSlots().Result() - if err != nil { - return nil, err + var firstErr error + for _, addr := range addrs { + node, err := c.nodes.GetOrCreate(addr) + if err != nil { + if firstErr == nil { + firstErr = err + } + continue + } + + slots, err := node.Client.ClusterSlots().Result() + if err != nil { + if firstErr == nil { + firstErr = err + } + continue + } + + return newClusterState(c.nodes, slots, node.Client.opt.Addr) } - return newClusterState(c.nodes, slots, node.Client.opt.Addr) -} - -func (c *ClusterClient) state() (*clusterState, error) { - v := c._state.Load() - if v != nil { - return v.(*clusterState), nil - } - return nil, c.getStateErr() -} - -func (c *ClusterClient) setStateErr(err error) { - c.stateErrMu.Lock() - c.stateErr = err - c.stateErrMu.Unlock() -} - -func (c *ClusterClient) getStateErr() error { - c.stateErrMu.RLock() - err := c.stateErr - c.stateErrMu.RUnlock() - return err + return nil, firstErr } // reaper closes idle connections to the cluster. @@ -1036,7 +1055,7 @@ func (c *ClusterClient) defaultProcessPipeline(cmds []Cmder) error { } func (c *ClusterClient) mapCmdsByNode(cmds []Cmder) (map[*clusterNode][]Cmder, error) { - state, err := c.state() + state, err := c.state.Get() if err != nil { setCmdsErr(cmds, err) return nil, err @@ -1112,7 +1131,7 @@ func (c *ClusterClient) checkMovedErr( moved, ask, addr := internal.IsMovedError(err) if moved { - c.lazyReloadState() + c.state.LazyReload() node, err := c.nodes.GetOrCreate(addr) if err != nil { @@ -1150,7 +1169,7 @@ func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) { } func (c *ClusterClient) defaultProcessTxPipeline(cmds []Cmder) error { - state, err := c.state() + state, err := c.state.Get() if err != nil { return err } diff --git a/export_test.go b/export_test.go index bcc18c45..288e86f0 100644 --- a/export_test.go +++ b/export_test.go @@ -20,7 +20,7 @@ func (c *PubSub) ReceiveMessageTimeout(timeout time.Duration) (*Message, error) } func (c *ClusterClient) SlotAddrs(slot int) []string { - state, err := c.state() + state, err := c.state.Get() if err != nil { panic(err) } @@ -34,7 +34,7 @@ func (c *ClusterClient) SlotAddrs(slot int) []string { // SwapSlot swaps a slot's master/slave address for testing MOVED redirects. func (c *ClusterClient) SwapSlotNodes(slot int) { - state, err := c.state() + state, err := c.state.Get() if err != nil { panic(err) }