From 6060f097e1dd91aac7569de37bf49c8a5d4c2043 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sun, 9 Jul 2017 10:07:20 +0300 Subject: [PATCH 1/2] Add PubSub support to Cluster client --- cluster.go | 60 ++++++++++++++++++++++++++++++++++++++++++++----- cluster_test.go | 22 ++++++++++++++++++ pubsub.go | 34 +++++++++++++--------------- pubsub_test.go | 6 ++--- redis.go | 38 ++++++++++++++++++++++--------- ring.go | 4 ++-- sentinel.go | 24 +++++++++++--------- 7 files changed, 138 insertions(+), 50 deletions(-) diff --git a/cluster.go b/cluster.go index f758b01..2641b16 100644 --- a/cluster.go +++ b/cluster.go @@ -327,7 +327,7 @@ func (c *clusterState) slotClosestNode(slot int) (*clusterNode, error) { } func (c *clusterState) slotNodes(slot int) []*clusterNode { - if slot < len(c.slots) { + if slot >= 0 && slot < len(c.slots) { return c.slots[slot] } return nil @@ -720,14 +720,14 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { failedCmds := make(map[*clusterNode][]Cmder) for node, cmds := range cmdsMap { - cn, _, err := node.Client.conn() + cn, _, err := node.Client.getConn() if err != nil { setCmdsErr(cmds, err) continue } err = c.pipelineProcessCmds(cn, cmds, failedCmds) - node.Client.putConn(cn, err) + node.Client.releaseConn(cn, err) } if len(failedCmds) == 0 { @@ -855,14 +855,14 @@ func (c *ClusterClient) txPipelineExec(cmds []Cmder) error { failedCmds := make(map[*clusterNode][]Cmder) for node, cmds := range cmdsMap { - cn, _, err := node.Client.conn() + cn, _, err := node.Client.getConn() if err != nil { setCmdsErr(cmds, err) continue } err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds) - node.Client.putConn(cn, err) + node.Client.releaseConn(cn, err) } if len(failedCmds) == 0 { @@ -966,6 +966,56 @@ func (c *ClusterClient) txPipelineReadQueued( return firstErr } +func (c *ClusterClient) pubSub(channels []string) *PubSub { + opt := c.opt.clientOptions() + + var node *clusterNode + return &PubSub{ + opt: opt, + + newConn: func(channels []string) (*pool.Conn, error) { + if node == nil { + var slot int + if len(channels) > 0 { + slot = hashtag.Slot(channels[0]) + } else { + slot = -1 + } + + masterNode, err := c.state().slotMasterNode(slot) + if err != nil { + return nil, err + } + node = masterNode + } + return node.Client.newConn() + }, + closeConn: func(cn *pool.Conn) error { + return node.Client.connPool.CloseConn(cn) + }, + } +} + +// Subscribe subscribes the client to the specified channels. +// Channels can be omitted to create empty subscription. +func (c *ClusterClient) Subscribe(channels ...string) *PubSub { + pubsub := c.pubSub(channels) + if len(channels) > 0 { + _ = pubsub.Subscribe(channels...) + } + return pubsub +} + +// PSubscribe subscribes the client to the given patterns. +// Patterns can be omitted to create empty subscription. +func (c *ClusterClient) PSubscribe(channels ...string) *PubSub { + pubsub := c.pubSub(channels) + if len(channels) > 0 { + _ = pubsub.PSubscribe(channels...) + } + return pubsub +} + func isLoopbackAddr(addr string) bool { host, _, err := net.SplitHostPort(addr) if err != nil { diff --git a/cluster_test.go b/cluster_test.go index 3a69255..1dc7229 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -472,6 +472,28 @@ var _ = Describe("ClusterClient", func() { }) }) + It("supports PubSub", func() { + pubsub := client.Subscribe("mychannel") + defer pubsub.Close() + + msgi, err := pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*redis.Subscription) + Expect(subscr.Kind).To(Equal("subscribe")) + Expect(subscr.Channel).To(Equal("mychannel")) + Expect(subscr.Count).To(Equal(1)) + + n, err := client.Publish("mychannel", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + msgi, err = pubsub.ReceiveTimeout(time.Second) + Expect(err).NotTo(HaveOccurred()) + msg := msgi.(*redis.Message) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) + }) + It("calls fn for every master node", func() { for i := 0; i < 10; i++ { Expect(client.Set(strconv.Itoa(i), "", 0).Err()).NotTo(HaveOccurred()) diff --git a/pubsub.go b/pubsub.go index 4872b4e..74ac51c 100644 --- a/pubsub.go +++ b/pubsub.go @@ -17,7 +17,10 @@ import ( // PubSub automatically resubscribes to the channels and patterns // when Redis becomes unavailable. type PubSub struct { - base baseClient + opt *Options + + newConn func([]string) (*pool.Conn, error) + closeConn func(*pool.Conn) error mu sync.Mutex cn *pool.Conn @@ -30,12 +33,12 @@ type PubSub struct { func (c *PubSub) conn() (*pool.Conn, error) { c.mu.Lock() - cn, err := c._conn() + cn, err := c._conn(nil) c.mu.Unlock() return cn, err } -func (c *PubSub) _conn() (*pool.Conn, error) { +func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { if c.closed { return nil, pool.ErrClosed } @@ -44,20 +47,13 @@ func (c *PubSub) _conn() (*pool.Conn, error) { return c.cn, nil } - cn, err := c.base.connPool.NewConn() + cn, err := c.newConn(channels) if err != nil { return nil, err } - if !cn.Inited { - if err := c.base.initConn(cn); err != nil { - _ = c.base.connPool.CloseConn(cn) - return nil, err - } - } - if err := c.resubscribe(cn); err != nil { - _ = c.base.connPool.CloseConn(cn) + _ = c.closeConn(cn) return nil, err } @@ -88,7 +84,7 @@ func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) } cmd := NewSliceCmd(args...) - cn.SetWriteTimeout(c.base.opt.WriteTimeout) + cn.SetWriteTimeout(c.opt.WriteTimeout) return writeCmd(cn, cmd) } @@ -99,13 +95,13 @@ func (c *PubSub) putConn(cn *pool.Conn, err error) { c.mu.Lock() if c.cn == cn { - _ = c.closeConn() + _ = c.releaseConn() } c.mu.Unlock() } -func (c *PubSub) closeConn() error { - err := c.base.connPool.CloseConn(c.cn) +func (c *PubSub) releaseConn() error { + err := c.closeConn(c.cn) c.cn = nil return err } @@ -120,7 +116,7 @@ func (c *PubSub) Close() error { c.closed = true if c.cn != nil { - return c.closeConn() + return c.releaseConn() } return nil } @@ -166,7 +162,7 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error { } func (c *PubSub) subscribe(redisCmd string, channels ...string) error { - cn, err := c._conn() + cn, err := c._conn(channels) if err != nil { return err } @@ -188,7 +184,7 @@ func (c *PubSub) Ping(payload ...string) error { return err } - cn.SetWriteTimeout(c.base.opt.WriteTimeout) + cn.SetWriteTimeout(c.opt.WriteTimeout) err = writeCmd(cn, cmd) c.putConn(cn, err) return err diff --git a/pubsub_test.go b/pubsub_test.go index e8589f4..3cb9627 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -159,9 +159,9 @@ var _ = Describe("PubSub", func() { { msgi, err := pubsub.ReceiveTimeout(time.Second) Expect(err).NotTo(HaveOccurred()) - subscr := msgi.(*redis.Message) - Expect(subscr.Channel).To(Equal("mychannel")) - Expect(subscr.Payload).To(Equal("hello")) + msg := msgi.(*redis.Message) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) } { diff --git a/redis.go b/redis.go index 9812daf..1a2ecb0 100644 --- a/redis.go +++ b/redis.go @@ -21,7 +21,23 @@ func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) } -func (c *baseClient) conn() (*pool.Conn, bool, error) { +func (c *baseClient) newConn() (*pool.Conn, error) { + cn, err := c.connPool.NewConn() + if err != nil { + return nil, err + } + + if !cn.Inited { + if err := c.initConn(cn); err != nil { + _ = c.connPool.CloseConn(cn) + return nil, err + } + } + + return cn, nil +} + +func (c *baseClient) getConn() (*pool.Conn, bool, error) { cn, isNew, err := c.connPool.Get() if err != nil { return nil, false, err @@ -37,7 +53,7 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) { return cn, isNew, nil } -func (c *baseClient) putConn(cn *pool.Conn, err error) bool { +func (c *baseClient) releaseConn(cn *pool.Conn, err error) bool { if internal.IsBadConn(err, false) { _ = c.connPool.Remove(cn) return false @@ -112,7 +128,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { time.Sleep(internal.RetryBackoff(i, c.opt.MaxRetryBackoff)) } - cn, _, err := c.conn() + cn, _, err := c.getConn() if err != nil { cmd.setErr(err) if internal.IsRetryableError(err) { @@ -123,7 +139,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { cn.SetWriteTimeout(c.opt.WriteTimeout) if err := writeCmd(cn, cmd); err != nil { - c.putConn(cn, err) + c.releaseConn(cn, err) cmd.setErr(err) if internal.IsRetryableError(err) { continue @@ -133,7 +149,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { cn.SetReadTimeout(c.cmdTimeout(cmd)) err = cmd.readReply(cn) - c.putConn(cn, err) + c.releaseConn(cn, err) if err != nil && internal.IsRetryableError(err) { continue } @@ -179,14 +195,14 @@ func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer { return func(cmds []Cmder) error { var firstErr error for i := 0; i <= c.opt.MaxRetries; i++ { - cn, _, err := c.conn() + cn, _, err := c.getConn() if err != nil { setCmdsErr(cmds, err) return err } canRetry, err := p(cn, cmds) - c.putConn(cn, err) + c.releaseConn(cn, err) if err == nil { return nil } @@ -375,10 +391,12 @@ func (c *Client) TxPipeline() Pipeliner { func (c *Client) pubSub() *PubSub { return &PubSub{ - base: baseClient{ - opt: c.opt, - connPool: c.connPool, + opt: c.opt, + + newConn: func(channels []string) (*pool.Conn, error) { + return c.newConn() }, + closeConn: c.connPool.CloseConn, } } diff --git a/ring.go b/ring.go index be92510..72d52bf 100644 --- a/ring.go +++ b/ring.go @@ -423,7 +423,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) { continue } - cn, _, err := shard.Client.conn() + cn, _, err := shard.Client.getConn() if err != nil { setCmdsErr(cmds, err) if firstErr == nil { @@ -433,7 +433,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) { } canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds) - shard.Client.putConn(cn, err) + shard.Client.releaseConn(cn, err) if err == nil { continue } diff --git a/sentinel.go b/sentinel.go index ed6e7ff..3bfdb4a 100644 --- a/sentinel.go +++ b/sentinel.go @@ -112,10 +112,12 @@ func newSentinel(opt *Options) *sentinelClient { func (c *sentinelClient) PubSub() *PubSub { return &PubSub{ - base: baseClient{ - opt: c.opt, - connPool: c.connPool, + opt: c.opt, + + newConn: func(channels []string) (*pool.Conn, error) { + return c.newConn() }, + closeConn: c.connPool.CloseConn, } } @@ -149,14 +151,6 @@ func (d *sentinelFailover) Close() error { return d.resetSentinel() } -func (d *sentinelFailover) dial() (net.Conn, error) { - addr, err := d.MasterAddr() - if err != nil { - return nil, err - } - return net.DialTimeout("tcp", addr, d.opt.DialTimeout) -} - func (d *sentinelFailover) Pool() *pool.ConnPool { d.poolOnce.Do(func() { d.opt.Dialer = d.dial @@ -165,6 +159,14 @@ func (d *sentinelFailover) Pool() *pool.ConnPool { return d.pool } +func (d *sentinelFailover) dial() (net.Conn, error) { + addr, err := d.MasterAddr() + if err != nil { + return nil, err + } + return net.DialTimeout("tcp", addr, d.opt.DialTimeout) +} + func (d *sentinelFailover) MasterAddr() (string, error) { d.mu.Lock() defer d.mu.Unlock() From 3ddda73a05f3ba7c3ea69ddce33b4667ff8f7aa2 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sun, 9 Jul 2017 13:10:07 +0300 Subject: [PATCH 2/2] Close connections to unused nodes --- cluster.go | 305 +++++++++++++++++++++++++++----------- cluster_test.go | 277 +++++++++++++++++++--------------- export_test.go | 7 +- internal/internal.go | 13 +- internal/internal_test.go | 9 +- options.go | 13 +- redis.go | 24 +-- 7 files changed, 418 insertions(+), 230 deletions(-) diff --git a/cluster.go b/cluster.go index 2641b16..30c5ce6 100644 --- a/cluster.go +++ b/cluster.go @@ -28,18 +28,19 @@ type ClusterOptions struct { // Default is 16. MaxRedirects int - // Enables read queries for a connection to a Redis Cluster slave node. + // Enables read-only commands on slave nodes. ReadOnly bool - - // Enables routing read-only queries to the closest master or slave node. + // Allows routing read-only commands to the closest master or slave node. RouteByLatency bool // Following options are copied from Options struct. OnConnect func(*Conn) error - MaxRetries int - Password string + MaxRetries int + MinRetryBackoff time.Duration + MaxRetryBackoff time.Duration + Password string DialTimeout time.Duration ReadTimeout time.Duration @@ -62,6 +63,19 @@ func (opt *ClusterOptions) init() { if opt.RouteByLatency { opt.ReadOnly = true } + + switch opt.MinRetryBackoff { + case -1: + opt.MinRetryBackoff = 0 + case 0: + opt.MinRetryBackoff = 8 * time.Millisecond + } + switch opt.MaxRetryBackoff { + case -1: + opt.MaxRetryBackoff = 0 + case 0: + opt.MaxRetryBackoff = 512 * time.Millisecond + } } func (opt *ClusterOptions) clientOptions() *Options { @@ -70,9 +84,11 @@ func (opt *ClusterOptions) clientOptions() *Options { return &Options{ OnConnect: opt.OnConnect, - MaxRetries: opt.MaxRetries, - Password: opt.Password, - ReadOnly: opt.ReadOnly, + MaxRetries: opt.MaxRetries, + MinRetryBackoff: opt.MinRetryBackoff, + MaxRetryBackoff: opt.MaxRetryBackoff, + Password: opt.Password, + ReadOnly: opt.ReadOnly, DialTimeout: opt.DialTimeout, ReadTimeout: opt.ReadTimeout, @@ -91,7 +107,9 @@ func (opt *ClusterOptions) clientOptions() *Options { type clusterNode struct { Client *Client Latency time.Duration - loading time.Time + + loading time.Time + generation uint32 } func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode { @@ -122,6 +140,17 @@ func (n *clusterNode) Loading() bool { return !n.loading.IsZero() && time.Since(n.loading) < time.Minute } +func (n *clusterNode) Generation() uint32 { + return n.generation +} + +func (n *clusterNode) SetGeneration(gen uint32) { + if gen < n.generation { + panic("gen < n.generation") + } + n.generation = gen +} + //------------------------------------------------------------------------------ type clusterNodes struct { @@ -131,6 +160,8 @@ type clusterNodes struct { addrs []string nodes map[string]*clusterNode closed bool + + generation uint32 } func newClusterNodes(opt *ClusterOptions) *clusterNodes { @@ -161,6 +192,39 @@ func (c *clusterNodes) Close() error { return firstErr } +func (c *clusterNodes) NextGeneration() uint32 { + c.generation++ + return c.generation +} + +// GC removes unused nodes. +func (c *clusterNodes) GC(generation uint32) error { + var collected []*clusterNode + c.mu.Lock() + for i := 0; i < len(c.addrs); { + addr := c.addrs[i] + node := c.nodes[addr] + if node.Generation() >= generation { + i++ + continue + } + + c.addrs = append(c.addrs[:i], c.addrs[i+1:]...) + delete(c.nodes, addr) + collected = append(collected, node) + } + c.mu.Unlock() + + var firstErr error + for _, node := range collected { + if err := node.Client.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + + return firstErr +} + func (c *clusterNodes) All() ([]*clusterNode, error) { c.mu.RLock() defer c.mu.RUnlock() @@ -176,7 +240,7 @@ func (c *clusterNodes) All() ([]*clusterNode, error) { return nodes, nil } -func (c *clusterNodes) Get(addr string) (*clusterNode, error) { +func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { var node *clusterNode var ok bool @@ -223,7 +287,7 @@ func (c *clusterNodes) Random() (*clusterNode, error) { var nodeErr error for i := 0; i <= c.opt.MaxRedirects; i++ { n := rand.Intn(len(addrs)) - node, err := c.Get(addrs[n]) + node, err := c.GetOrCreate(addrs[n]) if err != nil { return nil, err } @@ -239,30 +303,45 @@ func (c *clusterNodes) Random() (*clusterNode, error) { //------------------------------------------------------------------------------ type clusterState struct { - nodes *clusterNodes + nodes *clusterNodes + masters []*clusterNode + slaves []*clusterNode + slots [][]*clusterNode + + generation uint32 } func newClusterState(nodes *clusterNodes, slots []ClusterSlot, origin string) (*clusterState, error) { c := clusterState{ - nodes: nodes, + nodes: nodes, + generation: nodes.NextGeneration(), + slots: make([][]*clusterNode, hashtag.SlotNumber), } isLoopbackOrigin := isLoopbackAddr(origin) for _, slot := range slots { var nodes []*clusterNode - for _, slotNode := range slot.Nodes { + for i, slotNode := range slot.Nodes { addr := slotNode.Addr if !isLoopbackOrigin && isLoopbackAddr(addr) { addr = origin } - node, err := c.nodes.Get(addr) + node, err := c.nodes.GetOrCreate(addr) if err != nil { return nil, err } + + node.SetGeneration(c.generation) nodes = append(nodes, node) + + if i == 0 { + c.masters = appendNode(c.masters, node) + } else { + c.slaves = appendNode(c.slaves, node) + } } for i := slot.Start; i <= slot.End; i++ { @@ -348,7 +427,7 @@ type ClusterClient struct { cmdsInfoOnce internal.Once cmdsInfo map[string]*CommandInfo - // Reports where slots reloading is in progress. + // Reports whether slots reloading is in progress. reloading uint32 } @@ -365,12 +444,12 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { // Add initial nodes. for _, addr := range opt.Addrs { - _, _ = c.nodes.Get(addr) + _, _ = c.nodes.GetOrCreate(addr) } // Preload cluster slots. for i := 0; i < 10; i++ { - state, err := c.reloadSlots() + state, err := c.reloadState() if err == nil { c._state.Store(state) break @@ -394,7 +473,7 @@ func (c *ClusterClient) state() *clusterState { if v != nil { return v.(*clusterState) } - c.lazyReloadSlots() + c.lazyReloadState() return nil } @@ -476,6 +555,10 @@ func (c *ClusterClient) Process(cmd Cmder) error { var ask bool for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { + if attempt > 0 { + time.Sleep(node.Client.retryBackoff(attempt)) + } + if ask { pipe := node.Client.Pipeline() pipe.Process(NewCmd("ASKING")) @@ -487,13 +570,14 @@ func (c *ClusterClient) Process(cmd Cmder) error { err = node.Client.Process(cmd) } - // If there is no (real) error - we are done. + // If there is no error - we are done. if err == nil { return nil } // If slave is loading - read from master. if c.opt.ReadOnly && internal.IsLoadingError(err) { + // TODO: race node.loading = time.Now() continue } @@ -516,11 +600,11 @@ func (c *ClusterClient) Process(cmd Cmder) error { if state != nil && slot >= 0 { master, _ := state.slotMasterNode(slot) if moved && (master == nil || master.Client.getAddr() != addr) { - c.lazyReloadSlots() + c.lazyReloadState() } } - node, err = c.nodes.Get(addr) + node, err = c.nodes.GetOrCreate(addr) if err != nil { cmd.setErr(err) return err @@ -535,39 +619,6 @@ func (c *ClusterClient) Process(cmd Cmder) error { return cmd.Err() } -// ForEachNode concurrently calls the fn on each ever known node in the cluster. -// It returns the first error if any. -func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { - nodes, err := c.nodes.All() - if err != nil { - return err - } - - var wg sync.WaitGroup - errCh := make(chan error, 1) - for _, node := range nodes { - wg.Add(1) - go func(node *clusterNode) { - defer wg.Done() - err := fn(node.Client) - if err != nil { - select { - case errCh <- err: - default: - } - } - }(node) - } - wg.Wait() - - select { - case err := <-errCh: - return err - default: - return nil - } -} - // ForEachMaster concurrently calls the fn on each master node in the cluster. // It returns the first error if any. func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { @@ -577,19 +628,8 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { } var wg sync.WaitGroup - visited := make(map[*clusterNode]struct{}) errCh := make(chan error, 1) - for _, nodes := range state.slots { - if len(nodes) == 0 { - continue - } - - master := nodes[0] - if _, ok := visited[master]; ok { - continue - } - visited[master] = struct{}{} - + for _, master := range state.masters { wg.Add(1) go func(node *clusterNode) { defer wg.Done() @@ -612,16 +652,88 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error { } } +// ForEachSlave concurrently calls the fn on each slave node in the cluster. +// It returns the first error if any. +func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error { + state := c.state() + if state == nil { + return errNilClusterState + } + + var wg sync.WaitGroup + errCh := make(chan error, 1) + for _, slave := range state.slaves { + wg.Add(1) + go func(node *clusterNode) { + defer wg.Done() + err := fn(node.Client) + if err != nil { + select { + case errCh <- err: + default: + } + } + }(slave) + } + wg.Wait() + + select { + case err := <-errCh: + return err + default: + return nil + } +} + +// ForEachNode concurrently calls the fn on each known node in the cluster. +// It returns the first error if any. +func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error { + state := c.state() + if state == nil { + return errNilClusterState + } + + var wg sync.WaitGroup + errCh := make(chan error, 1) + worker := func(node *clusterNode) { + defer wg.Done() + err := fn(node.Client) + if err != nil { + select { + case errCh <- err: + default: + } + } + } + + for _, node := range state.masters { + wg.Add(1) + go worker(node) + } + for _, node := range state.slaves { + wg.Add(1) + go worker(node) + } + + wg.Wait() + select { + case err := <-errCh: + return err + default: + return nil + } +} + // PoolStats returns accumulated connection pool stats. func (c *ClusterClient) PoolStats() *PoolStats { var acc PoolStats - nodes, err := c.nodes.All() - if err != nil { + state := c.state() + if state == nil { return &acc } - for _, node := range nodes { + for _, node := range state.masters { s := node.Client.connPool.Stats() acc.Requests += s.Requests acc.Hits += s.Hits @@ -629,33 +741,51 @@ func (c *ClusterClient) PoolStats() *PoolStats { acc.TotalConns += s.TotalConns acc.FreeConns += s.FreeConns } + + for _, node := range state.slaves { + s := node.Client.connPool.Stats() + acc.Requests += s.Requests + acc.Hits += s.Hits + acc.Timeouts += s.Timeouts + acc.TotalConns += s.TotalConns + acc.FreeConns += s.FreeConns + } + return &acc } -func (c *ClusterClient) lazyReloadSlots() { +func (c *ClusterClient) lazyReloadState() { if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) { return } go func() { - for i := 0; i < 1000; i++ { - state, err := c.reloadSlots() + defer atomic.StoreUint32(&c.reloading, 0) + + var state *clusterState + for { + var err error + state, err = c.reloadState() if err == pool.ErrClosed { - break + return } - if err == nil { - c._state.Store(state) - break + + if err != nil { + time.Sleep(time.Millisecond) + continue } - time.Sleep(time.Millisecond) + + c._state.Store(state) + break } time.Sleep(3 * time.Second) - atomic.StoreUint32(&c.reloading, 0) + c.nodes.GC(state.generation) }() } -func (c *ClusterClient) reloadSlots() (*clusterState, error) { +// Not thread-safe. +func (c *ClusterClient) reloadState() (*clusterState, error) { node, err := c.nodes.Random() if err != nil { return nil, err @@ -799,9 +929,9 @@ func (c *ClusterClient) pipelineReadCmds( func (c *ClusterClient) checkMovedErr(cmd Cmder, failedCmds map[*clusterNode][]Cmder) error { moved, ask, addr := internal.IsMovedError(cmd.Err()) if moved { - c.lazyReloadSlots() + c.lazyReloadState() - node, err := c.nodes.Get(addr) + node, err := c.nodes.GetOrCreate(addr) if err != nil { return err } @@ -809,7 +939,7 @@ func (c *ClusterClient) checkMovedErr(cmd Cmder, failedCmds map[*clusterNode][]C failedCmds[node] = append(failedCmds[node], cmd) } if ask { - node, err := c.nodes.Get(addr) + node, err := c.nodes.GetOrCreate(addr) if err != nil { return err } @@ -1029,3 +1159,12 @@ func isLoopbackAddr(addr string) bool { return ip.IsLoopback() } + +func appendNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { + for _, n := range nodes { + if n == node { + return nodes + } + } + return append(nodes, node) +} diff --git a/cluster_test.go b/cluster_test.go index 1dc7229..0176c68 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -75,7 +75,7 @@ func startCluster(scenario *clusterScenario) error { scenario.nodeIds[pos] = info[:40] } - // Meet cluster nodes + // Meet cluster nodes. for _, client := range scenario.clients { err := client.ClusterMeet("127.0.0.1", scenario.ports[0]).Err() if err != nil { @@ -83,7 +83,7 @@ func startCluster(scenario *clusterScenario) error { } } - // Bootstrap masters + // Bootstrap masters. slots := []int{0, 5000, 10000, 16384} for pos, master := range scenario.masters() { err := master.ClusterAddSlotsRange(slots[pos], slots[pos+1]-1).Err() @@ -92,7 +92,7 @@ func startCluster(scenario *clusterScenario) error { } } - // Bootstrap slaves + // Bootstrap slaves. for idx, slave := range scenario.slaves() { masterId := scenario.nodeIds[idx] @@ -115,7 +115,7 @@ func startCluster(scenario *clusterScenario) error { } } - // Wait until all nodes have consistent info + // Wait until all nodes have consistent info. for _, client := range scenario.clients { err := eventually(func() error { res, err := client.ClusterSlots().Result() @@ -189,62 +189,6 @@ var _ = Describe("ClusterClient", func() { var client *redis.ClusterClient assertClusterClient := func() { - It("should CLUSTER SLOTS", func() { - res, err := client.ClusterSlots().Result() - Expect(err).NotTo(HaveOccurred()) - Expect(res).To(HaveLen(3)) - - wanted := []redis.ClusterSlot{ - {0, 4999, []redis.ClusterNode{{"", "127.0.0.1:8220"}, {"", "127.0.0.1:8223"}}}, - {5000, 9999, []redis.ClusterNode{{"", "127.0.0.1:8221"}, {"", "127.0.0.1:8224"}}}, - {10000, 16383, []redis.ClusterNode{{"", "127.0.0.1:8222"}, {"", "127.0.0.1:8225"}}}, - } - Expect(assertSlotsEqual(res, wanted)).NotTo(HaveOccurred()) - }) - - It("should CLUSTER NODES", func() { - res, err := client.ClusterNodes().Result() - Expect(err).NotTo(HaveOccurred()) - Expect(len(res)).To(BeNumerically(">", 400)) - }) - - It("should CLUSTER INFO", func() { - res, err := client.ClusterInfo().Result() - Expect(err).NotTo(HaveOccurred()) - Expect(res).To(ContainSubstring("cluster_known_nodes:6")) - }) - - It("should CLUSTER KEYSLOT", func() { - hashSlot, err := client.ClusterKeySlot("somekey").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(hashSlot).To(Equal(int64(hashtag.Slot("somekey")))) - }) - - It("should CLUSTER COUNT-FAILURE-REPORTS", func() { - n, err := client.ClusterCountFailureReports(cluster.nodeIds[0]).Result() - Expect(err).NotTo(HaveOccurred()) - Expect(n).To(Equal(int64(0))) - }) - - It("should CLUSTER COUNTKEYSINSLOT", func() { - n, err := client.ClusterCountKeysInSlot(10).Result() - Expect(err).NotTo(HaveOccurred()) - Expect(n).To(Equal(int64(0))) - }) - - It("should CLUSTER SAVECONFIG", func() { - res, err := client.ClusterSaveConfig().Result() - Expect(err).NotTo(HaveOccurred()) - Expect(res).To(Equal("OK")) - }) - - It("should CLUSTER SLAVES", func() { - nodesList, err := client.ClusterSlaves(cluster.nodeIds[0]).Result() - Expect(err).NotTo(HaveOccurred()) - Expect(nodesList).Should(ContainElement(ContainSubstring("slave"))) - Expect(nodesList).Should(HaveLen(1)) - }) - It("should GET/SET/DEL", func() { val, err := client.Get("A").Result() Expect(err).To(Equal(redis.Nil)) @@ -254,55 +198,24 @@ var _ = Describe("ClusterClient", func() { Expect(err).NotTo(HaveOccurred()) Expect(val).To(Equal("OK")) - val, err = client.Get("A").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(val).To(Equal("VALUE")) + Eventually(func() string { + return client.Get("A").Val() + }).Should(Equal("VALUE")) cnt, err := client.Del("A").Result() Expect(err).NotTo(HaveOccurred()) Expect(cnt).To(Equal(int64(1))) }) - It("returns pool stats", func() { - Expect(client.PoolStats()).To(BeAssignableToTypeOf(&redis.PoolStats{})) - }) - - It("removes idle connections", func() { - stats := client.PoolStats() - Expect(stats.TotalConns).NotTo(BeZero()) - Expect(stats.FreeConns).NotTo(BeZero()) - - time.Sleep(2 * time.Second) - - stats = client.PoolStats() - Expect(stats.TotalConns).To(BeZero()) - Expect(stats.FreeConns).To(BeZero()) - }) - It("follows redirects", func() { Expect(client.Set("A", "VALUE", 0).Err()).NotTo(HaveOccurred()) slot := hashtag.Slot("A") - Expect(client.SwapSlotNodes(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) + client.SwapSlotNodes(slot) - val, err := client.Get("A").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(val).To(Equal("VALUE")) - }) - - It("returns an error when there are no attempts left", func() { - opt := redisClusterOptions() - opt.MaxRedirects = -1 - client := cluster.clusterClient(opt) - - slot := hashtag.Slot("A") - Expect(client.SwapSlotNodes(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) - - err := client.Get("A").Err() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("MOVED")) - - Expect(client.Close()).NotTo(HaveOccurred()) + Eventually(func() string { + return client.Get("A").Val() + }).Should(Equal("VALUE")) }) It("distributes keys", func() { @@ -311,9 +224,14 @@ var _ = Describe("ClusterClient", func() { Expect(err).NotTo(HaveOccurred()) } - wanted := []string{"keys=31", "keys=29", "keys=40"} - for i, master := range cluster.masters() { - Expect(master.Info().Val()).To(ContainSubstring(wanted[i])) + for _, master := range cluster.masters() { + Eventually(func() string { + return master.Info("keyspace").Val() + }, 5*time.Second).Should(Or( + ContainSubstring("keys=31"), + ContainSubstring("keys=29"), + ContainSubstring("keys=40"), + )) } }) @@ -330,9 +248,14 @@ var _ = Describe("ClusterClient", func() { Expect(err).NotTo(HaveOccurred()) } - wanted := []string{"keys=31", "keys=29", "keys=40"} - for i, master := range cluster.masters() { - Expect(master.Info().Val()).To(ContainSubstring(wanted[i])) + for _, master := range cluster.masters() { + Eventually(func() string { + return master.Info("keyspace").Val() + }, 5*time.Second).Should(Or( + ContainSubstring("keys=31"), + ContainSubstring("keys=29"), + ContainSubstring("keys=40"), + )) } }) @@ -447,7 +370,7 @@ var _ = Describe("ClusterClient", func() { }) } - Describe("Pipeline", func() { + Describe("with Pipeline", func() { BeforeEach(func() { pipe = client.Pipeline().(*redis.Pipeline) }) @@ -459,7 +382,7 @@ var _ = Describe("ClusterClient", func() { assertPipeline() }) - Describe("TxPipeline", func() { + Describe("with TxPipeline", func() { BeforeEach(func() { pipe = client.TxPipeline().(*redis.Pipeline) }) @@ -476,22 +399,70 @@ var _ = Describe("ClusterClient", func() { pubsub := client.Subscribe("mychannel") defer pubsub.Close() - msgi, err := pubsub.ReceiveTimeout(time.Second) - Expect(err).NotTo(HaveOccurred()) - subscr := msgi.(*redis.Subscription) - Expect(subscr.Kind).To(Equal("subscribe")) - Expect(subscr.Channel).To(Equal("mychannel")) - Expect(subscr.Count).To(Equal(1)) + Eventually(func() error { + _, err := client.Publish("mychannel", "hello").Result() + if err != nil { + return err + } - n, err := client.Publish("mychannel", "hello").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(n).To(Equal(int64(1))) + msg, err := pubsub.ReceiveTimeout(time.Second) + if err != nil { + return err + } - msgi, err = pubsub.ReceiveTimeout(time.Second) - Expect(err).NotTo(HaveOccurred()) - msg := msgi.(*redis.Message) - Expect(msg.Channel).To(Equal("mychannel")) - Expect(msg.Payload).To(Equal("hello")) + _, ok := msg.(*redis.Message) + if !ok { + return fmt.Errorf("got %T, wanted *redis.Message", msg) + } + + return nil + }, 30*time.Second).ShouldNot(HaveOccurred()) + }) + } + + Describe("ClusterClient", func() { + BeforeEach(func() { + opt = redisClusterOptions() + client = cluster.clusterClient(opt) + + _ = client.ForEachMaster(func(master *redis.Client) error { + return master.FlushDB().Err() + }) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("returns pool stats", func() { + Expect(client.PoolStats()).To(BeAssignableToTypeOf(&redis.PoolStats{})) + }) + + It("removes idle connections", func() { + stats := client.PoolStats() + Expect(stats.TotalConns).NotTo(BeZero()) + Expect(stats.FreeConns).NotTo(BeZero()) + + time.Sleep(2 * time.Second) + + stats = client.PoolStats() + Expect(stats.TotalConns).To(BeZero()) + Expect(stats.FreeConns).To(BeZero()) + }) + + It("returns an error when there are no attempts left", func() { + opt := redisClusterOptions() + opt.MaxRedirects = -1 + client := cluster.clusterClient(opt) + + slot := hashtag.Slot("A") + client.SwapSlotNodes(slot) + + err := client.Get("A").Err() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("MOVED")) + + Expect(client.Close()).NotTo(HaveOccurred()) }) It("calls fn for every master node", func() { @@ -510,9 +481,67 @@ var _ = Describe("ClusterClient", func() { Expect(keys).To(HaveLen(0)) } }) - } - Describe("default ClusterClient", func() { + It("should CLUSTER SLOTS", func() { + res, err := client.ClusterSlots().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(res).To(HaveLen(3)) + + wanted := []redis.ClusterSlot{ + {0, 4999, []redis.ClusterNode{{"", "127.0.0.1:8220"}, {"", "127.0.0.1:8223"}}}, + {5000, 9999, []redis.ClusterNode{{"", "127.0.0.1:8221"}, {"", "127.0.0.1:8224"}}}, + {10000, 16383, []redis.ClusterNode{{"", "127.0.0.1:8222"}, {"", "127.0.0.1:8225"}}}, + } + Expect(assertSlotsEqual(res, wanted)).NotTo(HaveOccurred()) + }) + + It("should CLUSTER NODES", func() { + res, err := client.ClusterNodes().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(res)).To(BeNumerically(">", 400)) + }) + + It("should CLUSTER INFO", func() { + res, err := client.ClusterInfo().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(res).To(ContainSubstring("cluster_known_nodes:6")) + }) + + It("should CLUSTER KEYSLOT", func() { + hashSlot, err := client.ClusterKeySlot("somekey").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(hashSlot).To(Equal(int64(hashtag.Slot("somekey")))) + }) + + It("should CLUSTER COUNT-FAILURE-REPORTS", func() { + n, err := client.ClusterCountFailureReports(cluster.nodeIds[0]).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(0))) + }) + + It("should CLUSTER COUNTKEYSINSLOT", func() { + n, err := client.ClusterCountKeysInSlot(10).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(0))) + }) + + It("should CLUSTER SAVECONFIG", func() { + res, err := client.ClusterSaveConfig().Result() + Expect(err).NotTo(HaveOccurred()) + Expect(res).To(Equal("OK")) + }) + + It("should CLUSTER SLAVES", func() { + nodesList, err := client.ClusterSlaves(cluster.nodeIds[0]).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(nodesList).Should(ContainElement(ContainSubstring("slave"))) + Expect(nodesList).Should(HaveLen(1)) + }) + + assertClusterClient() + }) + + Describe("ClusterClient failover", func() { BeforeEach(func() { opt = redisClusterOptions() client = cluster.clusterClient(opt) @@ -520,6 +549,10 @@ var _ = Describe("ClusterClient", func() { _ = client.ForEachMaster(func(master *redis.Client) error { return master.FlushDB().Err() }) + + _ = client.ForEachSlave(func(slave *redis.Client) error { + return slave.ClusterFailover().Err() + }) }) AfterEach(func() { diff --git a/export_test.go b/export_test.go index b88e41b..3b7965d 100644 --- a/export_test.go +++ b/export_test.go @@ -28,8 +28,9 @@ func (c *ClusterClient) SlotAddrs(slot int) []string { } // SwapSlot swaps a slot's master/slave address for testing MOVED redirects. -func (c *ClusterClient) SwapSlotNodes(slot int) []string { +func (c *ClusterClient) SwapSlotNodes(slot int) { nodes := c.state().slots[slot] - nodes[0], nodes[1] = nodes[1], nodes[0] - return c.SlotAddrs(slot) + if len(nodes) == 2 { + nodes[0], nodes[1] = nodes[1], nodes[0] + } } diff --git a/internal/internal.go b/internal/internal.go index fb4efa5..ad3fc3c 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -5,19 +5,20 @@ import ( "time" ) -const retryBackoff = 8 * time.Millisecond - // Retry backoff with jitter sleep to prevent overloaded conditions during intervals // https://www.awsarchitectureblog.com/2015/03/backoff.html -func RetryBackoff(retry int, maxRetryBackoff time.Duration) time.Duration { +func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration { if retry < 0 { retry = 0 } - backoff := retryBackoff << uint(retry) - if backoff > maxRetryBackoff { - backoff = maxRetryBackoff + backoff := minBackoff << uint(retry) + if backoff > maxBackoff || backoff < minBackoff { + backoff = maxBackoff } + if backoff == 0 { + return 0 + } return time.Duration(rand.Int63n(int64(backoff))) } diff --git a/internal/internal_test.go b/internal/internal_test.go index 5c7000e..56ff611 100644 --- a/internal/internal_test.go +++ b/internal/internal_test.go @@ -2,15 +2,16 @@ package internal import ( "testing" - . "github.com/onsi/gomega" "time" + + . "github.com/onsi/gomega" ) func TestRetryBackoff(t *testing.T) { RegisterTestingT(t) - - for i := -1; i<= 8; i++ { - backoff := RetryBackoff(i, 512*time.Millisecond) + + for i := -1; i <= 16; i++ { + backoff := RetryBackoff(i, time.Millisecond, 512*time.Millisecond) Expect(backoff >= 0).To(BeTrue()) Expect(backoff <= 512*time.Millisecond).To(BeTrue()) } diff --git a/options.go b/options.go index cd6fa98..efa8f6a 100644 --- a/options.go +++ b/options.go @@ -37,9 +37,11 @@ type Options struct { // Maximum number of retries before giving up. // Default is to not retry failed commands. MaxRetries int - + // Minimum backoff between each retry. + // Default is 8 milliseconds; -1 disables backoff. + MinRetryBackoff time.Duration // Maximum backoff between each retry. - // Default is 512 seconds; -1 disables backoff. + // Default is 512 milliseconds; -1 disables backoff. MaxRetryBackoff time.Duration // Dial timeout for establishing new connections. @@ -118,6 +120,13 @@ func (opt *Options) init() { if opt.IdleCheckFrequency == 0 { opt.IdleCheckFrequency = time.Minute } + + switch opt.MinRetryBackoff { + case -1: + opt.MinRetryBackoff = 0 + case 0: + opt.MinRetryBackoff = 8 * time.Millisecond + } switch opt.MaxRetryBackoff { case -1: opt.MaxRetryBackoff = 0 diff --git a/redis.go b/redis.go index 1a2ecb0..b3a215e 100644 --- a/redis.go +++ b/redis.go @@ -107,13 +107,6 @@ func (c *baseClient) initConn(cn *pool.Conn) error { return nil } -func (c *baseClient) Process(cmd Cmder) error { - if c.process != nil { - return c.process(cmd) - } - return c.defaultProcess(cmd) -} - // WrapProcess replaces the process func. It takes a function createWrapper // which is supplied by the user. createWrapper takes the old process func as // an input and returns the new wrapper process func. createWrapper should @@ -122,10 +115,17 @@ func (c *baseClient) WrapProcess(fn func(oldProcess func(cmd Cmder) error) func( c.process = fn(c.defaultProcess) } +func (c *baseClient) Process(cmd Cmder) error { + if c.process != nil { + return c.process(cmd) + } + return c.defaultProcess(cmd) +} + func (c *baseClient) defaultProcess(cmd Cmder) error { - for i := 0; i <= c.opt.MaxRetries; i++ { - if i > 0 { - time.Sleep(internal.RetryBackoff(i, c.opt.MaxRetryBackoff)) + for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { + if attempt > 0 { + time.Sleep(c.retryBackoff(attempt)) } cn, _, err := c.getConn() @@ -160,6 +160,10 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { return cmd.Err() } +func (c *baseClient) retryBackoff(attempt int) time.Duration { + return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) +} + func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { if timeout := cmd.readTimeout(); timeout != nil { return *timeout