package redis import ( "context" "fmt" "reflect" "sync" "sync/atomic" "testing" "time" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("newClusterState", func() { var state *clusterState createClusterState := func(slots []ClusterSlot) *clusterState { opt := &ClusterOptions{} opt.init() nodes := newClusterNodes(opt) state, err := newClusterState(nodes, slots, "10.10.10.10:1234") Expect(err).NotTo(HaveOccurred()) return state } Describe("sorting", func() { BeforeEach(func() { state = createClusterState([]ClusterSlot{{ Start: 1000, End: 1999, }, { Start: 0, End: 999, }, { Start: 2000, End: 2999, }}) }) It("sorts slots", func() { Expect(state.slots).To(Equal([]*clusterSlot{ {start: 0, end: 999, nodes: nil}, {start: 1000, end: 1999, nodes: nil}, {start: 2000, end: 2999, nodes: nil}, })) }) }) Describe("loopback", func() { BeforeEach(func() { state = createClusterState([]ClusterSlot{{ Nodes: []ClusterNode{{Addr: "127.0.0.1:7001"}}, }, { Nodes: []ClusterNode{{Addr: "127.0.0.1:7002"}}, }, { Nodes: []ClusterNode{{Addr: "1.2.3.4:1234"}}, }, { Nodes: []ClusterNode{{Addr: ":1234"}}, }}) }) It("replaces loopback hosts in addresses", func() { slotAddr := func(slot *clusterSlot) string { return slot.nodes[0].Client.Options().Addr } Expect(slotAddr(state.slots[0])).To(Equal("10.10.10.10:7001")) Expect(slotAddr(state.slots[1])).To(Equal("10.10.10.10:7002")) Expect(slotAddr(state.slots[2])).To(Equal("1.2.3.4:1234")) Expect(slotAddr(state.slots[3])).To(Equal(":1234")) }) }) }) type fixedHash string func (h fixedHash) Get(string) string { return string(h) } func TestRingSetAddrsAndRebalanceRace(t *testing.T) { const ( ringShard1Name = "ringShardOne" ringShard2Name = "ringShardTwo" ringShard1Port = "6390" ringShard2Port = "6391" ) ring := NewRing(&RingOptions{ Addrs: map[string]string{ ringShard1Name: ":" + ringShard1Port, }, // Disable heartbeat HeartbeatFrequency: 1 * time.Hour, NewConsistentHash: func(shards []string) ConsistentHash { switch len(shards) { case 1: return fixedHash(ringShard1Name) case 2: return fixedHash(ringShard2Name) default: t.Fatalf("Unexpected number of shards: %v", shards) return nil } }, }) defer ring.Close() // Continuously update addresses by adding and removing one address updatesDone := make(chan struct{}) defer func() { close(updatesDone) }() go func() { for i := 0; ; i++ { select { case <-updatesDone: return default: if i%2 == 0 { ring.SetAddrs(map[string]string{ ringShard1Name: ":" + ringShard1Port, }) } else { ring.SetAddrs(map[string]string{ ringShard1Name: ":" + ringShard1Port, ringShard2Name: ":" + ringShard2Port, }) } } } }() timer := time.NewTimer(1 * time.Second) for running := true; running; { select { case <-timer.C: running = false default: shard, err := ring.sharding.GetByKey("whatever") if err == nil && shard == nil { t.Fatal("shard is nil") } } } } func BenchmarkRingShardingRebalanceLocked(b *testing.B) { opts := &RingOptions{ Addrs: make(map[string]string), // Disable heartbeat HeartbeatFrequency: 1 * time.Hour, } for i := 0; i < 100; i++ { opts.Addrs[fmt.Sprintf("shard%d", i)] = fmt.Sprintf(":63%02d", i) } ring := NewRing(opts) defer ring.Close() b.ResetTimer() for i := 0; i < b.N; i++ { ring.sharding.rebalanceLocked() } } type testCounter struct { mu sync.Mutex t *testing.T m map[string]int } func newTestCounter(t *testing.T) *testCounter { return &testCounter{t: t, m: make(map[string]int)} } func (ct *testCounter) increment(key string) { ct.mu.Lock() defer ct.mu.Unlock() ct.m[key]++ } func (ct *testCounter) expect(values map[string]int) { ct.mu.Lock() defer ct.mu.Unlock() ct.t.Helper() if !reflect.DeepEqual(values, ct.m) { ct.t.Errorf("expected %v != actual %v", values, ct.m) } } func TestRingShardsCleanup(t *testing.T) { const ( ringShard1Name = "ringShardOne" ringShard2Name = "ringShardTwo" ringShard1Addr = "shard1.test" ringShard2Addr = "shard2.test" ) t.Run("closes unused shards", func(t *testing.T) { closeCounter := newTestCounter(t) ring := NewRing(&RingOptions{ Addrs: map[string]string{ ringShard1Name: ringShard1Addr, ringShard2Name: ringShard2Addr, }, NewClient: func(opt *Options) *Client { c := NewClient(opt) c.baseClient.onClose = func() error { closeCounter.increment(opt.Addr) return nil } return c }, }) closeCounter.expect(map[string]int{}) // no change due to the same addresses ring.SetAddrs(map[string]string{ ringShard1Name: ringShard1Addr, ringShard2Name: ringShard2Addr, }) closeCounter.expect(map[string]int{}) ring.SetAddrs(map[string]string{ ringShard1Name: ringShard1Addr, }) closeCounter.expect(map[string]int{ringShard2Addr: 1}) ring.SetAddrs(map[string]string{ ringShard2Name: ringShard2Addr, }) closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1}) ring.Close() closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 2}) }) t.Run("closes created shards if ring was closed", func(t *testing.T) { createCounter := newTestCounter(t) closeCounter := newTestCounter(t) var ( ring *Ring shouldClose int32 ) ring = NewRing(&RingOptions{ Addrs: map[string]string{ ringShard1Name: ringShard1Addr, }, NewClient: func(opt *Options) *Client { if atomic.LoadInt32(&shouldClose) != 0 { ring.Close() } createCounter.increment(opt.Addr) c := NewClient(opt) c.baseClient.onClose = func() error { closeCounter.increment(opt.Addr) return nil } return c }, }) createCounter.expect(map[string]int{ringShard1Addr: 1}) closeCounter.expect(map[string]int{}) atomic.StoreInt32(&shouldClose, 1) ring.SetAddrs(map[string]string{ ringShard2Name: ringShard2Addr, }) createCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1}) closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1}) }) } //------------------------------------------------------------------------------ type timeoutErr struct { error } func (e timeoutErr) Timeout() bool { return true } func (e timeoutErr) Temporary() bool { return true } func (e timeoutErr) Error() string { return "i/o timeout" } var _ = Describe("withConn", func() { var client *Client BeforeEach(func() { client = NewClient(&Options{ PoolSize: 1, }) }) AfterEach(func() { client.Close() }) It("should replace the connection in the pool when there is no error", func() { var conn *pool.Conn client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error { conn = c return nil }) newConn, err := client.connPool.Get(ctx) Expect(err).To(BeNil()) Expect(newConn).To(Equal(conn)) }) It("should replace the connection in the pool when there is an error not related to a bad connection", func() { var conn *pool.Conn client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error { conn = c return proto.RedisError("LOADING") }) newConn, err := client.connPool.Get(ctx) Expect(err).To(BeNil()) Expect(newConn).To(Equal(conn)) }) It("should remove the connection from the pool when it times out", func() { var conn *pool.Conn client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error { conn = c return timeoutErr{} }) newConn, err := client.connPool.Get(ctx) Expect(err).To(BeNil()) Expect(newConn).NotTo(Equal(conn)) Expect(client.connPool.Len()).To(Equal(1)) }) })