package redis_test import ( "context" "crypto/rand" "fmt" "net" "strconv" "sync" "testing" "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/go-redis/redis/v9" ) var _ = Describe("Redis Ring", func() { const heartbeat = 100 * time.Millisecond var ring *redis.Ring setRingKeys := func() { for i := 0; i < 100; i++ { err := ring.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err() Expect(err).NotTo(HaveOccurred()) } } BeforeEach(func() { opt := redisRingOptions() opt.HeartbeatFrequency = heartbeat ring = redis.NewRing(opt) err := ring.ForEachShard(ctx, func(ctx context.Context, cl *redis.Client) error { return cl.FlushDB(ctx).Err() }) Expect(err).NotTo(HaveOccurred()) }) AfterEach(func() { Expect(ring.Close()).NotTo(HaveOccurred()) }) It("supports context", func() { ctx, cancel := context.WithCancel(ctx) cancel() err := ring.Ping(ctx).Err() Expect(err).To(MatchError("context canceled")) }) It("distributes keys", func() { setRingKeys() // Both shards should have some keys now. Expect(ringShard1.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=56")) Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=44")) }) It("distributes keys when using EVAL", func() { script := redis.NewScript(` local r = redis.call('SET', KEYS[1], ARGV[1]) return r `) var key string for i := 0; i < 100; i++ { key = fmt.Sprintf("key%d", i) err := script.Run(ctx, ring, []string{key}, "value").Err() Expect(err).NotTo(HaveOccurred()) } Expect(ringShard1.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=56")) Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=44")) }) It("uses single shard when one of the shards is down", func() { // Stop ringShard2. Expect(ringShard2.Close()).NotTo(HaveOccurred()) Eventually(func() int { return ring.Len() }, "30s").Should(Equal(1)) setRingKeys() // RingShard1 should have all keys. Expect(ringShard1.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=100")) // Start ringShard2. var err error ringShard2, err = startRedis(ringShard2Port) Expect(err).NotTo(HaveOccurred()) Eventually(func() int { return ring.Len() }, "30s").Should(Equal(2)) setRingKeys() // RingShard2 should have its keys. Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=44")) }) It("supports hash tags", func() { for i := 0; i < 100; i++ { err := ring.Set(ctx, fmt.Sprintf("key%d{tag}", i), "value", 0).Err() Expect(err).NotTo(HaveOccurred()) } Expect(ringShard1.Info(ctx, "keyspace").Val()).ToNot(ContainSubstring("keys=")) Expect(ringShard2.Info(ctx, "keyspace").Val()).To(ContainSubstring("keys=100")) }) Describe("[new] dynamic setting ring shards", func() { It("downscale shard and check reuse shard, upscale shard and check reuse", func() { Expect(ring.Len(), 2) wantShard := ring.ShardByName("ringShardOne") ring.SetAddrs(map[string]string{ "ringShardOne": ":" + ringShard1Port, }) Expect(ring.Len(), 1) gotShard := ring.ShardByName("ringShardOne") Expect(gotShard).To(BeIdenticalTo(wantShard)) ring.SetAddrs(map[string]string{ "ringShardOne": ":" + ringShard1Port, "ringShardTwo": ":" + ringShard2Port, }) Expect(ring.Len(), 2) gotShard = ring.ShardByName("ringShardOne") Expect(gotShard).To(BeIdenticalTo(wantShard)) }) It("uses 3 shards after setting it to 3 shards", func() { Expect(ring.Len(), 2) shardName1 := "ringShardOne" shardAddr1 := ":" + ringShard1Port wantShard1 := ring.ShardByName(shardName1) shardName2 := "ringShardTwo" shardAddr2 := ":" + ringShard2Port wantShard2 := ring.ShardByName(shardName2) shardName3 := "ringShardThree" shardAddr3 := ":" + ringShard3Port ring.SetAddrs(map[string]string{ shardName1: shardAddr1, shardName2: shardAddr2, shardName3: shardAddr3, }) Expect(ring.Len(), 3) gotShard1 := ring.ShardByName(shardName1) gotShard2 := ring.ShardByName(shardName2) gotShard3 := ring.ShardByName(shardName3) Expect(gotShard1).To(BeIdenticalTo(wantShard1)) Expect(gotShard2).To(BeIdenticalTo(wantShard2)) Expect(gotShard3).ToNot(BeNil()) ring.SetAddrs(map[string]string{ shardName1: shardAddr1, shardName2: shardAddr2, }) Expect(ring.Len(), 2) gotShard1 = ring.ShardByName(shardName1) gotShard2 = ring.ShardByName(shardName2) gotShard3 = ring.ShardByName(shardName3) Expect(gotShard1).To(BeIdenticalTo(wantShard1)) Expect(gotShard2).To(BeIdenticalTo(wantShard2)) Expect(gotShard3).To(BeNil()) }) }) Describe("pipeline", func() { It("doesn't panic closed ring, returns error", func() { pipe := ring.Pipeline() for i := 0; i < 3; i++ { err := pipe.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err() Expect(err).NotTo(HaveOccurred()) } Expect(ring.Close()).NotTo(HaveOccurred()) Expect(func() { _, execErr := pipe.Exec(ctx) Expect(execErr).To(HaveOccurred()) }).NotTo(Panic()) }) It("distributes keys", func() { pipe := ring.Pipeline() for i := 0; i < 100; i++ { err := pipe.Set(ctx, fmt.Sprintf("key%d", i), "value", 0).Err() Expect(err).NotTo(HaveOccurred()) } cmds, err := pipe.Exec(ctx) Expect(err).NotTo(HaveOccurred()) Expect(cmds).To(HaveLen(100)) for _, cmd := range cmds { Expect(cmd.Err()).NotTo(HaveOccurred()) Expect(cmd.(*redis.StatusCmd).Val()).To(Equal("OK")) } // Both shards should have some keys now. Expect(ringShard1.Info(ctx).Val()).To(ContainSubstring("keys=56")) Expect(ringShard2.Info(ctx).Val()).To(ContainSubstring("keys=44")) }) It("is consistent with ring", func() { var keys []string for i := 0; i < 100; i++ { key := make([]byte, 64) _, err := rand.Read(key) Expect(err).NotTo(HaveOccurred()) keys = append(keys, string(key)) } _, err := ring.Pipelined(ctx, func(pipe redis.Pipeliner) error { for _, key := range keys { pipe.Set(ctx, key, "value", 0).Err() } return nil }) Expect(err).NotTo(HaveOccurred()) for _, key := range keys { val, err := ring.Get(ctx, key).Result() Expect(err).NotTo(HaveOccurred()) Expect(val).To(Equal("value")) } }) It("supports hash tags", func() { _, err := ring.Pipelined(ctx, func(pipe redis.Pipeliner) error { for i := 0; i < 100; i++ { pipe.Set(ctx, fmt.Sprintf("key%d{tag}", i), "value", 0).Err() } return nil }) Expect(err).NotTo(HaveOccurred()) Expect(ringShard1.Info(ctx).Val()).ToNot(ContainSubstring("keys=")) Expect(ringShard2.Info(ctx).Val()).To(ContainSubstring("keys=100")) }) }) Describe("new client callback", func() { It("can be initialized with a new client callback", func() { opts := redisRingOptions() opts.NewClient = func(opt *redis.Options) *redis.Client { opt.Username = "username1" opt.Password = "password1" return redis.NewClient(opt) } ring = redis.NewRing(opts) err := ring.Ping(ctx).Err() Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("WRONGPASS")) }) }) Describe("Process hook", func() { BeforeEach(func() { // the health check leads to data race for variable "stack []string". // here, the health check time is set to 72 hours to avoid health check opt := redisRingOptions() opt.HeartbeatFrequency = 72 * time.Hour ring = redis.NewRing(opt) }) It("supports Process hook", func() { err := ring.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) var stack []string ring.AddHook(&hook{ processHook: func(hook redis.ProcessHook) redis.ProcessHook { return func(ctx context.Context, cmd redis.Cmder) error { Expect(cmd.String()).To(Equal("ping: ")) stack = append(stack, "ring.BeforeProcess") err := hook(ctx, cmd) Expect(cmd.String()).To(Equal("ping: PONG")) stack = append(stack, "ring.AfterProcess") return err } }, }) ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error { shard.AddHook(&hook{ processHook: func(hook redis.ProcessHook) redis.ProcessHook { return func(ctx context.Context, cmd redis.Cmder) error { Expect(cmd.String()).To(Equal("ping: ")) stack = append(stack, "shard.BeforeProcess") err := hook(ctx, cmd) Expect(cmd.String()).To(Equal("ping: PONG")) stack = append(stack, "shard.AfterProcess") return err } }, }) return nil }) err = ring.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) Expect(stack).To(Equal([]string{ "ring.BeforeProcess", "shard.BeforeProcess", "shard.AfterProcess", "ring.AfterProcess", })) }) It("supports Pipeline hook", func() { err := ring.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) var stack []string ring.AddHook(&hook{ processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: ")) stack = append(stack, "ring.BeforeProcessPipeline") err := hook(ctx, cmds) Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: PONG")) stack = append(stack, "ring.AfterProcessPipeline") return err } }, }) ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error { shard.AddHook(&hook{ processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: ")) stack = append(stack, "shard.BeforeProcessPipeline") err := hook(ctx, cmds) Expect(cmds).To(HaveLen(1)) Expect(cmds[0].String()).To(Equal("ping: PONG")) stack = append(stack, "shard.AfterProcessPipeline") return err } }, }) return nil }) _, err = ring.Pipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) return nil }) Expect(err).NotTo(HaveOccurred()) Expect(stack).To(Equal([]string{ "ring.BeforeProcessPipeline", "shard.BeforeProcessPipeline", "shard.AfterProcessPipeline", "ring.AfterProcessPipeline", })) }) It("supports TxPipeline hook", func() { err := ring.Ping(ctx).Err() Expect(err).NotTo(HaveOccurred()) var stack []string ring.AddHook(&hook{ processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { defer GinkgoRecover() Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: ")) stack = append(stack, "ring.BeforeProcessPipeline") err := hook(ctx, cmds) Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: PONG")) stack = append(stack, "ring.AfterProcessPipeline") return err } }, }) ring.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error { shard.AddHook(&hook{ processPipelineHook: func(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook { return func(ctx context.Context, cmds []redis.Cmder) error { defer GinkgoRecover() Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: ")) stack = append(stack, "shard.BeforeProcessPipeline") err := hook(ctx, cmds) Expect(cmds).To(HaveLen(3)) Expect(cmds[1].String()).To(Equal("ping: PONG")) stack = append(stack, "shard.AfterProcessPipeline") return err } }, }) return nil }) _, err = ring.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) return nil }) Expect(err).NotTo(HaveOccurred()) Expect(stack).To(Equal([]string{ "ring.BeforeProcessPipeline", "shard.BeforeProcessPipeline", "shard.AfterProcessPipeline", "ring.AfterProcessPipeline", })) }) }) }) var _ = Describe("empty Redis Ring", func() { var ring *redis.Ring BeforeEach(func() { ring = redis.NewRing(&redis.RingOptions{}) }) AfterEach(func() { Expect(ring.Close()).NotTo(HaveOccurred()) }) It("returns an error", func() { err := ring.Ping(ctx).Err() Expect(err).To(MatchError("redis: all ring shards are down")) }) It("pipeline returns an error", func() { _, err := ring.Pipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) return nil }) Expect(err).To(MatchError("redis: all ring shards are down")) }) }) var _ = Describe("Ring watch", func() { const heartbeat = 100 * time.Millisecond var ring *redis.Ring BeforeEach(func() { opt := redisRingOptions() opt.HeartbeatFrequency = heartbeat ring = redis.NewRing(opt) err := ring.ForEachShard(ctx, func(ctx context.Context, cl *redis.Client) error { return cl.FlushDB(ctx).Err() }) Expect(err).NotTo(HaveOccurred()) }) AfterEach(func() { Expect(ring.Close()).NotTo(HaveOccurred()) }) It("should Watch", func() { var incr func(string) error // Transactionally increments key using GET and SET commands. incr = func(key string) error { err := ring.Watch(ctx, func(tx *redis.Tx) error { n, err := tx.Get(ctx, key).Int64() if err != nil && err != redis.Nil { return err } _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Set(ctx, key, strconv.FormatInt(n+1, 10), 0) return nil }) return err }, key) if err == redis.TxFailedErr { return incr(key) } return err } var wg sync.WaitGroup for i := 0; i < 100; i++ { wg.Add(1) go func() { defer GinkgoRecover() defer wg.Done() err := incr("key") Expect(err).NotTo(HaveOccurred()) }() } wg.Wait() n, err := ring.Get(ctx, "key").Int64() Expect(err).NotTo(HaveOccurred()) Expect(n).To(Equal(int64(100))) }) It("should discard", func() { err := ring.Watch(ctx, func(tx *redis.Tx) error { cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Set(ctx, "{shard}key1", "hello1", 0) pipe.Discard() pipe.Set(ctx, "{shard}key2", "hello2", 0) return nil }) Expect(err).NotTo(HaveOccurred()) Expect(cmds).To(HaveLen(1)) return err }, "{shard}key1", "{shard}key2") Expect(err).NotTo(HaveOccurred()) get := ring.Get(ctx, "{shard}key1") Expect(get.Err()).To(Equal(redis.Nil)) Expect(get.Val()).To(Equal("")) get = ring.Get(ctx, "{shard}key2") Expect(get.Err()).NotTo(HaveOccurred()) Expect(get.Val()).To(Equal("hello2")) }) It("returns no error when there are no commands", func() { err := ring.Watch(ctx, func(tx *redis.Tx) error { _, err := tx.TxPipelined(ctx, func(redis.Pipeliner) error { return nil }) return err }, "key") Expect(err).NotTo(HaveOccurred()) v, err := ring.Ping(ctx).Result() Expect(err).NotTo(HaveOccurred()) Expect(v).To(Equal("PONG")) }) It("should exec bulks", func() { const N = 20000 err := ring.Watch(ctx, func(tx *redis.Tx) error { cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { for i := 0; i < N; i++ { pipe.Incr(ctx, "key") } return nil }) Expect(err).NotTo(HaveOccurred()) Expect(len(cmds)).To(Equal(N)) for _, cmd := range cmds { Expect(cmd.Err()).NotTo(HaveOccurred()) } return err }, "key") Expect(err).NotTo(HaveOccurred()) num, err := ring.Get(ctx, "key").Int64() Expect(err).NotTo(HaveOccurred()) Expect(num).To(Equal(int64(N))) }) It("should Watch/Unwatch", func() { var C, N int err := ring.Set(ctx, "key", "0", 0).Err() Expect(err).NotTo(HaveOccurred()) perform(C, func(id int) { for i := 0; i < N; i++ { err := ring.Watch(ctx, func(tx *redis.Tx) error { val, err := tx.Get(ctx, "key").Result() Expect(err).NotTo(HaveOccurred()) Expect(val).NotTo(Equal(redis.Nil)) num, err := strconv.ParseInt(val, 10, 64) Expect(err).NotTo(HaveOccurred()) cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Set(ctx, "key", strconv.FormatInt(num+1, 10), 0) return nil }) Expect(cmds).To(HaveLen(1)) return err }, "key") if err == redis.TxFailedErr { i-- continue } Expect(err).NotTo(HaveOccurred()) } }) val, err := ring.Get(ctx, "key").Int64() Expect(err).NotTo(HaveOccurred()) Expect(val).To(Equal(int64(C * N))) }) It("should close Tx without closing the client", func() { err := ring.Watch(ctx, func(tx *redis.Tx) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) return nil }) return err }, "key") Expect(err).NotTo(HaveOccurred()) Expect(ring.Ping(ctx).Err()).NotTo(HaveOccurred()) }) It("respects max size on multi", func() { //this test checks the number of "pool.conn" //if the health check is performed at the same time //conn will be used, resulting in an abnormal number of "pool.conn". // //redis.NewRing() does not have an option to prohibit health checks. //set a relatively large time here to avoid health checks. opt := redisRingOptions() opt.HeartbeatFrequency = 72 * time.Hour ring = redis.NewRing(opt) perform(1000, func(id int) { var ping *redis.StatusCmd err := ring.Watch(ctx, func(tx *redis.Tx) error { cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { ping = pipe.Ping(ctx) return nil }) Expect(err).NotTo(HaveOccurred()) Expect(cmds).To(HaveLen(1)) return err }, "key") Expect(err).NotTo(HaveOccurred()) Expect(ping.Err()).NotTo(HaveOccurred()) Expect(ping.Val()).To(Equal("PONG")) }) ring.ForEachShard(ctx, func(ctx context.Context, cl *redis.Client) error { defer GinkgoRecover() pool := cl.Pool() Expect(pool.Len()).To(BeNumerically("<=", 10)) Expect(pool.IdleLen()).To(BeNumerically("<=", 10)) Expect(pool.Len()).To(Equal(pool.IdleLen())) return nil }) }) }) var _ = Describe("Ring Tx timeout", func() { const heartbeat = 100 * time.Millisecond var ring *redis.Ring AfterEach(func() { _ = ring.Close() }) testTimeout := func() { It("Tx timeouts", func() { err := ring.Watch(ctx, func(tx *redis.Tx) error { return tx.Ping(ctx).Err() }, "foo") Expect(err).To(HaveOccurred()) Expect(err.(net.Error).Timeout()).To(BeTrue()) }) It("Tx Pipeline timeouts", func() { err := ring.Watch(ctx, func(tx *redis.Tx) error { _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Ping(ctx) return nil }) return err }, "foo") Expect(err).To(HaveOccurred()) Expect(err.(net.Error).Timeout()).To(BeTrue()) }) } const pause = 5 * time.Second Context("read/write timeout", func() { BeforeEach(func() { opt := redisRingOptions() opt.ReadTimeout = 250 * time.Millisecond opt.WriteTimeout = 250 * time.Millisecond opt.HeartbeatFrequency = heartbeat ring = redis.NewRing(opt) err := ring.ForEachShard(ctx, func(ctx context.Context, client *redis.Client) error { return client.ClientPause(ctx, pause).Err() }) Expect(err).NotTo(HaveOccurred()) }) AfterEach(func() { _ = ring.ForEachShard(ctx, func(ctx context.Context, client *redis.Client) error { defer GinkgoRecover() Eventually(func() error { return client.Ping(ctx).Err() }, 2*pause).ShouldNot(HaveOccurred()) return nil }) }) testTimeout() }) }) func TestRingSetAddrsContention(t *testing.T) { const ( ringShard1Name = "ringShardOne" ringShard2Name = "ringShardTwo" ) for _, port := range []string{ringShard1Port, ringShard2Port} { if _, err := startRedis(port); err != nil { t.Fatal(err) } } t.Cleanup(func() { for _, p := range processes { if err := p.Close(); err != nil { t.Errorf("Failed to stop redis process: %v", err) } } processes = nil }) ring := redis.NewRing(&redis.RingOptions{ Addrs: map[string]string{ "ringShardOne": ":" + ringShard1Port, }, NewClient: func(opt *redis.Options) *redis.Client { // Simulate slow shard creation time.Sleep(100 * time.Millisecond) return redis.NewClient(opt) }, }) if _, err := ring.Ping(context.Background()).Result(); err != nil { t.Fatal(err) } // Continuously update addresses by adding and removing one address updatesDone := make(chan struct{}) defer func() { close(updatesDone) }() go func() { ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for i := 0; ; i++ { select { case <-ticker.C: if i%2 == 0 { ring.SetAddrs(map[string]string{ ringShard1Name: ":" + ringShard1Port, }) } else { ring.SetAddrs(map[string]string{ ringShard1Name: ":" + ringShard1Port, ringShard2Name: ":" + ringShard2Port, }) } case <-updatesDone: return } } }() var pings, errClosed int timer := time.NewTimer(1 * time.Second) for running := true; running; pings++ { select { case <-timer.C: running = false default: if _, err := ring.Ping(context.Background()).Result(); err != nil { if err == redis.ErrClosed { // The shard client could be closed while ping command is in progress errClosed++ } else { t.Fatal(err) } } } } t.Logf("Number of pings: %d, errClosed: %d", pings, errClosed) if pings < 10_000 { t.Errorf("Expected at least 10k pings, got: %d", pings) } }