diff --git a/CHANGELOG.md b/CHANGELOG.md index b4cf05c0..b35c4dd1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Ring got new options called `HashReplicas` and `Hash`. It is recommended to set `HashReplicas = 1000` for better keys distribution between shards. - Cluster client was optimized to use much less memory when reloading cluster state. +- ReceiveMessage is re-worked to not use ReceiveTimeout so it does not lose data when timeout occurres. ## v6.12 diff --git a/cluster.go b/cluster.go index 84a026c2..7a1af143 100644 --- a/cluster.go +++ b/cluster.go @@ -1500,11 +1500,9 @@ func (c *ClusterClient) txPipelineReadQueued( } func (c *ClusterClient) pubSub(channels []string) *PubSub { - opt := c.opt.clientOptions() - var node *clusterNode - return &PubSub{ - opt: opt, + pubsub := &PubSub{ + opt: c.opt.clientOptions(), newConn: func(channels []string) (*pool.Conn, error) { if node == nil { @@ -1527,6 +1525,8 @@ func (c *ClusterClient) pubSub(channels []string) *PubSub { return node.Client.connPool.CloseConn(cn) }, } + pubsub.init() + return pubsub } // Subscribe subscribes the client to the specified channels. diff --git a/cluster_test.go b/cluster_test.go index e94a5099..f9c3a90f 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -453,7 +453,7 @@ var _ = Describe("ClusterClient", func() { ttl := cmds[(i*2)+1].(*redis.DurationCmd) dur := time.Duration(i+1) * time.Hour - Expect(ttl.Val()).To(BeNumerically("~", dur, 10*time.Second)) + Expect(ttl.Val()).To(BeNumerically("~", dur, 30*time.Second)) } }) diff --git a/export_test.go b/export_test.go index e9afda93..fab91e2a 100644 --- a/export_test.go +++ b/export_test.go @@ -3,7 +3,6 @@ package redis import ( "fmt" "net" - "time" "github.com/go-redis/redis/internal/hashtag" "github.com/go-redis/redis/internal/pool" @@ -17,10 +16,6 @@ func (c *PubSub) SetNetConn(netConn net.Conn) { c.cn = pool.NewConn(netConn) } -func (c *PubSub) ReceiveMessageTimeout(timeout time.Duration) (*Message, error) { - return c.receiveMessage(timeout) -} - func (c *ClusterClient) LoadState() (*clusterState, error) { return c.loadState() } diff --git a/pubsub.go b/pubsub.go index dbf6d1d1..b76dbc6d 100644 --- a/pubsub.go +++ b/pubsub.go @@ -2,7 +2,6 @@ package redis import ( "fmt" - "net" "sync" "time" @@ -11,11 +10,11 @@ import ( ) // PubSub implements Pub/Sub commands as described in -// http://redis.io/topics/pubsub. It's NOT safe for concurrent use by -// multiple goroutines. +// http://redis.io/topics/pubsub. Message receiving is NOT safe +// for concurrent use by multiple goroutines. // -// PubSub automatically resubscribes to the channels and patterns -// when Redis becomes unavailable. +// PubSub automatically reconnects to Redis Server and resubscribes +// to the channels in case of network errors. type PubSub struct { opt *Options @@ -27,13 +26,21 @@ type PubSub struct { channels map[string]struct{} patterns map[string]struct{} closed bool + exit chan struct{} cmd *Cmd + pingOnce sync.Once + ping chan struct{} + chOnce sync.Once ch chan *Message } +func (c *PubSub) init() { + c.exit = make(chan struct{}) +} + func (c *PubSub) conn() (*pool.Conn, error) { c.mu.Lock() cn, err := c._conn(nil) @@ -66,31 +73,36 @@ func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { func (c *PubSub) resubscribe(cn *pool.Conn) error { var firstErr error + if len(c.channels) > 0 { - channels := make([]string, len(c.channels)) - i := 0 - for channel := range c.channels { - channels[i] = channel - i++ - } - if err := c._subscribe(cn, "subscribe", channels...); err != nil && firstErr == nil { + channels := mapKeys(c.channels) + err := c._subscribe(cn, "subscribe", channels...) + if err != nil && firstErr == nil { firstErr = err } } + if len(c.patterns) > 0 { - patterns := make([]string, len(c.patterns)) - i := 0 - for pattern := range c.patterns { - patterns[i] = pattern - i++ - } - if err := c._subscribe(cn, "psubscribe", patterns...); err != nil && firstErr == nil { + patterns := mapKeys(c.patterns) + err := c._subscribe(cn, "psubscribe", patterns...) + if err != nil && firstErr == nil { firstErr = err } } + return firstErr } +func mapKeys(m map[string]struct{}) []string { + s := make([]string, len(m)) + i := 0 + for k := range m { + s[i] = k + i++ + } + return s +} + func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error { args := make([]interface{}, 1+len(channels)) args[0] = redisCmd @@ -114,16 +126,30 @@ func (c *PubSub) _releaseConn(cn *pool.Conn, err error) { return } if internal.IsBadConn(err, true) { - _ = c.closeTheCn() + c._reconnect() } } -func (c *PubSub) closeTheCn() error { - err := c.closeConn(c.cn) - c.cn = nil +func (c *PubSub) _closeTheCn() error { + var err error + if c.cn != nil { + err = c.closeConn(c.cn) + c.cn = nil + } return err } +func (c *PubSub) reconnect() { + c.mu.Lock() + c._reconnect() + c.mu.Unlock() +} + +func (c *PubSub) _reconnect() { + _ = c._closeTheCn() + _, _ = c._conn(nil) +} + func (c *PubSub) Close() error { c.mu.Lock() defer c.mu.Unlock() @@ -132,17 +158,18 @@ func (c *PubSub) Close() error { return pool.ErrClosed } c.closed = true + close(c.exit) - if c.cn != nil { - return c.closeTheCn() - } - return nil + err := c._closeTheCn() + return err } // Subscribe the client to the specified channels. It returns // empty subscription if there are no channels. func (c *PubSub) Subscribe(channels ...string) error { c.mu.Lock() + defer c.mu.Unlock() + err := c.subscribe("subscribe", channels...) if c.channels == nil { c.channels = make(map[string]struct{}) @@ -150,7 +177,6 @@ func (c *PubSub) Subscribe(channels ...string) error { for _, channel := range channels { c.channels[channel] = struct{}{} } - c.mu.Unlock() return err } @@ -158,6 +184,8 @@ func (c *PubSub) Subscribe(channels ...string) error { // empty subscription if there are no patterns. func (c *PubSub) PSubscribe(patterns ...string) error { c.mu.Lock() + defer c.mu.Unlock() + err := c.subscribe("psubscribe", patterns...) if c.patterns == nil { c.patterns = make(map[string]struct{}) @@ -165,7 +193,6 @@ func (c *PubSub) PSubscribe(patterns ...string) error { for _, pattern := range patterns { c.patterns[pattern] = struct{}{} } - c.mu.Unlock() return err } @@ -173,11 +200,12 @@ func (c *PubSub) PSubscribe(patterns ...string) error { // them if none is given. func (c *PubSub) Unsubscribe(channels ...string) error { c.mu.Lock() + defer c.mu.Unlock() + err := c.subscribe("unsubscribe", channels...) for _, channel := range channels { delete(c.channels, channel) } - c.mu.Unlock() return err } @@ -185,11 +213,12 @@ func (c *PubSub) Unsubscribe(channels ...string) error { // them if none is given. func (c *PubSub) PUnsubscribe(patterns ...string) error { c.mu.Lock() + defer c.mu.Unlock() + err := c.subscribe("punsubscribe", patterns...) for _, pattern := range patterns { delete(c.patterns, pattern) } - c.mu.Unlock() return err } @@ -298,7 +327,7 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { // ReceiveTimeout acts like Receive but returns an error if message // is not received in time. This is low-level API and most clients -// should use ReceiveMessage. +// should use ReceiveMessage instead. func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { if c.cmd == nil { c.cmd = NewCmd() @@ -309,7 +338,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { return nil, err } - cn.SetReadTimeout(readTimeout(timeout)) + cn.SetReadTimeout(timeout) err = c.cmd.readReply(cn) c.releaseConn(cn, err) if err != nil { @@ -321,48 +350,28 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { // Receive returns a message as a Subscription, Message, Pong or error. // See PubSub example for details. This is low-level API and most clients -// should use ReceiveMessage. +// should use ReceiveMessage instead. func (c *PubSub) Receive() (interface{}, error) { return c.ReceiveTimeout(0) } // ReceiveMessage returns a Message or error ignoring Subscription or Pong -// messages. It automatically reconnects to Redis Server and resubscribes -// to channels in case of network errors. +// messages. It periodically sends Ping messages to test connection health. func (c *PubSub) ReceiveMessage() (*Message, error) { - return c.receiveMessage(5 * time.Second) -} - -func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) { - var errNum uint + c.pingOnce.Do(c.initPing) for { - msgi, err := c.ReceiveTimeout(timeout) + msg, err := c.Receive() if err != nil { - if !internal.IsNetworkError(err) { - return nil, err - } - - errNum++ - if errNum < 3 { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - err := c.Ping() - if err != nil { - internal.Logf("PubSub.Ping failed: %s", err) - } - } - } else { - // 3 consequent errors - connection is broken or - // Redis Server is down. - // Sleep to not exceed max number of open connections. - time.Sleep(time.Second) - } - continue + return nil, err } - // Reset error number, because we received a message. - errNum = 0 + // Any message is as good as a ping. + select { + case c.ping <- struct{}{}: + default: + } - switch msg := msgi.(type) { + switch msg := msg.(type) { case *Subscription: // Ignore. case *Pong: @@ -370,30 +379,74 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) { case *Message: return msg, nil default: - return nil, fmt.Errorf("redis: unknown message: %T", msgi) + err := fmt.Errorf("redis: unknown message: %T", msg) + return nil, err } } } // Channel returns a Go channel for concurrently receiving messages. -// The channel is closed with PubSub. Receive or ReceiveMessage APIs -// can not be used after channel is created. +// The channel is closed with PubSub. Receive* APIs can not be used +// after channel is created. func (c *PubSub) Channel() <-chan *Message { - c.chOnce.Do(func() { - c.ch = make(chan *Message, 100) - go func() { - for { - msg, err := c.ReceiveMessage() - if err != nil { - if err == pool.ErrClosed { - break - } - continue - } - c.ch <- msg - } - close(c.ch) - }() - }) + c.chOnce.Do(c.initChannel) return c.ch } + +func (c *PubSub) initChannel() { + c.ch = make(chan *Message, 100) + go func() { + var errCount int + for { + msg, err := c.ReceiveMessage() + if err != nil { + if err == pool.ErrClosed { + close(c.ch) + return + } + if errCount > 0 { + time.Sleep(c.retryBackoff(errCount)) + } + errCount++ + continue + } + errCount = 0 + c.ch <- msg + } + }() +} + +func (c *PubSub) initPing() { + const timeout = 5 * time.Second + + c.ping = make(chan struct{}, 10) + go func() { + timer := time.NewTimer(timeout) + timer.Stop() + + var hasPing bool + for { + timer.Reset(timeout) + select { + case <-c.ping: + hasPing = true + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + if hasPing { + hasPing = false + _ = c.Ping() + } else { + c.reconnect() + } + case <-c.exit: + return + } + } + }() +} + +func (c *PubSub) retryBackoff(attempt int) time.Duration { + return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) +} diff --git a/pubsub_test.go b/pubsub_test.go index 6a85bd03..059b4a60 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -255,45 +255,6 @@ var _ = Describe("PubSub", func() { Expect(msg.Payload).To(Equal("world")) }) - It("should ReceiveMessage after timeout", func() { - timeout := 100 * time.Millisecond - - pubsub := client.Subscribe("mychannel") - defer pubsub.Close() - - subscr, err := pubsub.ReceiveTimeout(time.Second) - Expect(err).NotTo(HaveOccurred()) - Expect(subscr).To(Equal(&redis.Subscription{ - Kind: "subscribe", - Channel: "mychannel", - Count: 1, - })) - - done := make(chan bool, 1) - go func() { - defer GinkgoRecover() - defer func() { - done <- true - }() - - time.Sleep(timeout + 100*time.Millisecond) - n, err := client.Publish("mychannel", "hello").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(n).To(Equal(int64(1))) - }() - - msg, err := pubsub.ReceiveMessageTimeout(timeout) - Expect(err).NotTo(HaveOccurred()) - Expect(msg.Channel).To(Equal("mychannel")) - Expect(msg.Payload).To(Equal("hello")) - - Eventually(done).Should(Receive()) - - stats := client.PoolStats() - Expect(stats.Hits).To(Equal(uint32(1))) - Expect(stats.Misses).To(Equal(uint32(1))) - }) - It("returns an error when subscribe fails", func() { pubsub := client.Subscribe() defer pubsub.Close() @@ -316,24 +277,27 @@ var _ = Describe("PubSub", func() { writeErr: io.EOF, }) - done := make(chan bool, 1) + step := make(chan struct{}, 3) + go func() { defer GinkgoRecover() - defer func() { - done <- true - }() - time.Sleep(100 * time.Millisecond) + Eventually(step).Should(Receive()) err := client.Publish("mychannel", "hello").Err() Expect(err).NotTo(HaveOccurred()) + step <- struct{}{} }() + _, err := pubsub.ReceiveMessage() + Expect(err).To(Equal(io.EOF)) + step <- struct{}{} + msg, err := pubsub.ReceiveMessage() Expect(err).NotTo(HaveOccurred()) Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Payload).To(Equal("hello")) - Eventually(done).Should(Receive()) + Eventually(step).Should(Receive()) } It("Subscribe should reconnect on ReceiveMessage error", func() { @@ -380,9 +344,9 @@ var _ = Describe("PubSub", func() { _, err := pubsub.ReceiveMessage() Expect(err).To(HaveOccurred()) - Expect(err).To(SatisfyAny( - MatchError("redis: client is closed"), - MatchError("use of closed network connection"), // Go 1.4 + Expect(err.Error()).To(SatisfyAny( + Equal("redis: client is closed"), + ContainSubstring("use of closed network connection"), )) }() @@ -406,7 +370,7 @@ var _ = Describe("PubSub", func() { defer GinkgoRecover() defer wg.Done() - time.Sleep(2 * timeout) + time.Sleep(timeout) err := pubsub.Subscribe("mychannel") Expect(err).NotTo(HaveOccurred()) @@ -417,7 +381,7 @@ var _ = Describe("PubSub", func() { Expect(err).NotTo(HaveOccurred()) }() - msg, err := pubsub.ReceiveMessageTimeout(timeout) + msg, err := pubsub.ReceiveMessage() Expect(err).NotTo(HaveOccurred()) Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Payload).To(Equal("hello")) diff --git a/redis.go b/redis.go index ff065148..c0f142cc 100644 --- a/redis.go +++ b/redis.go @@ -423,7 +423,7 @@ func (c *Client) TxPipeline() Pipeliner { } func (c *Client) pubSub() *PubSub { - return &PubSub{ + pubsub := &PubSub{ opt: c.opt, newConn: func(channels []string) (*pool.Conn, error) { @@ -431,6 +431,8 @@ func (c *Client) pubSub() *PubSub { }, closeConn: c.connPool.CloseConn, } + pubsub.init() + return pubsub } // Subscribe subscribes the client to the specified channels. diff --git a/ring.go b/ring.go index 8b20d476..ef855115 100644 --- a/ring.go +++ b/ring.go @@ -165,6 +165,7 @@ type ringShards struct { hash *consistenthash.Map shards map[string]*ringShard // read only list []*ringShard // read only + len int closed bool } @@ -269,17 +270,27 @@ func (c *ringShards) Heartbeat(frequency time.Duration) { // rebalance removes dead shards from the Ring. func (c *ringShards) rebalance() { hash := newConsistentHash(c.opt) + var shardsNum int for name, shard := range c.shards { if shard.IsUp() { hash.Add(name) + shardsNum++ } } c.mu.Lock() c.hash = hash + c.len = shardsNum c.mu.Unlock() } +func (c *ringShards) Len() int { + c.mu.RLock() + l := c.len + c.mu.RUnlock() + return l +} + func (c *ringShards) Close() error { c.mu.Lock() defer c.mu.Unlock() @@ -398,6 +409,11 @@ func (c *Ring) PoolStats() *PoolStats { return &acc } +// Len returns the current number of shards in the ring. +func (c *Ring) Len() int { + return c.shards.Len() +} + // Subscribe subscribes the client to the specified channels. func (c *Ring) Subscribe(channels ...string) *PubSub { if len(channels) == 0 { diff --git a/ring_test.go b/ring_test.go index 0cad4298..1f5bf0d6 100644 --- a/ring_test.go +++ b/ring_test.go @@ -42,8 +42,8 @@ var _ = Describe("Redis Ring", func() { setRingKeys() // Both shards should have some keys now. - Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=57")) - Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43")) + Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=57")) + Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43")) }) It("distributes keys when using EVAL", func() { @@ -59,41 +59,36 @@ var _ = Describe("Redis Ring", func() { Expect(err).NotTo(HaveOccurred()) } - Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=57")) - Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43")) + Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=57")) + Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43")) }) It("uses single shard when one of the shards is down", func() { // Stop ringShard2. Expect(ringShard2.Close()).NotTo(HaveOccurred()) - // Ring needs 3 * heartbeat time to detect that node is down. - // Give it more to be sure. - time.Sleep(2 * 3 * heartbeat) + Eventually(func() int { + return ring.Len() + }, "30s").Should(Equal(1)) setRingKeys() // RingShard1 should have all keys. - Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=100")) + Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=100")) // Start ringShard2. var err error ringShard2, err = startRedis(ringShard2Port) Expect(err).NotTo(HaveOccurred()) - // Wait for ringShard2 to come up. - Eventually(func() error { - return ringShard2.Ping().Err() - }, "1s").ShouldNot(HaveOccurred()) - - // Ring needs heartbeat time to detect that node is up. - // Give it more to be sure. - time.Sleep(heartbeat + heartbeat) + Eventually(func() int { + return ring.Len() + }, "30s").Should(Equal(2)) setRingKeys() // RingShard2 should have its keys. - Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43")) + Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43")) }) It("supports hash tags", func() { @@ -102,8 +97,8 @@ var _ = Describe("Redis Ring", func() { Expect(err).NotTo(HaveOccurred()) } - Expect(ringShard1.Info().Val()).ToNot(ContainSubstring("keys=")) - Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=100")) + Expect(ringShard1.Info("keyspace").Val()).ToNot(ContainSubstring("keys=")) + Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=100")) }) Describe("pipeline", func() { diff --git a/sentinel.go b/sentinel.go index 3cedf36e..12c29a71 100644 --- a/sentinel.go +++ b/sentinel.go @@ -116,7 +116,7 @@ func NewSentinelClient(opt *Options) *SentinelClient { } func (c *SentinelClient) PubSub() *PubSub { - return &PubSub{ + pubsub := &PubSub{ opt: c.opt, newConn: func(channels []string) (*pool.Conn, error) { @@ -124,6 +124,8 @@ func (c *SentinelClient) PubSub() *PubSub { }, closeConn: c.connPool.CloseConn, } + pubsub.init() + return pubsub } func (c *SentinelClient) GetMasterAddrByName(name string) *StringSliceCmd { @@ -180,10 +182,7 @@ func (d *sentinelFailover) MasterAddr() (string, error) { if err != nil { return "", err } - - if d._masterAddr != addr { - d.switchMaster(addr) - } + d._switchMaster(addr) return addr, nil } @@ -194,11 +193,11 @@ func (d *sentinelFailover) masterAddr() (string, error) { addr, err := d.sentinel.GetMasterAddrByName(d.masterName).Result() if err == nil { addr := net.JoinHostPort(addr[0], addr[1]) - internal.Logf("sentinel: master=%q addr=%q", d.masterName, addr) return addr, nil } - internal.Logf("sentinel: GetMasterAddrByName name=%q failed: %s", d.masterName, err) + internal.Logf("sentinel: GetMasterAddrByName name=%q failed: %s", + d.masterName, err) d._resetSentinel() } @@ -234,15 +233,23 @@ func (d *sentinelFailover) masterAddr() (string, error) { return "", errors.New("redis: all sentinels are unreachable") } -func (d *sentinelFailover) switchMaster(masterAddr string) { - internal.Logf( - "sentinel: new master=%q addr=%q", - d.masterName, masterAddr, - ) - _ = d.Pool().Filter(func(cn *pool.Conn) bool { - return cn.RemoteAddr().String() != masterAddr +func (c *sentinelFailover) switchMaster(addr string) { + c.mu.Lock() + c._switchMaster(addr) + c.mu.Unlock() +} + +func (c *sentinelFailover) _switchMaster(addr string) { + if c._masterAddr == addr { + return + } + + internal.Logf("sentinel: new master=%q addr=%q", + c.masterName, addr) + _ = c.Pool().Filter(func(cn *pool.Conn) bool { + return cn.RemoteAddr().String() != addr }) - d._masterAddr = masterAddr + c._masterAddr = addr } func (d *sentinelFailover) setSentinel(sentinel *SentinelClient) { @@ -292,27 +299,25 @@ func (d *sentinelFailover) discoverSentinels(sentinel *SentinelClient) { } func (d *sentinelFailover) listen(sentinel *SentinelClient) { - var pubsub *PubSub - for { - if pubsub == nil { - pubsub = sentinel.PubSub() + pubsub := sentinel.PubSub() + defer pubsub.Close() - if err := pubsub.Subscribe("+switch-master"); err != nil { - internal.Logf("sentinel: Subscribe failed: %s", err) - pubsub.Close() + err := pubsub.Subscribe("+switch-master") + if err != nil { + internal.Logf("sentinel: Subscribe failed: %s", err) + d.resetSentinel() + return + } + + for { + msg, err := pubsub.ReceiveMessage() + if err != nil { + if err == pool.ErrClosed { d.resetSentinel() return } - } - - msg, err := pubsub.ReceiveMessage() - if err != nil { - if err != pool.ErrClosed { - internal.Logf("sentinel: ReceiveMessage failed: %s", err) - pubsub.Close() - } - d.resetSentinel() - return + internal.Logf("sentinel: ReceiveMessage failed: %s", err) + continue } switch msg.Channel { @@ -323,12 +328,7 @@ func (d *sentinelFailover) listen(sentinel *SentinelClient) { continue } addr := net.JoinHostPort(parts[3], parts[4]) - - d.mu.Lock() - if d._masterAddr != addr { - d.switchMaster(addr) - } - d.mu.Unlock() + d.switchMaster(addr) } } }