From 425f2fc69bcc46098688039179feb6b3a1bfe71f Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Wed, 29 Sep 2021 14:09:51 +0300 Subject: [PATCH] fix: use slot node id to detect node re-configuration --- cluster.go | 60 +++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/cluster.go b/cluster.go index 8d93b36..d7fe319 100644 --- a/cluster.go +++ b/cluster.go @@ -170,6 +170,7 @@ func (opt *ClusterOptions) clientOptions() *Options { //------------------------------------------------------------------------------ type clusterNode struct { + id string Client *Client latency uint32 // atomic @@ -177,10 +178,11 @@ type clusterNode struct { failing uint32 // atomic } -func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode { +func newClusterNode(clOpt *ClusterOptions, id, addr string) *clusterNode { opt := clOpt.clientOptions() opt.Addr = addr node := clusterNode{ + id: id, Client: clOpt.NewClient(opt), } @@ -352,33 +354,51 @@ func (c *clusterNodes) GC(generation uint32) { } } -func (c *clusterNodes) Get(addr string) (*clusterNode, error) { +func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { + return c.GetOrCreateWithID(addr, "") +} + +func (c *clusterNodes) GetOrCreateWithID(addr, id string) (*clusterNode, error) { node, err := c.get(addr) if err != nil { return nil, err } - if node != nil { + if node != nil && (id == "" || node.id == id) { return node, nil } c.mu.Lock() - defer c.mu.Unlock() + node, oldNode, err := c.getOrCreate(addr, id) + c.mu.Unlock() + if err != nil { + return nil, err + } + if oldNode != nil { + _ = oldNode.Client.Close() + } + return node, nil +} + +func (c *clusterNodes) getOrCreate(addr, id string) (node, oldNode *clusterNode, _ error) { if c.closed { - return nil, pool.ErrClosed + return nil, nil, pool.ErrClosed } - node, ok := c.nodes[addr] + oldNode, ok := c.nodes[addr] if ok { - return node, nil + // The id is changed when node is re-configured, for example, IP addr is changed. + if id == "" || oldNode.id == id { + return oldNode, nil, nil + } + } else { + c.addrs = appendIfNotExists(c.addrs, addr) } - node = newClusterNode(c.opt, addr) - - c.addrs = appendIfNotExists(c.addrs, addr) + node = newClusterNode(c.opt, id, addr) c.nodes[addr] = node - return node, nil + return node, oldNode, nil } func (c *clusterNodes) get(addr string) (*clusterNode, error) { @@ -416,7 +436,7 @@ func (c *clusterNodes) Random() (*clusterNode, error) { } n := rand.Intn(len(addrs)) - return c.Get(addrs[n]) + return c.GetOrCreate(addrs[n]) } //------------------------------------------------------------------------------ @@ -474,7 +494,7 @@ func newClusterState( addr = replaceLoopbackHost(addr, originHost) } - node, err := c.nodes.Get(addr) + node, err := c.nodes.GetOrCreateWithID(addr, slotNode.ID) if err != nil { return nil, err } @@ -824,8 +844,10 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { var addr string moved, ask, addr = isMovedError(lastErr) if moved || ask { + c.state.LazyReload() + var err error - node, err = c.nodes.Get(addr) + node, err = c.nodes.GetOrCreate(addr) if err != nil { return err } @@ -1022,7 +1044,7 @@ func (c *ClusterClient) loadState(ctx context.Context) (*clusterState, error) { for _, idx := range rand.Perm(len(addrs)) { addr := addrs[idx] - node, err := c.nodes.Get(addr) + node, err := c.nodes.GetOrCreate(addr) if err != nil { if firstErr == nil { firstErr = err @@ -1236,7 +1258,7 @@ func (c *ClusterClient) checkMovedErr( return false } - node, err := c.nodes.Get(addr) + node, err := c.nodes.GetOrCreate(addr) if err != nil { return false } @@ -1422,7 +1444,7 @@ func (c *ClusterClient) cmdsMoved( addr string, failedCmds *cmdsMap, ) error { - node, err := c.nodes.Get(addr) + node, err := c.nodes.GetOrCreate(addr) if err != nil { return err } @@ -1477,7 +1499,7 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s moved, ask, addr := isMovedError(err) if moved || ask { - node, err = c.nodes.Get(addr) + node, err = c.nodes.GetOrCreate(addr) if err != nil { return err } @@ -1589,7 +1611,7 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, for _, idx := range perm { addr := addrs[idx] - node, err := c.nodes.Get(addr) + node, err := c.nodes.GetOrCreate(addr) if err != nil { if firstErr == nil { firstErr = err