diff --git a/export_test.go b/export_test.go index f259fca0..0d032351 100644 --- a/export_test.go +++ b/export_test.go @@ -95,9 +95,6 @@ func GetSlavesAddrByName(ctx context.Context, c *SentinelClient, name string) [] } func (c *Ring) ShardByName(name string) *ringShard { - return c.sharding.ShardByName(name) -} - -func (c *ringSharding) ShardByName(name string) *ringShard { - return c.shards.m[name] + shard, _ := c.sharding.GetByName(name) + return shard } diff --git a/internal_test.go b/internal_test.go index b1dd0bdd..fcf1235b 100644 --- a/internal_test.go +++ b/internal_test.go @@ -1,6 +1,10 @@ package redis import ( + "fmt" + "testing" + "time" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -65,3 +69,92 @@ var _ = Describe("newClusterState", func() { }) }) }) + +type fixedHash string + +func (h fixedHash) Get(string) string { + return string(h) +} + +func TestRingSetAddrsAndRebalanceRace(t *testing.T) { + const ( + ringShard1Name = "ringShardOne" + ringShard2Name = "ringShardTwo" + + ringShard1Port = "6390" + ringShard2Port = "6391" + ) + + ring := NewRing(&RingOptions{ + Addrs: map[string]string{ + ringShard1Name: ":" + ringShard1Port, + }, + // Disable heartbeat + HeartbeatFrequency: 1 * time.Hour, + NewConsistentHash: func(shards []string) ConsistentHash { + switch len(shards) { + case 1: + return fixedHash(ringShard1Name) + case 2: + return fixedHash(ringShard2Name) + default: + t.Fatalf("Unexpected number of shards: %v", shards) + return nil + } + }, + }) + + // Continuously update addresses by adding and removing one address + updatesDone := make(chan struct{}) + defer func() { close(updatesDone) }() + go func() { + for i := 0; ; i++ { + select { + case <-updatesDone: + return + default: + if i%2 == 0 { + ring.SetAddrs(map[string]string{ + ringShard1Name: ":" + ringShard1Port, + }) + } else { + ring.SetAddrs(map[string]string{ + ringShard1Name: ":" + ringShard1Port, + ringShard2Name: ":" + ringShard2Port, + }) + } + } + } + }() + + timer := time.NewTimer(1 * time.Second) + for running := true; running; { + select { + case <-timer.C: + running = false + default: + shard, err := ring.sharding.GetByKey("whatever") + if err == nil && shard == nil { + t.Fatal("shard is nil") + } + } + } +} + +func BenchmarkRingShardingRebalanceLocked(b *testing.B) { + opts := &RingOptions{ + Addrs: make(map[string]string), + // Disable heartbeat + HeartbeatFrequency: 1 * time.Hour, + } + for i := 0; i < 100; i++ { + opts.Addrs[fmt.Sprintf("shard%d", i)] = fmt.Sprintf(":63%02d", i) + } + + ring := NewRing(opts) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ring.sharding.rebalanceLocked() + } +} diff --git a/ring.go b/ring.go index fee2fe89..de1b0a32 100644 --- a/ring.go +++ b/ring.go @@ -254,9 +254,9 @@ func (c *ringSharding) SetAddrs(addrs map[string]string) { shards, cleanup := c.newRingShards(addrs, c.shards) c.shards = shards + c.rebalanceLocked() c.mu.Unlock() - c.rebalance() cleanup() } @@ -388,7 +388,9 @@ func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) { } if rebalance { - c.rebalance() + c.mu.Lock() + c.rebalanceLocked() + c.mu.Unlock() } case <-ctx.Done(): return @@ -396,32 +398,26 @@ func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) { } } -// rebalance removes dead shards from the Ring. -func (c *ringSharding) rebalance() { - c.mu.RLock() - shards := c.shards - c.mu.RUnlock() - - if shards == nil { +// rebalanceLocked removes dead shards from the Ring. +// Requires c.mu locked. +func (c *ringSharding) rebalanceLocked() { + if c.closed { + return + } + if c.shards == nil { return } - liveShards := make([]string, 0, len(shards.m)) + liveShards := make([]string, 0, len(c.shards.m)) - for name, shard := range shards.m { + for name, shard := range c.shards.m { if shard.IsUp() { liveShards = append(liveShards, name) } } - hash := c.opt.NewConsistentHash(liveShards) - - c.mu.Lock() - if !c.closed { - c.hash = hash - c.numShard = len(liveShards) - } - c.mu.Unlock() + c.hash = c.opt.NewConsistentHash(liveShards) + c.numShard = len(liveShards) } func (c *ringSharding) Len() int {