diff --git a/internal_test.go b/internal_test.go index 9f53ae6..4d42f4b 100644 --- a/internal_test.go +++ b/internal_test.go @@ -3,6 +3,9 @@ package redis import ( "context" "fmt" + "reflect" + "sync" + "sync/atomic" "testing" "time" @@ -107,6 +110,7 @@ func TestRingSetAddrsAndRebalanceRace(t *testing.T) { } }, }) + defer ring.Close() // Continuously update addresses by adding and removing one address updatesDone := make(chan struct{}) @@ -156,6 +160,7 @@ func BenchmarkRingShardingRebalanceLocked(b *testing.B) { } ring := NewRing(opts) + defer ring.Close() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -163,6 +168,119 @@ func BenchmarkRingShardingRebalanceLocked(b *testing.B) { } } +type testCounter struct { + mu sync.Mutex + t *testing.T + m map[string]int +} + +func newTestCounter(t *testing.T) *testCounter { + return &testCounter{t: t, m: make(map[string]int)} +} + +func (ct *testCounter) increment(key string) { + ct.mu.Lock() + defer ct.mu.Unlock() + ct.m[key]++ +} + +func (ct *testCounter) expect(values map[string]int) { + ct.mu.Lock() + defer ct.mu.Unlock() + ct.t.Helper() + if !reflect.DeepEqual(values, ct.m) { + ct.t.Errorf("expected %v != actual %v", values, ct.m) + } +} + +func TestRingShardsCleanup(t *testing.T) { + const ( + ringShard1Name = "ringShardOne" + ringShard2Name = "ringShardTwo" + + ringShard1Addr = "shard1.test" + ringShard2Addr = "shard2.test" + ) + + t.Run("closes unused shards", func(t *testing.T) { + closeCounter := newTestCounter(t) + + ring := NewRing(&RingOptions{ + Addrs: map[string]string{ + ringShard1Name: ringShard1Addr, + ringShard2Name: ringShard2Addr, + }, + NewClient: func(opt *Options) *Client { + c := NewClient(opt) + c.baseClient.onClose = func() error { + closeCounter.increment(opt.Addr) + return nil + } + return c + }, + }) + closeCounter.expect(map[string]int{}) + + // no change due to the same addresses + ring.SetAddrs(map[string]string{ + ringShard1Name: ringShard1Addr, + ringShard2Name: ringShard2Addr, + }) + closeCounter.expect(map[string]int{}) + + ring.SetAddrs(map[string]string{ + ringShard1Name: ringShard1Addr, + }) + closeCounter.expect(map[string]int{ringShard2Addr: 1}) + + ring.SetAddrs(map[string]string{ + ringShard2Name: ringShard2Addr, + }) + closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1}) + + ring.Close() + closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 2}) + }) + + t.Run("closes created shards if ring was closed", func(t *testing.T) { + createCounter := newTestCounter(t) + closeCounter := newTestCounter(t) + + var ( + ring *Ring + shouldClose int32 + ) + + ring = NewRing(&RingOptions{ + Addrs: map[string]string{ + ringShard1Name: ringShard1Addr, + }, + NewClient: func(opt *Options) *Client { + if atomic.LoadInt32(&shouldClose) != 0 { + ring.Close() + } + createCounter.increment(opt.Addr) + c := NewClient(opt) + c.baseClient.onClose = func() error { + closeCounter.increment(opt.Addr) + return nil + } + return c + }, + }) + createCounter.expect(map[string]int{ringShard1Addr: 1}) + closeCounter.expect(map[string]int{}) + + atomic.StoreInt32(&shouldClose, 1) + + ring.SetAddrs(map[string]string{ + ringShard2Name: ringShard2Addr, + }) + createCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1}) + closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1}) + }) +} + //------------------------------------------------------------------------------ type timeoutErr struct { diff --git a/ring.go b/ring.go index de1b0a3..0a1069d 100644 --- a/ring.go +++ b/ring.go @@ -219,6 +219,10 @@ type ringSharding struct { hash ConsistentHash numShard int onNewNode []func(rdb *Client) + + // ensures exclusive access to SetAddrs so there is no need + // to hold mu for the duration of potentially long shard creation + setAddrsMu sync.Mutex } type ringShards struct { @@ -245,46 +249,62 @@ func (c *ringSharding) OnNewNode(fn func(rdb *Client)) { // decrease number of shards, that you use. It will reuse shards that // existed before and close the ones that will not be used anymore. func (c *ringSharding) SetAddrs(addrs map[string]string) { - c.mu.Lock() + c.setAddrsMu.Lock() + defer c.setAddrsMu.Unlock() + cleanup := func(shards map[string]*ringShard) { + for addr, shard := range shards { + if err := shard.Client.Close(); err != nil { + internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err) + } + } + } + + c.mu.RLock() if c.closed { + c.mu.RUnlock() + return + } + existing := c.shards + c.mu.RUnlock() + + shards, created, unused := c.newRingShards(addrs, existing) + + c.mu.Lock() + if c.closed { + cleanup(created) c.mu.Unlock() return } - - shards, cleanup := c.newRingShards(addrs, c.shards) c.shards = shards c.rebalanceLocked() c.mu.Unlock() - cleanup() + cleanup(unused) } func (c *ringSharding) newRingShards( - addrs map[string]string, existingShards *ringShards, -) (*ringShards, func()) { - shardMap := make(map[string]*ringShard) // indexed by addr - unusedShards := make(map[string]*ringShard) // indexed by addr + addrs map[string]string, existing *ringShards, +) (shards *ringShards, created, unused map[string]*ringShard) { - if existingShards != nil { - for _, shard := range existingShards.list { - addr := shard.Client.opt.Addr - shardMap[addr] = shard - unusedShards[addr] = shard + shards = &ringShards{m: make(map[string]*ringShard, len(addrs))} + created = make(map[string]*ringShard) // indexed by addr + unused = make(map[string]*ringShard) // indexed by addr + + if existing != nil { + for _, shard := range existing.list { + unused[shard.addr] = shard } } - shards := &ringShards{ - m: make(map[string]*ringShard), - } - for name, addr := range addrs { - if shard, ok := shardMap[addr]; ok { + if shard, ok := unused[addr]; ok { shards.m[name] = shard - delete(unusedShards, addr) + delete(unused, addr) } else { shard := newRingShard(c.opt, addr) shards.m[name] = shard + created[addr] = shard for _, fn := range c.onNewNode { fn(shard.Client) @@ -296,13 +316,7 @@ func (c *ringSharding) newRingShards( shards.list = append(shards.list, shard) } - return shards, func() { - for addr, shard := range unusedShards { - if err := shard.Client.Close(); err != nil { - internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err) - } - } - } + return } func (c *ringSharding) List() []*ringShard { diff --git a/ring_test.go b/ring_test.go index c64e107..ba4fa8a 100644 --- a/ring_test.go +++ b/ring_test.go @@ -7,6 +7,7 @@ import ( "net" "strconv" "sync" + "testing" "time" . "github.com/onsi/ginkgo" @@ -123,7 +124,7 @@ var _ = Describe("Redis Ring", func() { }) Expect(ring.Len(), 1) gotShard := ring.ShardByName("ringShardOne") - Expect(gotShard).To(Equal(wantShard)) + Expect(gotShard).To(BeIdenticalTo(wantShard)) ring.SetAddrs(map[string]string{ "ringShardOne": ":" + ringShard1Port, @@ -131,7 +132,7 @@ var _ = Describe("Redis Ring", func() { }) Expect(ring.Len(), 2) gotShard = ring.ShardByName("ringShardOne") - Expect(gotShard).To(Equal(wantShard)) + Expect(gotShard).To(BeIdenticalTo(wantShard)) }) It("uses 3 shards after setting it to 3 shards", func() { @@ -155,8 +156,8 @@ var _ = Describe("Redis Ring", func() { gotShard1 := ring.ShardByName(shardName1) gotShard2 := ring.ShardByName(shardName2) gotShard3 := ring.ShardByName(shardName3) - Expect(gotShard1).To(Equal(wantShard1)) - Expect(gotShard2).To(Equal(wantShard2)) + Expect(gotShard1).To(BeIdenticalTo(wantShard1)) + Expect(gotShard2).To(BeIdenticalTo(wantShard2)) Expect(gotShard3).ToNot(BeNil()) ring.SetAddrs(map[string]string{ @@ -167,8 +168,8 @@ var _ = Describe("Redis Ring", func() { gotShard1 = ring.ShardByName(shardName1) gotShard2 = ring.ShardByName(shardName2) gotShard3 = ring.ShardByName(shardName3) - Expect(gotShard1).To(Equal(wantShard1)) - Expect(gotShard2).To(Equal(wantShard2)) + Expect(gotShard1).To(BeIdenticalTo(wantShard1)) + Expect(gotShard2).To(BeIdenticalTo(wantShard2)) Expect(gotShard3).To(BeNil()) }) }) @@ -739,3 +740,89 @@ var _ = Describe("Ring Tx timeout", func() { testTimeout() }) }) + +func TestRingSetAddrsContention(t *testing.T) { + const ( + ringShard1Name = "ringShardOne" + ringShard2Name = "ringShardTwo" + ) + + for _, port := range []string{ringShard1Port, ringShard2Port} { + if _, err := startRedis(port); err != nil { + t.Fatal(err) + } + } + + t.Cleanup(func() { + for _, p := range processes { + if err := p.Close(); err != nil { + t.Errorf("Failed to stop redis process: %v", err) + } + } + processes = nil + }) + + ring := redis.NewRing(&redis.RingOptions{ + Addrs: map[string]string{ + "ringShardOne": ":" + ringShard1Port, + }, + NewClient: func(opt *redis.Options) *redis.Client { + // Simulate slow shard creation + time.Sleep(100 * time.Millisecond) + return redis.NewClient(opt) + }, + }) + defer ring.Close() + + if _, err := ring.Ping(context.Background()).Result(); err != nil { + t.Fatal(err) + } + + // Continuously update addresses by adding and removing one address + updatesDone := make(chan struct{}) + defer func() { close(updatesDone) }() + go func() { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for i := 0; ; i++ { + select { + case <-ticker.C: + if i%2 == 0 { + ring.SetAddrs(map[string]string{ + ringShard1Name: ":" + ringShard1Port, + }) + } else { + ring.SetAddrs(map[string]string{ + ringShard1Name: ":" + ringShard1Port, + ringShard2Name: ":" + ringShard2Port, + }) + } + case <-updatesDone: + return + } + } + }() + + var pings, errClosed int + timer := time.NewTimer(1 * time.Second) + for running := true; running; pings++ { + select { + case <-timer.C: + running = false + default: + if _, err := ring.Ping(context.Background()).Result(); err != nil { + if err == redis.ErrClosed { + // The shard client could be closed while ping command is in progress + errClosed++ + } else { + t.Fatal(err) + } + } + } + } + + t.Logf("Number of pings: %d, errClosed: %d", pings, errClosed) + if pings < 10_000 { + t.Errorf("Expected at least 10k pings, got: %d", pings) + } +}