diff --git a/ring.go b/ring.go index 5ea4ae5..444acae 100644 --- a/ring.go +++ b/ring.go @@ -310,11 +310,10 @@ func (c *ringShards) Random() (*ringShard, error) { } // Heartbeat monitors state of each shard in the ring. -func (c *ringShards) Heartbeat(frequency time.Duration, closeCh chan struct{}) { +func (c *ringShards) Heartbeat(ctx context.Context, frequency time.Duration) { ticker := time.NewTicker(frequency) defer ticker.Stop() - ctx := context.Background() for { select { case <-ticker.C: @@ -332,7 +331,7 @@ func (c *ringShards) Heartbeat(frequency time.Duration, closeCh chan struct{}) { if rebalance { c.rebalance() } - case <-closeCh: + case <-ctx.Done(): return } } @@ -392,10 +391,10 @@ func (c *ringShards) Close() error { //------------------------------------------------------------------------------ type ring struct { - opt *RingOptions - shards *ringShards - cmdsInfoCache *cmdsInfoCache //nolint:structcheck - hearbeatCloseSignal chan struct{} + opt *RingOptions + shards *ringShards + cmdsInfoCache *cmdsInfoCache //nolint:structcheck + heartbeatCancelFn context.CancelFunc } // Ring is a Redis client that uses consistent hashing to distribute @@ -421,20 +420,20 @@ type Ring struct { func NewRing(opt *RingOptions) *Ring { opt.init() - hearbeatCloseSignal := make(chan struct{}) + hbCtx, hbCancel := context.WithCancel(context.Background()) ring := Ring{ ring: &ring{ - opt: opt, - shards: newRingShards(opt), - hearbeatCloseSignal: hearbeatCloseSignal, + opt: opt, + shards: newRingShards(opt), + heartbeatCancelFn: hbCancel, }, } ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) ring.cmdable = ring.Process - go ring.shards.Heartbeat(opt.HeartbeatFrequency, hearbeatCloseSignal) + go ring.shards.Heartbeat(hbCtx, opt.HeartbeatFrequency) return &ring } @@ -722,6 +721,7 @@ func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) er // It is rare to Close a Ring, as the Ring is meant to be long-lived // and shared between many goroutines. func (c *Ring) Close() error { - close(c.hearbeatCloseSignal) + c.heartbeatCancelFn() + return c.shards.Close() }