diff --git a/internal_test.go b/internal_test.go index 9f53ae6..4d42f4b 100644 --- a/internal_test.go +++ b/internal_test.go @@ -3,6 +3,9 @@ package redis import ( "context" "fmt" + "reflect" + "sync" + "sync/atomic" "testing" "time" @@ -107,6 +110,7 @@ func TestRingSetAddrsAndRebalanceRace(t *testing.T) { } }, }) + defer ring.Close() // Continuously update addresses by adding and removing one address updatesDone := make(chan struct{}) @@ -156,6 +160,7 @@ func BenchmarkRingShardingRebalanceLocked(b *testing.B) { } ring := NewRing(opts) + defer ring.Close() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -163,6 +168,119 @@ func BenchmarkRingShardingRebalanceLocked(b *testing.B) { } } +type testCounter struct { + mu sync.Mutex + t *testing.T + m map[string]int +} + +func newTestCounter(t *testing.T) *testCounter { + return &testCounter{t: t, m: make(map[string]int)} +} + +func (ct *testCounter) increment(key string) { + ct.mu.Lock() + defer ct.mu.Unlock() + ct.m[key]++ +} + +func (ct *testCounter) expect(values map[string]int) { + ct.mu.Lock() + defer ct.mu.Unlock() + ct.t.Helper() + if !reflect.DeepEqual(values, ct.m) { + ct.t.Errorf("expected %v != actual %v", values, ct.m) + } +} + +func TestRingShardsCleanup(t *testing.T) { + const ( + ringShard1Name = "ringShardOne" + ringShard2Name = "ringShardTwo" + + ringShard1Addr = "shard1.test" + ringShard2Addr = "shard2.test" + ) + + t.Run("closes unused shards", func(t *testing.T) { + closeCounter := newTestCounter(t) + + ring := NewRing(&RingOptions{ + Addrs: map[string]string{ + ringShard1Name: ringShard1Addr, + ringShard2Name: ringShard2Addr, + }, + NewClient: func(opt *Options) *Client { + c := NewClient(opt) + c.baseClient.onClose = func() error { + closeCounter.increment(opt.Addr) + return nil + } + return c + }, + }) + closeCounter.expect(map[string]int{}) + + // no change due to the same addresses + ring.SetAddrs(map[string]string{ + ringShard1Name: ringShard1Addr, + ringShard2Name: ringShard2Addr, + }) + closeCounter.expect(map[string]int{}) + + ring.SetAddrs(map[string]string{ + ringShard1Name: ringShard1Addr, + }) + closeCounter.expect(map[string]int{ringShard2Addr: 1}) + + ring.SetAddrs(map[string]string{ + ringShard2Name: ringShard2Addr, + }) + closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1}) + + ring.Close() + closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 2}) + }) + + t.Run("closes created shards if ring was closed", func(t *testing.T) { + createCounter := newTestCounter(t) + closeCounter := newTestCounter(t) + + var ( + ring *Ring + shouldClose int32 + ) + + ring = NewRing(&RingOptions{ + Addrs: map[string]string{ + ringShard1Name: ringShard1Addr, + }, + NewClient: func(opt *Options) *Client { + if atomic.LoadInt32(&shouldClose) != 0 { + ring.Close() + } + createCounter.increment(opt.Addr) + c := NewClient(opt) + c.baseClient.onClose = func() error { + closeCounter.increment(opt.Addr) + return nil + } + return c + }, + }) + createCounter.expect(map[string]int{ringShard1Addr: 1}) + closeCounter.expect(map[string]int{}) + + atomic.StoreInt32(&shouldClose, 1) + + ring.SetAddrs(map[string]string{ + ringShard2Name: ringShard2Addr, + }) + createCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1}) + closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1}) + }) +} + //------------------------------------------------------------------------------ type timeoutErr struct { diff --git a/ring_test.go b/ring_test.go index c76a783..ba4fa8a 100644 --- a/ring_test.go +++ b/ring_test.go @@ -772,6 +772,7 @@ func TestRingSetAddrsContention(t *testing.T) { return redis.NewClient(opt) }, }) + defer ring.Close() if _, err := ring.Ping(context.Background()).Result(); err != nil { t.Fatal(err)