diff --git a/cluster.go b/cluster.go index 346d675f..51ceb09e 100644 --- a/cluster.go +++ b/cluster.go @@ -49,7 +49,7 @@ type ClusterOptions struct { // and load-balance read/write operations between master and slaves. // It can use service like ZooKeeper to maintain configuration information // and Cluster.ReloadState to manually trigger state reloading. - ClusterSlots func() ([]ClusterSlot, error) + ClusterSlots func(context.Context) ([]ClusterSlot, error) // Following options are copied from Options struct. @@ -987,7 +987,7 @@ func (c *ClusterClient) PoolStats() *PoolStats { func (c *ClusterClient) loadState(ctx context.Context) (*clusterState, error) { if c.opt.ClusterSlots != nil { - slots, err := c.opt.ClusterSlots() + slots, err := c.opt.ClusterSlots(ctx) if err != nil { return nil, err } diff --git a/cluster_test.go b/cluster_test.go index 9cb41173..6bca752f 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -911,7 +911,7 @@ var _ = Describe("ClusterClient", func() { failover = true opt = redisClusterOptions() - opt.ClusterSlots = func() ([]redis.ClusterSlot, error) { + opt.ClusterSlots = func(ctx context.Context) ([]redis.ClusterSlot, error) { slots := []redis.ClusterSlot{{ Start: 0, End: 4999, @@ -965,7 +965,7 @@ var _ = Describe("ClusterClient", func() { opt = redisClusterOptions() opt.RouteRandomly = true - opt.ClusterSlots = func() ([]redis.ClusterSlot, error) { + opt.ClusterSlots = func(ctx context.Context) ([]redis.ClusterSlot, error) { slots := []redis.ClusterSlot{{ Start: 0, End: 4999, diff --git a/example_test.go b/example_test.go index 0d1cac97..63d9f8cf 100644 --- a/example_test.go +++ b/example_test.go @@ -80,7 +80,7 @@ func ExampleNewClusterClient_manualSetup() { // clusterSlots returns cluster slots information. // It can use service like ZooKeeper to maintain configuration information // and Cluster.ReloadState to manually trigger state reloading. - clusterSlots := func() ([]redis.ClusterSlot, error) { + clusterSlots := func(ctx context.Context) ([]redis.ClusterSlot, error) { slots := []redis.ClusterSlot{ // First node with 1 master and 1 slave. { diff --git a/go.sum b/go.sum index ccc8af1b..9b0733ee 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,7 @@ github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.14.0 h1:2mOpI4JVVPBN+WQRa0WKH2eXR+Ey+uK4n7Zj0aYpIQA= github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.14.1 h1:jMU0WaQrP0a/YAEq8eJmJKjBoMs+pClEr1vDMlM/Do4= github.com/onsi/ginkgo v1.14.1/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1 h1:o0+MgICZLuZ7xjH7Vx6zS/zcu93/BEp1VwkIW1mEXCE= diff --git a/sentinel.go b/sentinel.go index 41b142a0..7a6ccf39 100644 --- a/sentinel.go +++ b/sentinel.go @@ -88,6 +88,33 @@ func (opt *FailoverOptions) options() *Options { } } +func (opt *FailoverOptions) clusterOptions() *ClusterOptions { + return &ClusterOptions{ + Dialer: opt.Dialer, + OnConnect: opt.OnConnect, + + Username: opt.Username, + Password: opt.Password, + + MaxRetries: opt.MaxRetries, + MinRetryBackoff: opt.MinRetryBackoff, + MaxRetryBackoff: opt.MaxRetryBackoff, + + DialTimeout: opt.DialTimeout, + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + + PoolSize: opt.PoolSize, + PoolTimeout: opt.PoolTimeout, + IdleTimeout: opt.IdleTimeout, + IdleCheckFrequency: opt.IdleCheckFrequency, + MinIdleConns: opt.MinIdleConns, + MaxConnAge: opt.MaxConnAge, + + TLSConfig: opt.TLSConfig, + } +} + // NewFailoverClient returns a Redis client that uses Redis Sentinel // for automatic failover. It's safe for concurrent use by multiple // goroutines. @@ -103,8 +130,18 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { opt: opt, } + // TODO: this overwrites original dialer + opt.Dialer = failover.dial + + connPool := newConnPool(opt) + failover.onFailover = func(ctx context.Context, addr string) { + _ = connPool.Filter(func(cn *pool.Conn) bool { + return cn.RemoteAddr().String() != addr + }) + } + c := Client{ - baseClient: newBaseClient(opt, failover.Pool()), + baseClient: newBaseClient(opt, connPool), ctx: context.Background(), } c.cmdable = c.Process @@ -283,14 +320,14 @@ func (c *SentinelClient) Remove(ctx context.Context, name string) *StringCmd { return cmd } +//------------------------------------------------------------------------------ + type sentinelFailover struct { sentinelAddrs []string sentinelPassword string - opt *Options - - pool *pool.ConnPool - poolOnce sync.Once + opt *Options + onFailover func(ctx context.Context, addr string) mu sync.RWMutex masterName string @@ -321,15 +358,6 @@ func (c *sentinelFailover) closeSentinel() error { return firstErr } -func (c *sentinelFailover) Pool() *pool.ConnPool { - c.poolOnce.Do(func() { - opt := *c.opt - opt.Dialer = c.dial - c.pool = newConnPool(&opt) - }) - return c.pool -} - func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Conn, error) { var addr string var err error @@ -338,6 +366,9 @@ func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Con addr, err = c.RandomSlaveAddr(ctx) } else { addr, err = c.MasterAddr(ctx) + if err == nil { + c.trySwitchMaster(ctx, addr) + } } if err != nil { @@ -349,15 +380,6 @@ func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Con return net.DialTimeout("tcp", addr, c.opt.DialTimeout) } -func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { - addr, err := c.masterAddr(ctx) - if err != nil { - return "", err - } - c.switchMaster(ctx, addr) - return addr, nil -} - func (c *sentinelFailover) RandomSlaveAddr(ctx context.Context) (string, error) { addresses, err := c.slaveAddresses(ctx) if err != nil { @@ -369,7 +391,7 @@ func (c *sentinelFailover) RandomSlaveAddr(ctx context.Context) (string, error) return addresses[rand.Intn(len(addresses))], nil } -func (c *sentinelFailover) masterAddr(ctx context.Context) (string, error) { +func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { c.mu.RLock() sentinel := c.sentinel c.mu.RUnlock() @@ -553,27 +575,26 @@ func parseSlaveAddresses(addrs []interface{}) []string { return nodes } -func (c *sentinelFailover) switchMaster(ctx context.Context, addr string) { +func (c *sentinelFailover) trySwitchMaster(ctx context.Context, addr string) { c.mu.RLock() - masterAddr := c._masterAddr + currentAddr := c._masterAddr c.mu.RUnlock() - if masterAddr == addr { + + if addr == currentAddr { return } c.mu.Lock() defer c.mu.Unlock() - if c._masterAddr == addr { + if addr == c._masterAddr { return } + c._masterAddr = addr internal.Logger.Printf(ctx, "sentinel: new master=%q addr=%q", c.masterName, addr) - _ = c.Pool().Filter(func(cn *pool.Conn) bool { - return cn.RemoteAddr().String() != addr - }) - c._masterAddr = addr + go c.onFailover(ctx, addr) } func (c *sentinelFailover) setSentinel(ctx context.Context, sentinel *SentinelClient) { @@ -624,7 +645,7 @@ func (c *sentinelFailover) listen(pubsub *PubSub) { continue } addr := net.JoinHostPort(parts[3], parts[4]) - c.switchMaster(pubsub.getContext(), addr) + c.trySwitchMaster(pubsub.getContext(), addr) } } } @@ -637,3 +658,54 @@ func contains(slice []string, str string) bool { } return false } + +//------------------------------------------------------------------------------ + +func NewFailoverClusterClient(failoverOpt *FailoverOptions) *ClusterClient { + failover := &sentinelFailover{ + masterName: failoverOpt.MasterName, + sentinelAddrs: failoverOpt.SentinelAddrs, + + opt: failoverOpt.options(), + } + + opt := failoverOpt.clusterOptions() + + opt.ClusterSlots = func(ctx context.Context) ([]ClusterSlot, error) { + masterAddr, err := failover.MasterAddr(ctx) + if err != nil { + return nil, err + } + + nodes := []ClusterNode{{ + Addr: masterAddr, + }} + + slaveAddrs, err := failover.slaveAddresses(ctx) + if err != nil { + return nil, err + } + + for _, slaveAddr := range slaveAddrs { + nodes = append(nodes, ClusterNode{ + Addr: slaveAddr, + }) + } + + slots := []ClusterSlot{ + { + Start: 0, + End: 16383, + Nodes: nodes, + }, + } + return slots, nil + } + + c := NewClusterClient(opt) + failover.onFailover = func(ctx context.Context, addr string) { + _ = c.ReloadState(ctx) + } + + return c +} diff --git a/sentinel_test.go b/sentinel_test.go index a4d0fe22..2bbb59c3 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -46,13 +46,13 @@ var _ = Describe("Sentinel", func() { // Wait until slaves are picked up by sentinel. Eventually(func() string { return sentinel1.Info(ctx).Val() - }, "10s", "100ms").Should(ContainSubstring("slaves=2")) + }, "15s", "100ms").Should(ContainSubstring("slaves=2")) Eventually(func() string { return sentinel2.Info(ctx).Val() - }, "10s", "100ms").Should(ContainSubstring("slaves=2")) + }, "15s", "100ms").Should(ContainSubstring("slaves=2")) Eventually(func() string { return sentinel3.Info(ctx).Val() - }, "10s", "100ms").Should(ContainSubstring("slaves=2")) + }, "15s", "100ms").Should(ContainSubstring("slaves=2")) // Kill master. sentinelMaster.Shutdown(ctx) @@ -63,7 +63,7 @@ var _ = Describe("Sentinel", func() { // Wait for Redis sentinel to elect new master. Eventually(func() string { return sentinelSlave1.Info(ctx).Val() + sentinelSlave2.Info(ctx).Val() - }, "30s", "100ms").Should(ContainSubstring("role:master")) + }, "15s", "100ms").Should(ContainSubstring("role:master")) // Check that client picked up new master. Eventually(func() error { @@ -75,7 +75,7 @@ var _ = Describe("Sentinel", func() { Eventually(func() <-chan *redis.Message { _ = client.Publish(ctx, "foo", "hello").Err() return ch - }, "15s").Should(Receive(&msg)) + }, "15s", "100ms").Should(Receive(&msg)) Expect(msg.Channel).To(Equal("foo")) Expect(msg.Payload).To(Equal("hello")) }) @@ -92,3 +92,77 @@ var _ = Describe("Sentinel", func() { Expect(err).NotTo(HaveOccurred()) }) }) + +var _ = Describe("NewFailoverClusterClient", func() { + var client *redis.ClusterClient + + BeforeEach(func() { + client = redis.NewFailoverClusterClient(&redis.FailoverOptions{ + MasterName: sentinelName, + SentinelAddrs: sentinelAddrs, + }) + Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should facilitate failover", func() { + // Set value on master. + err := client.Set(ctx, "foo", "master", 0).Err() + Expect(err).NotTo(HaveOccurred()) + + // Verify. + val, err := sentinelMaster.Get(ctx, "foo").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("master")) + + // Create subscription. + ch := client.Subscribe(ctx, "foo").Channel() + + // Wait until replicated. + Eventually(func() string { + return sentinelSlave1.Get(ctx, "foo").Val() + }, "15s", "100ms").Should(Equal("master")) + Eventually(func() string { + return sentinelSlave2.Get(ctx, "foo").Val() + }, "15s", "100ms").Should(Equal("master")) + + // Wait until slaves are picked up by sentinel. + Eventually(func() string { + return sentinel1.Info(ctx).Val() + }, "15s", "100ms").Should(ContainSubstring("slaves=2")) + Eventually(func() string { + return sentinel2.Info(ctx).Val() + }, "15s", "100ms").Should(ContainSubstring("slaves=2")) + Eventually(func() string { + return sentinel3.Info(ctx).Val() + }, "15s", "100ms").Should(ContainSubstring("slaves=2")) + + // Kill master. + sentinelMaster.Shutdown(ctx) + Eventually(func() error { + return sentinelMaster.Ping(ctx).Err() + }, "15s", "100ms").Should(HaveOccurred()) + + // Wait for Redis sentinel to elect new master. + Eventually(func() string { + return sentinelSlave1.Info(ctx).Val() + sentinelSlave2.Info(ctx).Val() + }, "15s", "100ms").Should(ContainSubstring("role:master")) + + // Check that client picked up new master. + Eventually(func() error { + return client.Get(ctx, "foo").Err() + }, "15s", "100ms").ShouldNot(HaveOccurred()) + + // Check if subscription is renewed. + var msg *redis.Message + Eventually(func() <-chan *redis.Message { + _ = client.Publish(ctx, "foo", "hello").Err() + return ch + }, "15s", "100ms").Should(Receive(&msg)) + Expect(msg.Channel).To(Equal("foo")) + Expect(msg.Payload).To(Equal("hello")) + }) +})