diff --git a/ring.go b/ring.go index de1b0a3..0a1069d 100644 --- a/ring.go +++ b/ring.go @@ -219,6 +219,10 @@ type ringSharding struct { hash ConsistentHash numShard int onNewNode []func(rdb *Client) + + // ensures exclusive access to SetAddrs so there is no need + // to hold mu for the duration of potentially long shard creation + setAddrsMu sync.Mutex } type ringShards struct { @@ -245,46 +249,62 @@ func (c *ringSharding) OnNewNode(fn func(rdb *Client)) { // decrease number of shards, that you use. It will reuse shards that // existed before and close the ones that will not be used anymore. func (c *ringSharding) SetAddrs(addrs map[string]string) { - c.mu.Lock() + c.setAddrsMu.Lock() + defer c.setAddrsMu.Unlock() + cleanup := func(shards map[string]*ringShard) { + for addr, shard := range shards { + if err := shard.Client.Close(); err != nil { + internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err) + } + } + } + + c.mu.RLock() if c.closed { + c.mu.RUnlock() + return + } + existing := c.shards + c.mu.RUnlock() + + shards, created, unused := c.newRingShards(addrs, existing) + + c.mu.Lock() + if c.closed { + cleanup(created) c.mu.Unlock() return } - - shards, cleanup := c.newRingShards(addrs, c.shards) c.shards = shards c.rebalanceLocked() c.mu.Unlock() - cleanup() + cleanup(unused) } func (c *ringSharding) newRingShards( - addrs map[string]string, existingShards *ringShards, -) (*ringShards, func()) { - shardMap := make(map[string]*ringShard) // indexed by addr - unusedShards := make(map[string]*ringShard) // indexed by addr + addrs map[string]string, existing *ringShards, +) (shards *ringShards, created, unused map[string]*ringShard) { - if existingShards != nil { - for _, shard := range existingShards.list { - addr := shard.Client.opt.Addr - shardMap[addr] = shard - unusedShards[addr] = shard + shards = &ringShards{m: make(map[string]*ringShard, len(addrs))} + created = make(map[string]*ringShard) // indexed by addr + unused = make(map[string]*ringShard) // indexed by addr + + if existing != nil { + for _, shard := range existing.list { + unused[shard.addr] = shard } } - shards := &ringShards{ - m: make(map[string]*ringShard), - } - for name, addr := range addrs { - if shard, ok := shardMap[addr]; ok { + if shard, ok := unused[addr]; ok { shards.m[name] = shard - delete(unusedShards, addr) + delete(unused, addr) } else { shard := newRingShard(c.opt, addr) shards.m[name] = shard + created[addr] = shard for _, fn := range c.onNewNode { fn(shard.Client) @@ -296,13 +316,7 @@ func (c *ringSharding) newRingShards( shards.list = append(shards.list, shard) } - return shards, func() { - for addr, shard := range unusedShards { - if err := shard.Client.Close(); err != nil { - internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err) - } - } - } + return } func (c *ringSharding) List() []*ringShard { diff --git a/ring_test.go b/ring_test.go index eb03c32..c76a783 100644 --- a/ring_test.go +++ b/ring_test.go @@ -124,7 +124,7 @@ var _ = Describe("Redis Ring", func() { }) Expect(ring.Len(), 1) gotShard := ring.ShardByName("ringShardOne") - Expect(gotShard).To(Equal(wantShard)) + Expect(gotShard).To(BeIdenticalTo(wantShard)) ring.SetAddrs(map[string]string{ "ringShardOne": ":" + ringShard1Port, @@ -132,7 +132,7 @@ var _ = Describe("Redis Ring", func() { }) Expect(ring.Len(), 2) gotShard = ring.ShardByName("ringShardOne") - Expect(gotShard).To(Equal(wantShard)) + Expect(gotShard).To(BeIdenticalTo(wantShard)) }) It("uses 3 shards after setting it to 3 shards", func() { @@ -156,8 +156,8 @@ var _ = Describe("Redis Ring", func() { gotShard1 := ring.ShardByName(shardName1) gotShard2 := ring.ShardByName(shardName2) gotShard3 := ring.ShardByName(shardName3) - Expect(gotShard1).To(Equal(wantShard1)) - Expect(gotShard2).To(Equal(wantShard2)) + Expect(gotShard1).To(BeIdenticalTo(wantShard1)) + Expect(gotShard2).To(BeIdenticalTo(wantShard2)) Expect(gotShard3).ToNot(BeNil()) ring.SetAddrs(map[string]string{ @@ -168,8 +168,8 @@ var _ = Describe("Redis Ring", func() { gotShard1 = ring.ShardByName(shardName1) gotShard2 = ring.ShardByName(shardName2) gotShard3 = ring.ShardByName(shardName3) - Expect(gotShard1).To(Equal(wantShard1)) - Expect(gotShard2).To(Equal(wantShard2)) + Expect(gotShard1).To(BeIdenticalTo(wantShard1)) + Expect(gotShard2).To(BeIdenticalTo(wantShard2)) Expect(gotShard3).To(BeNil()) }) })