diff --git a/sentinel.go b/sentinel.go index 7a6ccf39..c7e1b15a 100644 --- a/sentinel.go +++ b/sentinel.go @@ -26,8 +26,8 @@ type FailoverOptions struct { // Sentinel password from "requirepass " (if enabled) in Sentinel configuration SentinelPassword string - // Enables read-only commands on slave nodes. - ReadOnly bool + // Route all commands to slave read-only nodes. + SlaveOnly bool // Following options are copied from Options struct. @@ -57,7 +57,7 @@ type FailoverOptions struct { } func (opt *FailoverOptions) options() *Options { - return &Options{ + redisOpt := &Options{ Addr: "FailoverClient", Dialer: opt.Dialer, @@ -83,13 +83,13 @@ func (opt *FailoverOptions) options() *Options { MaxConnAge: opt.MaxConnAge, TLSConfig: opt.TLSConfig, - - sentinelReadOnly: opt.ReadOnly, } + redisOpt.init() + return redisOpt } func (opt *FailoverOptions) clusterOptions() *ClusterOptions { - return &ClusterOptions{ + clusterOpt := &ClusterOptions{ Dialer: opt.Dialer, OnConnect: opt.OnConnect, @@ -113,25 +113,24 @@ func (opt *FailoverOptions) clusterOptions() *ClusterOptions { TLSConfig: opt.TLSConfig, } + clusterOpt.init() + return clusterOpt } // NewFailoverClient returns a Redis client that uses Redis Sentinel // for automatic failover. It's safe for concurrent use by multiple // goroutines. func NewFailoverClient(failoverOpt *FailoverOptions) *Client { - opt := failoverOpt.options() - opt.init() - failover := &sentinelFailover{ masterName: failoverOpt.MasterName, sentinelAddrs: failoverOpt.SentinelAddrs, sentinelPassword: failoverOpt.SentinelPassword, - opt: opt, + opt: failoverOpt.options(), } - // TODO: this overwrites original dialer - opt.Dialer = failover.dial + opt := failoverOpt.options() + opt.Dialer = masterSlaveDialer(failover, failoverOpt.SlaveOnly) connPool := newConnPool(opt) failover.onFailover = func(ctx context.Context, addr string) { @@ -150,8 +149,35 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { return &c } +func masterSlaveDialer( + failover *sentinelFailover, slaveOnly bool, +) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, _ string) (net.Conn, error) { + var addr string + var err error + + if slaveOnly { + addr, err = failover.RandomSlaveAddr(ctx) + } else { + addr, err = failover.MasterAddr(ctx) + if err == nil { + failover.trySwitchMaster(ctx, addr) + } + } + + if err != nil { + return nil, err + } + if failover.opt.Dialer != nil { + return failover.opt.Dialer(ctx, network, addr) + } + return net.DialTimeout("tcp", addr, failover.opt.DialTimeout) + } +} + //------------------------------------------------------------------------------ +// SentinelClient is a client for a Redis Sentinel. type SentinelClient struct { *baseClient ctx context.Context @@ -358,34 +384,12 @@ func (c *sentinelFailover) closeSentinel() error { return firstErr } -func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Conn, error) { - var addr string - var err error - - if c.opt.sentinelReadOnly { - addr, err = c.RandomSlaveAddr(ctx) - } else { - addr, err = c.MasterAddr(ctx) - if err == nil { - c.trySwitchMaster(ctx, addr) - } - } - - if err != nil { - return nil, err - } - if c.opt.Dialer != nil { - return c.opt.Dialer(ctx, network, addr) - } - return net.DialTimeout("tcp", addr, c.opt.DialTimeout) -} - func (c *sentinelFailover) RandomSlaveAddr(ctx context.Context) (string, error) { addresses, err := c.slaveAddresses(ctx) if err != nil { return "", err } - if len(addresses) < 1 { + if len(addresses) == 0 { return c.MasterAddr(ctx) } return addresses[rand.Intn(len(addresses))], nil diff --git a/sentinel_test.go b/sentinel_test.go index 2bbb59c3..e47d4eb7 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -28,7 +28,7 @@ var _ = Describe("Sentinel", func() { Expect(err).NotTo(HaveOccurred()) // Verify. - val, err := sentinelMaster.Get(ctx, "foo").Result() + val, err := client.Get(ctx, "foo").Result() Expect(err).NotTo(HaveOccurred()) Expect(val).To(Equal("master")) @@ -78,6 +78,10 @@ var _ = Describe("Sentinel", func() { }, "15s", "100ms").Should(Receive(&msg)) Expect(msg.Channel).To(Equal("foo")) Expect(msg.Payload).To(Equal("hello")) + + Expect(sentinelMaster.Close()).NotTo(HaveOccurred()) + sentinelMaster, err = startRedis(sentinelMasterPort) + Expect(err).NotTo(HaveOccurred()) }) It("supports DB selection", func() { @@ -114,7 +118,7 @@ var _ = Describe("NewFailoverClusterClient", func() { Expect(err).NotTo(HaveOccurred()) // Verify. - val, err := sentinelMaster.Get(ctx, "foo").Result() + val, err := client.Get(ctx, "foo").Result() Expect(err).NotTo(HaveOccurred()) Expect(val).To(Equal("master")) @@ -164,5 +168,9 @@ var _ = Describe("NewFailoverClusterClient", func() { }, "15s", "100ms").Should(Receive(&msg)) Expect(msg.Channel).To(Equal("foo")) Expect(msg.Payload).To(Equal("hello")) + + Expect(sentinelMaster.Close()).NotTo(HaveOccurred()) + sentinelMaster, err = startRedis(sentinelMasterPort) + Expect(err).NotTo(HaveOccurred()) }) })