Rework ReceiveMessage

This commit is contained in:
Vladimir Mihailenco 2018-07-23 15:55:13 +03:00
parent f7e97f0a16
commit ea9da7c2e8
10 changed files with 227 additions and 201 deletions

View File

@ -4,6 +4,7 @@
- Ring got new options called `HashReplicas` and `Hash`. It is recommended to set `HashReplicas = 1000` for better keys distribution between shards. - Ring got new options called `HashReplicas` and `Hash`. It is recommended to set `HashReplicas = 1000` for better keys distribution between shards.
- Cluster client was optimized to use much less memory when reloading cluster state. - Cluster client was optimized to use much less memory when reloading cluster state.
- ReceiveMessage is re-worked to not use ReceiveTimeout so it does not lose data when timeout occurres.
## v6.12 ## v6.12

View File

@ -1500,11 +1500,9 @@ func (c *ClusterClient) txPipelineReadQueued(
} }
func (c *ClusterClient) pubSub(channels []string) *PubSub { func (c *ClusterClient) pubSub(channels []string) *PubSub {
opt := c.opt.clientOptions()
var node *clusterNode var node *clusterNode
return &PubSub{ pubsub := &PubSub{
opt: opt, opt: c.opt.clientOptions(),
newConn: func(channels []string) (*pool.Conn, error) { newConn: func(channels []string) (*pool.Conn, error) {
if node == nil { if node == nil {
@ -1527,6 +1525,8 @@ func (c *ClusterClient) pubSub(channels []string) *PubSub {
return node.Client.connPool.CloseConn(cn) return node.Client.connPool.CloseConn(cn)
}, },
} }
pubsub.init()
return pubsub
} }
// Subscribe subscribes the client to the specified channels. // Subscribe subscribes the client to the specified channels.

View File

@ -453,7 +453,7 @@ var _ = Describe("ClusterClient", func() {
ttl := cmds[(i*2)+1].(*redis.DurationCmd) ttl := cmds[(i*2)+1].(*redis.DurationCmd)
dur := time.Duration(i+1) * time.Hour dur := time.Duration(i+1) * time.Hour
Expect(ttl.Val()).To(BeNumerically("~", dur, 10*time.Second)) Expect(ttl.Val()).To(BeNumerically("~", dur, 30*time.Second))
} }
}) })

View File

