diff --git a/cluster.go b/cluster.go index ea677ae2..783e27de 100644 --- a/cluster.go +++ b/cluster.go @@ -203,11 +203,11 @@ func (n *clusterNode) SetGeneration(gen uint32) { type clusterNodes struct { opt *ClusterOptions - mu sync.RWMutex - allAddrs []string - addrs []string - nodes map[string]*clusterNode - closed bool + mu sync.RWMutex + allAddrs []string + allNodes map[string]*clusterNode + clusterAddrs []string + closed bool nodeCreateGroup singleflight.Group @@ -219,7 +219,7 @@ func newClusterNodes(opt *ClusterOptions) *clusterNodes { opt: opt, allAddrs: opt.Addrs, - nodes: make(map[string]*clusterNode), + allNodes: make(map[string]*clusterNode), } } @@ -233,14 +233,14 @@ func (c *clusterNodes) Close() error { c.closed = true var firstErr error - for _, node := range c.nodes { + for _, node := range c.allNodes { if err := node.Client.Close(); err != nil && firstErr == nil { firstErr = err } } - c.addrs = nil - c.nodes = nil + c.allNodes = nil + c.clusterAddrs = nil return firstErr } @@ -250,8 +250,8 @@ func (c *clusterNodes) Addrs() ([]string, error) { c.mu.RLock() closed := c.closed if !closed { - if len(c.addrs) > 0 { - addrs = c.addrs + if len(c.clusterAddrs) > 0 { + addrs = c.clusterAddrs } else { addrs = c.allAddrs } @@ -276,25 +276,20 @@ func (c *clusterNodes) NextGeneration() uint32 { func (c *clusterNodes) GC(generation uint32) { var collected []*clusterNode c.mu.Lock() - for i := 0; i < len(c.addrs); { - addr := c.addrs[i] - node := c.nodes[addr] + for addr, node := range c.allNodes { if node.Generation() >= generation { - i++ continue } - c.addrs = append(c.addrs[:i], c.addrs[i+1:]...) - delete(c.nodes, addr) + c.clusterAddrs = remove(c.clusterAddrs, addr) + delete(c.allNodes, addr) collected = append(collected, node) } c.mu.Unlock() - time.AfterFunc(time.Minute, func() { - for _, node := range collected { - _ = node.Client.Close() - } - }) + for _, node := range collected { + _ = node.Client.Close() + } } func (c *clusterNodes) All() ([]*clusterNode, error) { @@ -305,23 +300,28 @@ func (c *clusterNodes) All() ([]*clusterNode, error) { return nil, pool.ErrClosed } - nodes := make([]*clusterNode, 0, len(c.nodes)) - for _, node := range c.nodes { - nodes = append(nodes, node) + cp := make([]*clusterNode, 0, len(c.allNodes)) + for _, node := range c.allNodes { + cp = append(cp, node) } - return nodes, nil + return cp, nil } func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { var node *clusterNode - var ok bool + var err error c.mu.RLock() - if !c.closed { - node, ok = c.nodes[addr] + if c.closed { + err = pool.ErrClosed + } else { + node = c.allNodes[addr] } c.mu.RUnlock() - if ok { + if err != nil { + return nil, err + } + if node != nil { return node, nil } @@ -329,9 +329,6 @@ func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { node := newClusterNode(c.opt, addr) return node, node.Test() }) - if err != nil { - return nil, err - } c.mu.Lock() defer c.mu.Unlock() @@ -340,18 +337,20 @@ func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { return nil, pool.ErrClosed } - node, ok = c.nodes[addr] + node, ok := c.allNodes[addr] if ok { _ = v.(*clusterNode).Close() - return node, nil + return node, err } node = v.(*clusterNode) c.allAddrs = appendIfNotExists(c.allAddrs, addr) - c.addrs = append(c.addrs, addr) - c.nodes[addr] = node + if err == nil { + c.clusterAddrs = append(c.clusterAddrs, addr) + } + c.allNodes[addr] = node - return node, nil + return node, err } func (c *clusterNodes) Random() (*clusterNode, error) { @@ -679,10 +678,7 @@ func (c *ClusterClient) WrapProcess( } func (c *ClusterClient) Process(cmd Cmder) error { - if c.process != nil { - return c.process(cmd) - } - return c.defaultProcess(cmd) + return c.process(cmd) } func (c *ClusterClient) defaultProcess(cmd Cmder) error { @@ -918,7 +914,9 @@ func (c *ClusterClient) reloadState() bool { state, err := c.loadState() if err == nil { c._state.Store(state) - c.nodes.GC(state.generation) + time.AfterFunc(time.Minute, func() { + c.nodes.GC(state.generation) + }) return true }