diff --git a/ring_test.go b/ring_test.go index c64e107..eb03c32 100644 --- a/ring_test.go +++ b/ring_test.go @@ -7,6 +7,7 @@ import ( "net" "strconv" "sync" + "testing" "time" . "github.com/onsi/ginkgo" @@ -739,3 +740,88 @@ var _ = Describe("Ring Tx timeout", func() { testTimeout() }) }) + +func TestRingSetAddrsContention(t *testing.T) { + const ( + ringShard1Name = "ringShardOne" + ringShard2Name = "ringShardTwo" + ) + + for _, port := range []string{ringShard1Port, ringShard2Port} { + if _, err := startRedis(port); err != nil { + t.Fatal(err) + } + } + + t.Cleanup(func() { + for _, p := range processes { + if err := p.Close(); err != nil { + t.Errorf("Failed to stop redis process: %v", err) + } + } + processes = nil + }) + + ring := redis.NewRing(&redis.RingOptions{ + Addrs: map[string]string{ + "ringShardOne": ":" + ringShard1Port, + }, + NewClient: func(opt *redis.Options) *redis.Client { + // Simulate slow shard creation + time.Sleep(100 * time.Millisecond) + return redis.NewClient(opt) + }, + }) + + if _, err := ring.Ping(context.Background()).Result(); err != nil { + t.Fatal(err) + } + + // Continuously update addresses by adding and removing one address + updatesDone := make(chan struct{}) + defer func() { close(updatesDone) }() + go func() { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for i := 0; ; i++ { + select { + case <-ticker.C: + if i%2 == 0 { + ring.SetAddrs(map[string]string{ + ringShard1Name: ":" + ringShard1Port, + }) + } else { + ring.SetAddrs(map[string]string{ + ringShard1Name: ":" + ringShard1Port, + ringShard2Name: ":" + ringShard2Port, + }) + } + case <-updatesDone: + return + } + } + }() + + var pings, errClosed int + timer := time.NewTimer(1 * time.Second) + for running := true; running; pings++ { + select { + case <-timer.C: + running = false + default: + if _, err := ring.Ping(context.Background()).Result(); err != nil { + if err == redis.ErrClosed { + // The shard client could be closed while ping command is in progress + errClosed++ + } else { + t.Fatal(err) + } + } + } + } + + t.Logf("Number of pings: %d, errClosed: %d", pings, errClosed) + if pings < 10_000 { + t.Errorf("Expected at least 10k pings, got: %d", pings) + } +}