@ -3,7 +3,6 @@ package redis
import ( import (
"fmt" "fmt"
"net" "net"
"time"
"github.com/go-redis/redis/internal/hashtag" "github.com/go-redis/redis/internal/hashtag"
"github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/pool"
@ -17,10 +16,6 @@ func (c *PubSub) SetNetConn(netConn net.Conn) {
c.cn = pool.NewConn(netConn) c.cn = pool.NewConn(netConn)
} }
func (c *PubSub) ReceiveMessageTimeout(timeout time.Duration) (*Message, error) {
return c.receiveMessage(timeout)
}
func (c *ClusterClient) LoadState() (*clusterState, error) { func (c *ClusterClient) LoadState() (*clusterState, error) {
return c.loadState() return c.loadState()
} }

191
pubsub.go
View File

@ -2,7 +2,6 @@ package redis
import ( import (
"fmt" "fmt"
"net"
"sync" "sync"
"time" "time"
@ -11,11 +10,11 @@ import (
) )
// PubSub implements Pub/Sub commands as described in // PubSub implements Pub/Sub commands as described in
// http://redis.io/topics/pubsub. It's NOT safe for concurrent use by // http://redis.io/topics/pubsub. Message receiving is NOT safe
// multiple goroutines. // for concurrent use by multiple goroutines.
// //
// PubSub automatically resubscribes to the channels and patterns // PubSub automatically reconnects to Redis Server and resubscribes
// when Redis becomes unavailable. // to the channels in case of network errors.
type PubSub struct { type PubSub struct {
opt *Options opt *Options
@ -27,13 +26,21 @@ type PubSub struct {
channels map[string]struct{} channels map[string]struct{}
patterns map[string]struct{} patterns map[string]struct{}
closed bool closed bool
exit chan struct{}
cmd *Cmd cmd *Cmd
pingOnce sync.Once
ping chan struct{}
chOnce sync.Once chOnce sync.Once
ch chan *Message ch chan *Message
} }
func (c *PubSub) init() {
c.exit = make(chan struct{})
}
func (c *PubSub) conn() (*pool.Conn, error) { func (c *PubSub) conn() (*pool.Conn, error) {
c.mu.Lock() c.mu.Lock()
cn, err := c._conn(nil) cn, err := c._conn(nil)
@ -66,31 +73,36 @@ func (c *PubSub) _conn(channels []string) (*pool.Conn, error) {
func (c *PubSub) resubscribe(cn *pool.Conn) error { func (c *PubSub) resubscribe(cn *pool.Conn) error {
var firstErr error var firstErr error
if len(c.channels) > 0 { if len(c.channels) > 0 {
channels := make([]string, len(c.channels)) channels := mapKeys(c.channels)
i := 0 err := c._subscribe(cn, "subscribe", channels...)
for channel := range c.channels { if err != nil && firstErr == nil {
channels[i] = channel
i++
}
if err := c._subscribe(cn, "subscribe", channels...); err != nil && firstErr == nil {
firstErr = err firstErr = err
} }
} }
if len(c.patterns) > 0 { if len(c.patterns) > 0 {
patterns := make([]string, len(c.patterns)) patterns := mapKeys(c.patterns)
i := 0 err := c._subscribe(cn, "psubscribe", patterns...)
for pattern := range c.patterns { if err != nil && firstErr == nil {
patterns[i] = pattern
i++
}
if err := c._subscribe(cn, "psubscribe", patterns...); err != nil && firstErr == nil {
firstErr = err firstErr = err
} }
} }
return firstErr return firstErr
} }
func mapKeys(m map[string]struct{}) []string {
s := make([]string, len(m))
i := 0
for k := range m {
s[i] = k
i++
}
return s
}
func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error { func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error {
args := make([]interface{}, 1+len(channels)) args := make([]interface{}, 1+len(channels))
args[0] = redisCmd args[0] = redisCmd
@ -114,16 +126,30 @@ func (c *PubSub) _releaseConn(cn *pool.Conn, err error) {
return return
} }
if internal.IsBadConn(err, true) { if internal.IsBadConn(err, true) {
_ = c.closeTheCn() c._reconnect()
} }
} }
func (c *PubSub) closeTheCn() error { func (c *PubSub) _closeTheCn() error {
err := c.closeConn(c.cn) var err error
if c.cn != nil {
err = c.closeConn(c.cn)
c.cn = nil c.cn = nil
}
return err return err
} }
func (c *PubSub) reconnect() {
c.mu.Lock()
c._reconnect()
c.mu.Unlock()
}
func (c *PubSub) _reconnect() {
_ = c._closeTheCn()
_, _ = c._conn(nil)
}
func (c *PubSub) Close() error { func (c *PubSub) Close() error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -132,17 +158,18 @@ func (c *PubSub) Close() error {
return pool.ErrClosed return pool.ErrClosed
} }
c.closed = true c.closed = true
close(c.exit)
if c.cn != nil { err := c._closeTheCn()
return c.closeTheCn() return err
}
return nil
} }
// Subscribe the client to the specified channels. It returns // Subscribe the client to the specified channels. It returns
// empty subscription if there are no channels. // empty subscription if there are no channels.
func (c *PubSub) Subscribe(channels ...string) error { func (c *PubSub) Subscribe(channels ...string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock()
err := c.subscribe("subscribe", channels...) err := c.subscribe("subscribe", channels...)
if c.channels == nil { if c.channels == nil {
c.channels = make(map[string]struct{}) c.channels = make(map[string]struct{})
@ -150,7 +177,6 @@ func (c *PubSub) Subscribe(channels ...string) error {
for _, channel := range channels { for _, channel := range channels {
c.channels[channel] = struct{}{} c.channels[channel] = struct{}{}
} }
c.mu.Unlock()
return err return err
} }
@ -158,6 +184,8 @@ func (c *PubSub) Subscribe(channels ...string) error {
// empty subscription if there are no patterns. // empty subscription if there are no patterns.
func (c *PubSub) PSubscribe(patterns ...string) error { func (c *PubSub) PSubscribe(patterns ...string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock()
err := c.subscribe("psubscribe", patterns...) err := c.subscribe("psubscribe", patterns...)
if c.patterns == nil { if c.patterns == nil {
c.patterns = make(map[string]struct{}) c.patterns = make(map[string]struct{})
@ -165,7 +193,6 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
for _, pattern := range patterns { for _, pattern := range patterns {
c.patterns[pattern] = struct{}{} c.patterns[pattern] = struct{}{}
} }
c.mu.Unlock()
return err return err
} }
@ -173,11 +200,12 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
// them if none is given. // them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error { func (c *PubSub) Unsubscribe(channels ...string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock()
err := c.subscribe("unsubscribe", channels...) err := c.subscribe("unsubscribe", channels...)
for _, channel := range channels { for _, channel := range channels {
delete(c.channels, channel) delete(c.channels, channel)
} }
c.mu.Unlock()
return err return err
} }
@ -185,11 +213,12 @@ func (c *PubSub) Unsubscribe(channels ...string) error {
// them if none is given. // them if none is given.
func (c *PubSub) PUnsubscribe(patterns ...string) error { func (c *PubSub) PUnsubscribe(patterns ...string) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock()
err := c.subscribe("punsubscribe", patterns...) err := c.subscribe("punsubscribe", patterns...)
for _, pattern := range patterns { for _, pattern := range patterns {
delete(c.patterns, pattern) delete(c.patterns, pattern)
} }
c.mu.Unlock()
return err return err
} }
@ -298,7 +327,7 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
// ReceiveTimeout acts like Receive but returns an error if message // ReceiveTimeout acts like Receive but returns an error if message
// is not received in time. This is low-level API and most clients // is not received in time. This is low-level API and most clients
// should use ReceiveMessage. // should use ReceiveMessage instead.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
if c.cmd == nil { if c.cmd == nil {
c.cmd = NewCmd() c.cmd = NewCmd()
@ -309,7 +338,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
return nil, err return nil, err
} }
cn.SetReadTimeout(readTimeout(timeout)) cn.SetReadTimeout(timeout)
err = c.cmd.readReply(cn) err = c.cmd.readReply(cn)
c.releaseConn(cn, err) c.releaseConn(cn, err)
if err != nil { if err != nil {
@ -321,48 +350,28 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
// Receive returns a message as a Subscription, Message, Pong or error. // Receive returns a message as a Subscription, Message, Pong or error.
// See PubSub example for details. This is low-level API and most clients // See PubSub example for details. This is low-level API and most clients
// should use ReceiveMessage. // should use ReceiveMessage instead.
func (c *PubSub) Receive() (interface{}, error) { func (c *PubSub) Receive() (interface{}, error) {
return c.ReceiveTimeout(0) return c.ReceiveTimeout(0)
} }
// ReceiveMessage returns a Message or error ignoring Subscription or Pong // ReceiveMessage returns a Message or error ignoring Subscription or Pong
// messages. It automatically reconnects to Redis Server and resubscribes // messages. It periodically sends Ping messages to test connection health.
// to channels in case of network errors.
func (c *PubSub) ReceiveMessage() (*Message, error) { func (c *PubSub) ReceiveMessage() (*Message, error) {
return c.receiveMessage(5 * time.Second) c.pingOnce.Do(c.initPing)
}
func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) {
var errNum uint
for { for {
msgi, err := c.ReceiveTimeout(timeout) msg, err := c.Receive()
if err != nil { if err != nil {
if !internal.IsNetworkError(err) {
return nil, err return nil, err
} }
errNum++ // Any message is as good as a ping.
if errNum < 3 { select {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { case c.ping <- struct{}{}:
err := c.Ping() default:
if err != nil {
internal.Logf("PubSub.Ping failed: %s", err)
}
}
} else {
// 3 consequent errors - connection is broken or
// Redis Server is down.
// Sleep to not exceed max number of open connections.
time.Sleep(time.Second)
}
continue
} }
// Reset error number, because we received a message. switch msg := msg.(type) {
errNum = 0
switch msg := msgi.(type) {
case *Subscription: case *Subscription:
// Ignore. // Ignore.
case *Pong: case *Pong:
@ -370,30 +379,74 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) {
case *Message: case *Message:
return msg, nil return msg, nil
default: default:
return nil, fmt.Errorf("redis: unknown message: %T", msgi) err := fmt.Errorf("redis: unknown message: %T", msg)
return nil, err
} }
} }
} }
// Channel returns a Go channel for concurrently receiving messages. // Channel returns a Go channel for concurrently receiving messages.
// The channel is closed with PubSub. Receive or ReceiveMessage APIs // The channel is closed with PubSub. Receive* APIs can not be used
// can not be used after channel is created. // after channel is created.
func (c *PubSub) Channel() <-chan *Message { func (c *PubSub) Channel() <-chan *Message {
c.chOnce.Do(func() { c.chOnce.Do(c.initChannel)
return c.ch
}
func (c *PubSub) initChannel() {
c.ch = make(chan *Message, 100) c.ch = make(chan *Message, 100)
go func() { go func() {
var errCount int
for { for {
msg, err := c.ReceiveMessage() msg, err := c.ReceiveMessage()
if err != nil { if err != nil {
if err == pool.ErrClosed { if err == pool.ErrClosed {
break close(c.ch)
return
} }
if errCount > 0 {
time.Sleep(c.retryBackoff(errCount))
}
errCount++
continue continue
} }
errCount = 0
c.ch <- msg c.ch <- msg
} }
close(c.ch)
}() }()
}) }
return c.ch
func (c *PubSub) initPing() {
const timeout = 5 * time.Second
c.ping = make(chan struct{}, 10)
go func() {
timer := time.NewTimer(timeout)
timer.Stop()
var hasPing bool
for {
timer.Reset(timeout)
select {
case <-c.ping:
hasPing = true
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
if hasPing {
hasPing = false
_ = c.Ping()
} else {
c.reconnect()
}
case <-c.exit:
return
}
}
}()
}
func (c *PubSub) retryBackoff(attempt int) time.Duration {
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
} }

View File

@ -255,45 +255,6 @@ var _ = Describe("PubSub", func() {
Expect(msg.Payload).To(Equal("world")) Expect(msg.Payload).To(Equal("world"))
}) })
It("should ReceiveMessage after timeout", func() {
timeout := 100 * time.Millisecond
pubsub := client.Subscribe("mychannel")
defer pubsub.Close()
subscr, err := pubsub.ReceiveTimeout(time.Second)
Expect(err).NotTo(HaveOccurred())
Expect(subscr).To(Equal(&redis.Subscription{
Kind: "subscribe",
Channel: "mychannel",
Count: 1,
}))
done := make(chan bool, 1)
go func() {
defer GinkgoRecover()
defer func() {
done <- true
}()
time.Sleep(timeout + 100*time.Millisecond)
n, err := client.Publish("mychannel", "hello").Result()
Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1)))
}()
msg, err := pubsub.ReceiveMessageTimeout(timeout)
Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello"))
Eventually(done).Should(Receive())
stats := client.PoolStats()
Expect(stats.Hits).To(Equal(uint32(1)))
Expect(stats.Misses).To(Equal(uint32(1)))
})
It("returns an error when subscribe fails", func() { It("returns an error when subscribe fails", func() {
pubsub := client.Subscribe() pubsub := client.Subscribe()
defer pubsub.Close() defer pubsub.Close()
@ -316,24 +277,27 @@ var _ = Describe("PubSub", func() {
writeErr: io.EOF, writeErr: io.EOF,
}) })
done := make(chan bool, 1) step := make(chan struct{}, 3)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
defer func() {
done <- true
}()
time.Sleep(100 * time.Millisecond) Eventually(step).Should(Receive())
err := client.Publish("mychannel", "hello").Err() err := client.Publish("mychannel", "hello").Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
step <- struct{}{}
}() }()
_, err := pubsub.ReceiveMessage()
Expect(err).To(Equal(io.EOF))
step <- struct{}{}
msg, err := pubsub.ReceiveMessage() msg, err := pubsub.ReceiveMessage()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello")) Expect(msg.Payload).To(Equal("hello"))
Eventually(done).Should(Receive()) Eventually(step).Should(Receive())
} }
It("Subscribe should reconnect on ReceiveMessage error", func() { It("Subscribe should reconnect on ReceiveMessage error", func() {
@ -380,9 +344,9 @@ var _ = Describe("PubSub", func() {
_, err := pubsub.ReceiveMessage() _, err := pubsub.ReceiveMessage()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err).To(SatisfyAny( Expect(err.Error()).To(SatisfyAny(
MatchError("redis: client is closed"), Equal("redis: client is closed"),
MatchError("use of closed network connection"), // Go 1.4 ContainSubstring("use of closed network connection"),
)) ))
}() }()
@ -406,7 +370,7 @@ var _ = Describe("PubSub", func() {
defer GinkgoRecover() defer GinkgoRecover()
defer wg.Done() defer wg.Done()
time.Sleep(2 * timeout) time.Sleep(timeout)
err := pubsub.Subscribe("mychannel") err := pubsub.Subscribe("mychannel")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -417,7 +381,7 @@ var _ = Describe("PubSub", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}() }()
msg, err := pubsub.ReceiveMessageTimeout(timeout) msg, err := pubsub.ReceiveMessage()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello")) Expect(msg.Payload).To(Equal("hello"))

View File

@ -423,7 +423,7 @@ func (c *Client) TxPipeline() Pipeliner {
} }
func (c *Client) pubSub() *PubSub { func (c *Client) pubSub() *PubSub {
return &PubSub{ pubsub := &PubSub{
opt: c.opt, opt: c.opt,
newConn: func(channels []string) (*pool.Conn, error) { newConn: func(channels []string) (*pool.Conn, error) {
@ -431,6 +431,8 @@ func (c *Client) pubSub() *PubSub {
}, },
closeConn: c.connPool.CloseConn, closeConn: c.connPool.CloseConn,
} }
pubsub.init()
return pubsub
} }
// Subscribe subscribes the client to the specified channels. // Subscribe subscribes the client to the specified channels.

16
ring.go
View File

@ -165,6 +165,7 @@ type ringShards struct {
hash *consistenthash.Map hash *consistenthash.Map
shards map[string]*ringShard // read only shards map[string]*ringShard // read only
list []*ringShard // read only list []*ringShard // read only
len int
closed bool closed bool
} }
@ -269,17 +270,27 @@ func (c *ringShards) Heartbeat(frequency time.Duration) {
// rebalance removes dead shards from the Ring. // rebalance removes dead shards from the Ring.
func (c *ringShards) rebalance() { func (c *ringShards) rebalance() {
hash := newConsistentHash(c.opt) hash := newConsistentHash(c.opt)
var shardsNum int
for name, shard := range c.shards { for name, shard := range c.shards {
if shard.IsUp() { if shard.IsUp() {
hash.Add(name) hash.Add(name)
shardsNum++
} }
} }
c.mu.Lock() c.mu.Lock()
c.hash = hash c.hash = hash
c.len = shardsNum
c.mu.Unlock() c.mu.Unlock()
} }
func (c *ringShards) Len() int {
c.mu.RLock()
l := c.len
c.mu.RUnlock()
return l
}
func (c *ringShards) Close() error { func (c *ringShards) Close() error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -398,6 +409,11 @@ func (c *Ring) PoolStats() *PoolStats {
return &acc return &acc
} }
// Len returns the current number of shards in the ring.
func (c *Ring) Len() int {
return c.shards.Len()
}
// Subscribe subscribes the client to the specified channels. // Subscribe subscribes the client to the specified channels.
func (c *Ring) Subscribe(channels ...string) *PubSub { func (c *Ring) Subscribe(channels ...string) *PubSub {
if len(channels) == 0 { if len(channels) == 0 {

View File

@ -42,8 +42,8 @@ var _ = Describe("Redis Ring", func() {
setRingKeys() setRingKeys()
// Both shards should have some keys now. // Both shards should have some keys now.
Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=57")) Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=57"))
Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43")) Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43"))
}) })
It("distributes keys when using EVAL", func() { It("distributes keys when using EVAL", func() {
@ -59,41 +59,36 @@ var _ = Describe("Redis Ring", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=57")) Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=57"))
Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43")) Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43"))
}) })
It("uses single shard when one of the shards is down", func() { It("uses single shard when one of the shards is down", func() {
// Stop ringShard2. // Stop ringShard2.
Expect(ringShard2.Close()).NotTo(HaveOccurred()) Expect(ringShard2.Close()).NotTo(HaveOccurred())
// Ring needs 3 * heartbeat time to detect that node is down. Eventually(func() int {
// Give it more to be sure. return ring.Len()
time.Sleep(2 * 3 * heartbeat) }, "30s").Should(Equal(1))
setRingKeys() setRingKeys()
// RingShard1 should have all keys. // RingShard1 should have all keys.
Expect(ringShard1.Info().Val()).To(ContainSubstring("keys=100")) Expect(ringShard1.Info("keyspace").Val()).To(ContainSubstring("keys=100"))
// Start ringShard2. // Start ringShard2.
var err error var err error
ringShard2, err = startRedis(ringShard2Port) ringShard2, err = startRedis(ringShard2Port)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Wait for ringShard2 to come up. Eventually(func() int {
Eventually(func() error { return ring.Len()
return ringShard2.Ping().Err() }, "30s").Should(Equal(2))
}, "1s").ShouldNot(HaveOccurred())
// Ring needs heartbeat time to detect that node is up.
// Give it more to be sure.
time.Sleep(heartbeat + heartbeat)
setRingKeys() setRingKeys()
// RingShard2 should have its keys. // RingShard2 should have its keys.
Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=43")) Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=43"))
}) })
It("supports hash tags", func() { It("supports hash tags", func() {
@ -102,8 +97,8 @@ var _ = Describe("Redis Ring", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
Expect(ringShard1.Info().Val()).ToNot(ContainSubstring("keys=")) Expect(ringShard1.Info("keyspace").Val()).ToNot(ContainSubstring("keys="))
Expect(ringShard2.Info().Val()).To(ContainSubstring("keys=100")) Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=100"))
}) })
Describe("pipeline", func() { Describe("pipeline", func() {

View File

@ -116,7 +116,7 @@ func NewSentinelClient(opt *Options) *SentinelClient {
} }
func (c *SentinelClient) PubSub() *PubSub { func (c *SentinelClient) PubSub() *PubSub {
return &PubSub{ pubsub := &PubSub{
opt: c.opt, opt: c.opt,
newConn: func(channels []string) (*pool.Conn, error) { newConn: func(channels []string) (*pool.Conn, error) {
@ -124,6 +124,8 @@ func (c *SentinelClient) PubSub() *PubSub {
}, },
closeConn: c.connPool.CloseConn, closeConn: c.connPool.CloseConn,
} }
pubsub.init()
return pubsub
} }
func (c *SentinelClient) GetMasterAddrByName(name string) *StringSliceCmd { func (c *SentinelClient) GetMasterAddrByName(name string) *StringSliceCmd {
@ -180,10 +182,7 @@ func (d *sentinelFailover) MasterAddr() (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
d._switchMaster(addr)
if d._masterAddr != addr {
d.switchMaster(addr)
}
return addr, nil return addr, nil
} }
@ -194,11 +193,11 @@ func (d *sentinelFailover) masterAddr() (string, error) {
addr, err := d.sentinel.GetMasterAddrByName(d.masterName).Result() addr, err := d.sentinel.GetMasterAddrByName(d.masterName).Result()
if err == nil { if err == nil {
addr := net.JoinHostPort(addr[0], addr[1]) addr := net.JoinHostPort(addr[0], addr[1])
internal.Logf("sentinel: master=%q addr=%q", d.masterName, addr)
return addr, nil return addr, nil
} }
internal.Logf("sentinel: GetMasterAddrByName name=%q failed: %s", d.masterName, err) internal.Logf("sentinel: GetMasterAddrByName name=%q failed: %s",
d.masterName, err)
d._resetSentinel() d._resetSentinel()
} }
@ -234,15 +233,23 @@ func (d *sentinelFailover) masterAddr() (string, error) {
return "", errors.New("redis: all sentinels are unreachable") return "", errors.New("redis: all sentinels are unreachable")
} }
func (d *sentinelFailover) switchMaster(masterAddr string) { func (c *sentinelFailover) switchMaster(addr string) {
internal.Logf( c.mu.Lock()
"sentinel: new master=%q addr=%q", c._switchMaster(addr)
d.masterName, masterAddr, c.mu.Unlock()
) }
_ = d.Pool().Filter(func(cn *pool.Conn) bool {
return cn.RemoteAddr().String() != masterAddr func (c *sentinelFailover) _switchMaster(addr string) {
if c._masterAddr == addr {
return
}
internal.Logf("sentinel: new master=%q addr=%q",
c.masterName, addr)
_ = c.Pool().Filter(func(cn *pool.Conn) bool {
return cn.RemoteAddr().String() != addr
}) })
d._masterAddr = masterAddr c._masterAddr = addr
} }
func (d *sentinelFailover) setSentinel(sentinel *SentinelClient) { func (d *sentinelFailover) setSentinel(sentinel *SentinelClient) {
@ -292,28 +299,26 @@ func (d *sentinelFailover) discoverSentinels(sentinel *SentinelClient) {
} }
func (d *sentinelFailover) listen(sentinel *SentinelClient) { func (d *sentinelFailover) listen(sentinel *SentinelClient) {
var pubsub *PubSub pubsub := sentinel.PubSub()
for { defer pubsub.Close()
if pubsub == nil {
pubsub = sentinel.PubSub()
if err := pubsub.Subscribe("+switch-master"); err != nil { err := pubsub.Subscribe("+switch-master")
if err != nil {
internal.Logf("sentinel: Subscribe failed: %s", err) internal.Logf("sentinel: Subscribe failed: %s", err)
pubsub.Close()
d.resetSentinel() d.resetSentinel()
return return
} }
}
for {
msg, err := pubsub.ReceiveMessage() msg, err := pubsub.ReceiveMessage()
if err != nil { if err != nil {
if err != pool.ErrClosed { if err == pool.ErrClosed {
internal.Logf("sentinel: ReceiveMessage failed: %s", err)
pubsub.Close()
}
d.resetSentinel() d.resetSentinel()
return return
} }
internal.Logf("sentinel: ReceiveMessage failed: %s", err)
continue
}
switch msg.Channel { switch msg.Channel {
case "+switch-master": case "+switch-master":
@ -323,13 +328,8 @@ func (d *sentinelFailover) listen(sentinel *SentinelClient) {
continue continue
} }
addr := net.JoinHostPort(parts[3], parts[4]) addr := net.JoinHostPort(parts[3], parts[4])
d.mu.Lock()
if d._masterAddr != addr {
d.switchMaster(addr) d.switchMaster(addr)
} }
d.mu.Unlock()
}
} }
} }