package redis import ( "context" "crypto/tls" "errors" "fmt" "net" "strconv" "sync" "sync/atomic" "time" "github.com/cespare/xxhash/v2" rendezvous "github.com/dgryski/go-rendezvous" //nolint "github.com/go-redis/redis/v9/internal" "github.com/go-redis/redis/v9/internal/hashtag" "github.com/go-redis/redis/v9/internal/pool" "github.com/go-redis/redis/v9/internal/rand" ) var errRingShardsDown = errors.New("redis: all ring shards are down") //------------------------------------------------------------------------------ type ConsistentHash interface { Get(string) string } type rendezvousWrapper struct { *rendezvous.Rendezvous } func (w rendezvousWrapper) Get(key string) string { return w.Lookup(key) } func newRendezvous(shards []string) ConsistentHash { return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)} } //------------------------------------------------------------------------------ // RingOptions are used to configure a ring client and should be // passed to NewRing. type RingOptions struct { // Map of name => host:port addresses of ring shards. Addrs map[string]string // NewClient creates a shard client with provided options. NewClient func(opt *Options) *Client // Frequency of PING commands sent to check shards availability. // Shard is considered down after 3 subsequent failed checks. HeartbeatFrequency time.Duration // NewConsistentHash returns a consistent hash that is used // to distribute keys across the shards. // // See https://medium.com/@dgryski/consistent-hashing-algorithmic-tradeoffs-ef6b8e2fcae8 // for consistent hashing algorithmic tradeoffs. NewConsistentHash func(shards []string) ConsistentHash // Following options are copied from Options struct. Dialer func(ctx context.Context, network, addr string) (net.Conn, error) OnConnect func(ctx context.Context, cn *Conn) error Username string Password string DB int MaxRetries int MinRetryBackoff time.Duration MaxRetryBackoff time.Duration DialTimeout time.Duration ReadTimeout time.Duration WriteTimeout time.Duration // PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO). PoolFIFO bool PoolSize int PoolTimeout time.Duration MinIdleConns int MaxIdleConns int ConnMaxIdleTime time.Duration ConnMaxLifetime time.Duration TLSConfig *tls.Config Limiter Limiter } func (opt *RingOptions) init() { if opt.NewClient == nil { opt.NewClient = func(opt *Options) *Client { return NewClient(opt) } } if opt.HeartbeatFrequency == 0 { opt.HeartbeatFrequency = 500 * time.Millisecond } if opt.NewConsistentHash == nil { opt.NewConsistentHash = newRendezvous } if opt.MaxRetries == -1 { opt.MaxRetries = 0 } else if opt.MaxRetries == 0 { opt.MaxRetries = 3 } switch opt.MinRetryBackoff { case -1: opt.MinRetryBackoff = 0 case 0: opt.MinRetryBackoff = 8 * time.Millisecond } switch opt.MaxRetryBackoff { case -1: opt.MaxRetryBackoff = 0 case 0: opt.MaxRetryBackoff = 512 * time.Millisecond } } func (opt *RingOptions) clientOptions() *Options { return &Options{ Dialer: opt.Dialer, OnConnect: opt.OnConnect, Username: opt.Username, Password: opt.Password, DB: opt.DB, MaxRetries: -1, DialTimeout: opt.DialTimeout, ReadTimeout: opt.ReadTimeout, WriteTimeout: opt.WriteTimeout, PoolFIFO: opt.PoolFIFO, PoolSize: opt.PoolSize, PoolTimeout: opt.PoolTimeout, MinIdleConns: opt.MinIdleConns, MaxIdleConns: opt.MaxIdleConns, ConnMaxIdleTime: opt.ConnMaxIdleTime, ConnMaxLifetime: opt.ConnMaxLifetime, TLSConfig: opt.TLSConfig, Limiter: opt.Limiter, } } //------------------------------------------------------------------------------ type ringShard struct { Client *Client down int32 addr string } func newRingShard(opt *RingOptions, addr string) *ringShard { clopt := opt.clientOptions() clopt.Addr = addr return &ringShard{ Client: opt.NewClient(clopt), addr: addr, } } func (shard *ringShard) String() string { var state string if shard.IsUp() { state = "up" } else { state = "down" } return fmt.Sprintf("%s is %s", shard.Client, state) } func (shard *ringShard) IsDown() bool { const threshold = 3 return atomic.LoadInt32(&shard.down) >= threshold } func (shard *ringShard) IsUp() bool { return !shard.IsDown() } // Vote votes to set shard state and returns true if state was changed. func (shard *ringShard) Vote(up bool) bool { if up { changed := shard.IsDown() atomic.StoreInt32(&shard.down, 0) return changed } if shard.IsDown() { return false } atomic.AddInt32(&shard.down, 1) return shard.IsDown() } //------------------------------------------------------------------------------ type ringSharding struct { opt *RingOptions mu sync.RWMutex shards *ringShards closed bool hash ConsistentHash numShard int onNewNode []func(rdb *Client) } type ringShards struct { m map[string]*ringShard list []*ringShard } func newRingSharding(opt *RingOptions) *ringSharding { c := &ringSharding{ opt: opt, } c.SetAddrs(opt.Addrs) return c } func (c *ringSharding) OnNewNode(fn func(rdb *Client)) { c.mu.Lock() c.onNewNode = append(c.onNewNode, fn) c.mu.Unlock() } // SetAddrs replaces the shards in use, such that you can increase and // decrease number of shards, that you use. It will reuse shards that // existed before and close the ones that will not be used anymore. func (c *ringSharding) SetAddrs(addrs map[string]string) { c.mu.Lock() if c.closed { c.mu.Unlock() return } shards, cleanup := c.newRingShards(addrs, c.shards) c.shards = shards c.mu.Unlock() c.rebalance() cleanup() } func (c *ringSharding) newRingShards( addrs map[string]string, existingShards *ringShards, ) (*ringShards, func()) { shardMap := make(map[string]*ringShard) // indexed by addr unusedShards := make(map[string]*ringShard) // indexed by addr if existingShards != nil { for _, shard := range existingShards.list { addr := shard.Client.opt.Addr shardMap[addr] = shard unusedShards[addr] = shard } } shards := &ringShards{ m: make(map[string]*ringShard), } for name, addr := range addrs { if shard, ok := shardMap[addr]; ok { shards.m[name] = shard delete(unusedShards, addr) } else { shard := newRingShard(c.opt, addr) shards.m[name] = shard for _, fn := range c.onNewNode { fn(shard.Client) } } } for _, shard := range shards.m { shards.list = append(shards.list, shard) } return shards, func() { for addr, shard := range unusedShards { if err := shard.Client.Close(); err != nil { internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err) } } } } func (c *ringSharding) List() []*ringShard { var list []*ringShard c.mu.RLock() if !c.closed { list = c.shards.list } c.mu.RUnlock() return list } func (c *ringSharding) Hash(key string) string { key = hashtag.Key(key) var hash string c.mu.RLock() defer c.mu.RUnlock() if c.numShard > 0 { hash = c.hash.Get(key) } return hash } func (c *ringSharding) GetByKey(key string) (*ringShard, error) { key = hashtag.Key(key) c.mu.RLock() defer c.mu.RUnlock() if c.closed { return nil, pool.ErrClosed } if c.numShard == 0 { return nil, errRingShardsDown } shardName := c.hash.Get(key) if shardName == "" { return nil, errRingShardsDown } return c.shards.m[shardName], nil } func (c *ringSharding) GetByName(shardName string) (*ringShard, error) { if shardName == "" { return c.Random() } c.mu.RLock() defer c.mu.RUnlock() return c.shards.m[shardName], nil } func (c *ringSharding) Random() (*ringShard, error) { return c.GetByKey(strconv.Itoa(rand.Int())) } // Heartbeat monitors state of each shard in the ring. func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) { ticker := time.NewTicker(frequency) defer ticker.Stop() for { select { case <-ticker.C: var rebalance bool for _, shard := range c.List() { err := shard.Client.Ping(ctx).Err() isUp := err == nil || err == pool.ErrPoolTimeout if shard.Vote(isUp) { internal.Logger.Printf(ctx, "ring shard state changed: %s", shard) rebalance = true } } if rebalance { c.rebalance() } case <-ctx.Done(): return } } } // rebalance removes dead shards from the Ring. func (c *ringSharding) rebalance() { c.mu.RLock() shards := c.shards c.mu.RUnlock() if shards == nil { return } liveShards := make([]string, 0, len(shards.m)) for name, shard := range shards.m { if shard.IsUp() { liveShards = append(liveShards, name) } } hash := c.opt.NewConsistentHash(liveShards) c.mu.Lock() if !c.closed { c.hash = hash c.numShard = len(liveShards) } c.mu.Unlock() } func (c *ringSharding) Len() int { c.mu.RLock() defer c.mu.RUnlock() return c.numShard } func (c *ringSharding) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.closed { return nil } c.closed = true var firstErr error for _, shard := range c.shards.list { if err := shard.Client.Close(); err != nil && firstErr == nil { firstErr = err } } c.hash = nil c.shards = nil c.numShard = 0 return firstErr } //------------------------------------------------------------------------------ // Ring is a Redis client that uses consistent hashing to distribute // keys across multiple Redis servers (shards). It's safe for // concurrent use by multiple goroutines. // // Ring monitors the state of each shard and removes dead shards from // the ring. When a shard comes online it is added back to the ring. This // gives you maximum availability and partition tolerance, but no // consistency between different shards or even clients. Each client // uses shards that are available to the client and does not do any // coordination when shard state is changed. // // Ring should be used when you need multiple Redis servers for caching // and can tolerate losing data when one of the servers dies. // Otherwise you should use Redis Cluster. type Ring struct { cmdable hooks opt *RingOptions sharding *ringSharding cmdsInfoCache *cmdsInfoCache heartbeatCancelFn context.CancelFunc } func NewRing(opt *RingOptions) *Ring { opt.init() hbCtx, hbCancel := context.WithCancel(context.Background()) ring := Ring{ opt: opt, sharding: newRingSharding(opt), heartbeatCancelFn: hbCancel, } ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) ring.cmdable = ring.Process ring.hooks.setProcess(ring.process) ring.hooks.setProcessPipeline(func(ctx context.Context, cmds []Cmder) error { return ring.generalProcessPipeline(ctx, cmds, false) }) ring.hooks.setProcessTxPipeline(func(ctx context.Context, cmds []Cmder) error { return ring.generalProcessPipeline(ctx, cmds, true) }) go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency) return &ring } func (c *Ring) SetAddrs(addrs map[string]string) { c.sharding.SetAddrs(addrs) } // Do creates a Cmd from the args and processes the cmd. func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd { cmd := NewCmd(ctx, args...) _ = c.Process(ctx, cmd) return cmd } func (c *Ring) Process(ctx context.Context, cmd Cmder) error { err := c.hooks.process(ctx, cmd) cmd.SetErr(err) return err } // Options returns read-only Options that were used to create the client. func (c *Ring) Options() *RingOptions { return c.opt } func (c *Ring) retryBackoff(attempt int) time.Duration { return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) } // PoolStats returns accumulated connection pool stats. func (c *Ring) PoolStats() *PoolStats { shards := c.sharding.List() var acc PoolStats for _, shard := range shards { s := shard.Client.connPool.Stats() acc.Hits += s.Hits acc.Misses += s.Misses acc.Timeouts += s.Timeouts acc.TotalConns += s.TotalConns acc.IdleConns += s.IdleConns } return &acc } // Len returns the current number of shards in the ring. func (c *Ring) Len() int { return c.sharding.Len() } // Subscribe subscribes the client to the specified channels. func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub { if len(channels) == 0 { panic("at least one channel is required") } shard, err := c.sharding.GetByKey(channels[0]) if err != nil { // TODO: return PubSub with sticky error panic(err) } return shard.Client.Subscribe(ctx, channels...) } // PSubscribe subscribes the client to the given patterns. func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub { if len(channels) == 0 { panic("at least one channel is required") } shard, err := c.sharding.GetByKey(channels[0]) if err != nil { // TODO: return PubSub with sticky error panic(err) } return shard.Client.PSubscribe(ctx, channels...) } // SSubscribe Subscribes the client to the specified shard channels. func (c *Ring) SSubscribe(ctx context.Context, channels ...string) *PubSub { if len(channels) == 0 { panic("at least one channel is required") } shard, err := c.sharding.GetByKey(channels[0]) if err != nil { // TODO: return PubSub with sticky error panic(err) } return shard.Client.SSubscribe(ctx, channels...) } func (c *Ring) OnNewNode(fn func(rdb *Client)) { c.sharding.OnNewNode(fn) } // ForEachShard concurrently calls the fn on each live shard in the ring. // It returns the first error if any. func (c *Ring) ForEachShard( ctx context.Context, fn func(ctx context.Context, client *Client) error, ) error { shards := c.sharding.List() var wg sync.WaitGroup errCh := make(chan error, 1) for _, shard := range shards { if shard.IsDown() { continue } wg.Add(1) go func(shard *ringShard) { defer wg.Done() err := fn(ctx, shard.Client) if err != nil { select { case errCh <- err: default: } } }(shard) } wg.Wait() select { case err := <-errCh: return err default: return nil } } func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) { shards := c.sharding.List() var firstErr error for _, shard := range shards { cmdsInfo, err := shard.Client.Command(ctx).Result() if err == nil { return cmdsInfo, nil } if firstErr == nil { firstErr = err } } if firstErr == nil { return nil, errRingShardsDown } return nil, firstErr } func (c *Ring) cmdInfo(ctx context.Context, name string) *CommandInfo { cmdsInfo, err := c.cmdsInfoCache.Get(ctx) if err != nil { return nil } info := cmdsInfo[name] if info == nil { internal.Logger.Printf(ctx, "info for cmd=%s not found", name) } return info } func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) { cmdInfo := c.cmdInfo(ctx, cmd.Name()) pos := cmdFirstKeyPos(cmd, cmdInfo) if pos == 0 { return c.sharding.Random() } firstKey := cmd.stringArg(pos) return c.sharding.GetByKey(firstKey) } func (c *Ring) process(ctx context.Context, cmd Cmder) error { var lastErr error for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { if attempt > 0 { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { return err } } shard, err := c.cmdShard(ctx, cmd) if err != nil { return err } lastErr = shard.Client.Process(ctx, cmd) if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) { return lastErr } } return lastErr } func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { return c.Pipeline().Pipelined(ctx, fn) } func (c *Ring) Pipeline() Pipeliner { pipe := Pipeline{ exec: pipelineExecer(c.hooks.processPipeline), } pipe.init() return &pipe } func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { return c.TxPipeline().Pipelined(ctx, fn) } func (c *Ring) TxPipeline() Pipeliner { pipe := Pipeline{ exec: func(ctx context.Context, cmds []Cmder) error { cmds = wrapMultiExec(ctx, cmds) return c.hooks.processTxPipeline(ctx, cmds) }, } pipe.init() return &pipe } func (c *Ring) generalProcessPipeline( ctx context.Context, cmds []Cmder, tx bool, ) error { if tx { // Trim multi .. exec. cmds = cmds[1 : len(cmds)-1] } cmdsMap := make(map[string][]Cmder) for _, cmd := range cmds { cmdInfo := c.cmdInfo(ctx, cmd.Name()) hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo)) if hash != "" { hash = c.sharding.Hash(hash) } cmdsMap[hash] = append(cmdsMap[hash], cmd) } var wg sync.WaitGroup for hash, cmds := range cmdsMap { wg.Add(1) go func(hash string, cmds []Cmder) { defer wg.Done() // TODO: retry? shard, err := c.sharding.GetByName(hash) if err != nil { setCmdsErr(cmds, err) return } if tx { cmds = wrapMultiExec(ctx, cmds) shard.Client.hooks.processTxPipeline(ctx, cmds) } else { shard.Client.hooks.processPipeline(ctx, cmds) } }(hash, cmds) } wg.Wait() return cmdsFirstErr(cmds) } func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error { if len(keys) == 0 { return fmt.Errorf("redis: Watch requires at least one key") } var shards []*ringShard for _, key := range keys { if key != "" { shard, err := c.sharding.GetByKey(hashtag.Key(key)) if err != nil { return err } shards = append(shards, shard) } } if len(shards) == 0 { return fmt.Errorf("redis: Watch requires at least one shard") } if len(shards) > 1 { for _, shard := range shards[1:] { if shard.Client != shards[0].Client { err := fmt.Errorf("redis: Watch requires all keys to be in the same shard") return err } } } return shards[0].Client.Watch(ctx, fn, keys...) } // Close closes the ring client, releasing any open resources. // // It is rare to Close a Ring, as the Ring is meant to be long-lived // and shared between many goroutines. func (c *Ring) Close() error { c.heartbeatCancelFn() return c.sharding.Close() }