From 59798f9dbae5081018b69c2ddf852b9da6d4833d Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Wed, 17 Aug 2022 13:18:58 +0300 Subject: [PATCH] chore: cleanup --- export_test.go | 8 +- ring.go | 224 ++++++++++++++++++++++++------------------------- ring_test.go | 34 ++++---- 3 files changed, 130 insertions(+), 136 deletions(-) diff --git a/export_test.go b/export_test.go index cae7faa..f259fca 100644 --- a/export_test.go +++ b/export_test.go @@ -94,10 +94,10 @@ func GetSlavesAddrByName(ctx context.Context, c *SentinelClient, name string) [] return parseReplicaAddrs(addrs, false) } -func (c *Ring) GetAddr(addr string) *ringShard { - return c.shards.GetAddr(addr) +func (c *Ring) ShardByName(name string) *ringShard { + return c.sharding.ShardByName(name) } -func (c *ringShards) GetAddr(addr string) *ringShard { - return c.shards[addr] +func (c *ringSharding) ShardByName(name string) *ringShard { + return c.shards.m[name] } diff --git a/ring.go b/ring.go index 65c7ce8..a8e08df 100644 --- a/ring.go +++ b/ring.go @@ -48,8 +48,8 @@ type RingOptions struct { // Map of name => host:port addresses of ring shards. Addrs map[string]string - // NewClient creates a shard client with provided name and options. - NewClient func(name string, opt *Options) *Client + // NewClient creates a shard client with provided options. + NewClient func(opt *Options) *Client // Frequency of PING commands sent to check shards availability. // Shard is considered down after 3 subsequent failed checks. @@ -95,7 +95,7 @@ type RingOptions struct { func (opt *RingOptions) init() { if opt.NewClient == nil { - opt.NewClient = func(name string, opt *Options) *Client { + opt.NewClient = func(opt *Options) *Client { return NewClient(opt) } } @@ -163,12 +163,12 @@ type ringShard struct { addr string } -func newRingShard(opt *RingOptions, name, addr string) *ringShard { +func newRingShard(opt *RingOptions, addr string) *ringShard { clopt := opt.clientOptions() clopt.Addr = addr return &ringShard{ - Client: opt.NewClient(name, clopt), + Client: opt.NewClient(clopt), addr: addr, } } @@ -210,20 +210,23 @@ func (shard *ringShard) Vote(up bool) bool { //------------------------------------------------------------------------------ -type ringShards struct { +type ringSharding struct { opt *RingOptions mu sync.RWMutex - muClose sync.Mutex - hash ConsistentHash - shards map[string]*ringShard // read only, updated by SetAddrs - list []*ringShard // read only, updated by SetAddrs - numShard int + shards *ringShards closed bool + hash ConsistentHash + numShard int } -func newRingShards(opt *RingOptions) *ringShards { - c := &ringShards{ +type ringShards struct { + m map[string]*ringShard + list []*ringShard +} + +func newRingSharding(opt *RingOptions) *ringSharding { + c := &ringSharding{ opt: opt, } c.SetAddrs(opt.Addrs) @@ -234,63 +237,75 @@ func newRingShards(opt *RingOptions) *ringShards { // SetAddrs replaces the shards in use, such that you can increase and // decrease number of shards, that you use. It will reuse shards that // existed before and close the ones that will not be used anymore. -func (c *ringShards) SetAddrs(addrs map[string]string) { - c.muClose.Lock() - defer c.muClose.Unlock() +func (c *ringSharding) SetAddrs(addrs map[string]string) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() return } - shards := make(map[string]*ringShard) - unusedShards := make(map[string]*ringShard) - - for k, shard := range c.shards { - if addr, ok := addrs[k]; ok && shard.addr == addr { - shards[k] = shard - } else { - unusedShards[k] = shard - } - } - - for k, addr := range addrs { - if shard, ok := c.shards[k]; !ok || shard.addr != addr { - shards[k] = newRingShard(c.opt, k, addr) - } - } - - list := make([]*ringShard, 0, len(shards)) - for _, shard := range shards { - list = append(list, shard) - } - - c.mu.Lock() + shards, cleanup := newRingShards(c.opt, addrs, c.shards) c.shards = shards - c.list = list - - c.rebalanceLocked() c.mu.Unlock() - for k, shard := range unusedShards { - err := shard.Client.Close() - if err != nil { - internal.Logger.Printf(context.Background(), "Failed to close ring shard client %s %s: %v", k, shard.addr, err) + c.rebalance() + cleanup() +} + +func newRingShards( + opt *RingOptions, addrs map[string]string, existingShards *ringShards, +) (*ringShards, func()) { + shardMap := make(map[string]*ringShard) // indexed by addr + unusedShards := make(map[string]*ringShard) // indexed by addr + + if existingShards != nil { + for _, shard := range existingShards.list { + addr := shard.Client.opt.Addr + shardMap[addr] = shard + unusedShards[addr] = shard + } + } + + shards := &ringShards{ + m: make(map[string]*ringShard), + } + + for name, addr := range addrs { + if shard, ok := shardMap[addr]; ok { + shards.m[name] = shard + delete(unusedShards, addr) + } else { + shards.m[name] = newRingShard(opt, addr) + } + } + + for _, shard := range shards.m { + shards.list = append(shards.list, shard) + } + + return shards, func() { + for addr, shard := range unusedShards { + if err := shard.Client.Close(); err != nil { + internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err) + } } } } -func (c *ringShards) List() []*ringShard { +func (c *ringSharding) List() []*ringShard { var list []*ringShard c.mu.RLock() if !c.closed { - list = c.list + list = c.shards.list } c.mu.RUnlock() return list } -func (c *ringShards) Hash(key string) string { +func (c *ringSharding) Hash(key string) string { key = hashtag.Key(key) var hash string @@ -305,7 +320,7 @@ func (c *ringShards) Hash(key string) string { return hash } -func (c *ringShards) GetByKey(key string) (*ringShard, error) { +func (c *ringSharding) GetByKey(key string) (*ringShard, error) { key = hashtag.Key(key) c.mu.RLock() @@ -319,15 +334,14 @@ func (c *ringShards) GetByKey(key string) (*ringShard, error) { return nil, errRingShardsDown } - hash := c.hash.Get(key) - if hash == "" { + shardName := c.hash.Get(key) + if shardName == "" { return nil, errRingShardsDown } - - return c.shards[hash], nil + return c.shards.m[shardName], nil } -func (c *ringShards) GetByName(shardName string) (*ringShard, error) { +func (c *ringSharding) GetByName(shardName string) (*ringShard, error) { if shardName == "" { return c.Random() } @@ -335,15 +349,15 @@ func (c *ringShards) GetByName(shardName string) (*ringShard, error) { c.mu.RLock() defer c.mu.RUnlock() - return c.shards[shardName], nil + return c.shards.m[shardName], nil } -func (c *ringShards) Random() (*ringShard, error) { +func (c *ringSharding) Random() (*ringShard, error) { return c.GetByKey(strconv.Itoa(rand.Int())) } // Heartbeat monitors state of each shard in the ring. -func (c *ringShards) Heartbeat(ctx context.Context, frequency time.Duration) { +func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) { ticker := time.NewTicker(frequency) defer ticker.Stop() @@ -371,14 +385,18 @@ func (c *ringShards) Heartbeat(ctx context.Context, frequency time.Duration) { } // rebalance removes dead shards from the Ring. -func (c *ringShards) rebalance() { +func (c *ringSharding) rebalance() { c.mu.RLock() shards := c.shards c.mu.RUnlock() - liveShards := make([]string, 0, len(shards)) + if shards == nil { + return + } - for name, shard := range shards { + liveShards := make([]string, 0, len(shards.m)) + + for name, shard := range shards.m { if shard.IsUp() { liveShards = append(liveShards, name) } @@ -387,38 +405,21 @@ func (c *ringShards) rebalance() { hash := c.opt.NewConsistentHash(liveShards) c.mu.Lock() - c.hash = hash - c.numShard = len(liveShards) + if !c.closed { + c.hash = hash + c.numShard = len(liveShards) + } c.mu.Unlock() } -// rebalanceLocked removes dead shards from the Ring and callers need to hold the locl -func (c *ringShards) rebalanceLocked() { - shards := c.shards - liveShards := make([]string, 0, len(shards)) - - for name, shard := range shards { - if shard.IsUp() { - liveShards = append(liveShards, name) - } - } - - hash := c.opt.NewConsistentHash(liveShards) - - c.hash = hash - c.numShard = len(liveShards) -} - -func (c *ringShards) Len() int { +func (c *ringSharding) Len() int { c.mu.RLock() defer c.mu.RUnlock() return c.numShard } -func (c *ringShards) Close() error { - c.muClose.Lock() - defer c.muClose.Unlock() +func (c *ringSharding) Close() error { c.mu.Lock() defer c.mu.Unlock() @@ -428,7 +429,8 @@ func (c *ringShards) Close() error { c.closed = true var firstErr error - for _, shard := range c.shards { + + for _, shard := range c.shards.list { if err := shard.Client.Close(); err != nil && firstErr == nil { firstErr = err } @@ -437,20 +439,12 @@ func (c *ringShards) Close() error { c.hash = nil c.shards = nil c.numShard = 0 - c.list = nil return firstErr } //------------------------------------------------------------------------------ -type ring struct { - opt *RingOptions - shards *ringShards - cmdsInfoCache *cmdsInfoCache //nolint:structcheck - heartbeatCancelFn context.CancelFunc -} - // Ring is a Redis client that uses consistent hashing to distribute // keys across multiple Redis servers (shards). It's safe for // concurrent use by multiple goroutines. @@ -466,7 +460,11 @@ type ring struct { // and can tolerate losing data when one of the servers dies. // Otherwise you should use Redis Cluster. type Ring struct { - *ring + opt *RingOptions + sharding *ringSharding + cmdsInfoCache *cmdsInfoCache + heartbeatCancelFn context.CancelFunc + cmdable hooks } @@ -477,23 +475,21 @@ func NewRing(opt *RingOptions) *Ring { hbCtx, hbCancel := context.WithCancel(context.Background()) ring := Ring{ - ring: &ring{ - opt: opt, - shards: newRingShards(opt), - heartbeatCancelFn: hbCancel, - }, + opt: opt, + sharding: newRingSharding(opt), + heartbeatCancelFn: hbCancel, } ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) ring.cmdable = ring.Process - go ring.shards.Heartbeat(hbCtx, opt.HeartbeatFrequency) + go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency) return &ring } -func (c *Ring) SetAddrs(ctx context.Context, addrs map[string]string) { - c.shards.SetAddrs(addrs) +func (c *Ring) SetAddrs(addrs map[string]string) { + c.sharding.SetAddrs(addrs) } // Do creates a Cmd from the args and processes the cmd. @@ -518,7 +514,7 @@ func (c *Ring) retryBackoff(attempt int) time.Duration { // PoolStats returns accumulated connection pool stats. func (c *Ring) PoolStats() *PoolStats { - shards := c.shards.List() + shards := c.sharding.List() var acc PoolStats for _, shard := range shards { s := shard.Client.connPool.Stats() @@ -533,7 +529,7 @@ func (c *Ring) PoolStats() *PoolStats { // Len returns the current number of shards in the ring. func (c *Ring) Len() int { - return c.shards.Len() + return c.sharding.Len() } // Subscribe subscribes the client to the specified channels. @@ -542,7 +538,7 @@ func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub { panic("at least one channel is required") } - shard, err := c.shards.GetByKey(channels[0]) + shard, err := c.sharding.GetByKey(channels[0]) if err != nil { // TODO: return PubSub with sticky error panic(err) @@ -556,7 +552,7 @@ func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub { panic("at least one channel is required") } - shard, err := c.shards.GetByKey(channels[0]) + shard, err := c.sharding.GetByKey(channels[0]) if err != nil { // TODO: return PubSub with sticky error panic(err) @@ -570,7 +566,7 @@ func (c *Ring) ForEachShard( ctx context.Context, fn func(ctx context.Context, client *Client) error, ) error { - shards := c.shards.List() + shards := c.sharding.List() var wg sync.WaitGroup errCh := make(chan error, 1) for _, shard := range shards { @@ -601,7 +597,7 @@ func (c *Ring) ForEachShard( } func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) { - shards := c.shards.List() + shards := c.sharding.List() var firstErr error for _, shard := range shards { cmdsInfo, err := shard.Client.Command(ctx).Result() @@ -634,10 +630,10 @@ func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) { cmdInfo := c.cmdInfo(ctx, cmd.Name()) pos := cmdFirstKeyPos(cmd, cmdInfo) if pos == 0 { - return c.shards.Random() + return c.sharding.Random() } firstKey := cmd.stringArg(pos) - return c.shards.GetByKey(firstKey) + return c.sharding.GetByKey(firstKey) } func (c *Ring) process(ctx context.Context, cmd Cmder) error { @@ -706,7 +702,7 @@ func (c *Ring) generalProcessPipeline( cmdInfo := c.cmdInfo(ctx, cmd.Name()) hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo)) if hash != "" { - hash = c.shards.Hash(hash) + hash = c.sharding.Hash(hash) } cmdsMap[hash] = append(cmdsMap[hash], cmd) } @@ -729,7 +725,7 @@ func (c *Ring) processShardPipeline( ctx context.Context, hash string, cmds []Cmder, tx bool, ) error { // TODO: retry? - shard, err := c.shards.GetByName(hash) + shard, err := c.sharding.GetByName(hash) if err != nil { setCmdsErr(cmds, err) return err @@ -749,7 +745,7 @@ func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) er var shards []*ringShard for _, key := range keys { if key != "" { - shard, err := c.shards.GetByKey(hashtag.Key(key)) + shard, err := c.sharding.GetByKey(hashtag.Key(key)) if err != nil { return err } @@ -781,5 +777,5 @@ func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) er func (c *Ring) Close() error { c.heartbeatCancelFn() - return c.shards.Close() + return c.sharding.Close() } diff --git a/ring_test.go b/ring_test.go index d804062..ed4d3ba 100644 --- a/ring_test.go +++ b/ring_test.go @@ -117,26 +117,25 @@ var _ = Describe("Redis Ring", func() { It("downscale shard and check reuse shard, upscale shard and check reuse", func() { Expect(ring.Len(), 2) - wantShard := ring.GetAddr("ringShardOne") + wantShard := ring.ShardByName("ringShardOne") ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - ring.SetAddrs(ctx, map[string]string{ + ring.SetAddrs(map[string]string{ "ringShardOne": ":" + ringShard1Port, }) Expect(ring.Len(), 1) - gotShard := ring.GetAddr("ringShardOne") + gotShard := ring.ShardByName("ringShardOne") Expect(gotShard).To(Equal(wantShard)) ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - ring.SetAddrs(ctx, map[string]string{ + ring.SetAddrs(map[string]string{ "ringShardOne": ":" + ringShard1Port, "ringShardTwo": ":" + ringShard2Port, }) Expect(ring.Len(), 2) - gotShard = ring.GetAddr("ringShardOne") + gotShard = ring.ShardByName("ringShardOne") Expect(gotShard).To(Equal(wantShard)) - }) It("uses 3 shards after setting it to 3 shards", func() { @@ -149,42 +148,41 @@ var _ = Describe("Redis Ring", func() { shardName1 := "ringShardOne" shardAddr1 := ":" + ringShard1Port - wantShard1 := ring.GetAddr(shardName1) + wantShard1 := ring.ShardByName(shardName1) shardName2 := "ringShardTwo" shardAddr2 := ":" + ringShard2Port - wantShard2 := ring.GetAddr(shardName2) + wantShard2 := ring.ShardByName(shardName2) shardName3 := "ringShardThree" shardAddr3 := ":" + ringShard3Port ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - ring.SetAddrs(ctx, map[string]string{ + ring.SetAddrs(map[string]string{ shardName1: shardAddr1, shardName2: shardAddr2, shardName3: shardAddr3, }) Expect(ring.Len(), 3) - gotShard1 := ring.GetAddr(shardName1) - gotShard2 := ring.GetAddr(shardName2) - gotShard3 := ring.GetAddr(shardName3) + gotShard1 := ring.ShardByName(shardName1) + gotShard2 := ring.ShardByName(shardName2) + gotShard3 := ring.ShardByName(shardName3) Expect(gotShard1).To(Equal(wantShard1)) Expect(gotShard2).To(Equal(wantShard2)) Expect(gotShard3).ToNot(BeNil()) ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - ring.SetAddrs(ctx, map[string]string{ + ring.SetAddrs(map[string]string{ shardName1: shardAddr1, shardName2: shardAddr2, }) Expect(ring.Len(), 2) - gotShard1 = ring.GetAddr(shardName1) - gotShard2 = ring.GetAddr(shardName2) - gotShard3 = ring.GetAddr(shardName3) + gotShard1 = ring.ShardByName(shardName1) + gotShard2 = ring.ShardByName(shardName2) + gotShard3 = ring.ShardByName(shardName3) Expect(gotShard1).To(Equal(wantShard1)) Expect(gotShard2).To(Equal(wantShard2)) Expect(gotShard3).To(BeNil()) }) - }) Describe("pipeline", func() { It("doesn't panic closed ring, returns error", func() { @@ -263,7 +261,7 @@ var _ = Describe("Redis Ring", func() { Describe("new client callback", func() { It("can be initialized with a new client callback", func() { opts := redisRingOptions() - opts.NewClient = func(name string, opt *redis.Options) *redis.Client { + opts.NewClient = func(opt *redis.Options) *redis.Client { opt.Username = "username1" opt.Password = "password1" return redis.NewClient(opt)