diff --git a/export_test.go b/export_test.go index 36b2c547..95715b5f 100644 --- a/export_test.go +++ b/export_test.go @@ -2,7 +2,6 @@ package redis import ( "net" - "sync" "time" ) @@ -20,18 +19,12 @@ func (cn *conn) SetNetConn(netcn net.Conn) { cn.netcn = netcn } -var timeMu sync.Mutex - func SetTime(tm time.Time) { - timeMu.Lock() now = func() time.Time { return tm } - timeMu.Unlock() } func RestoreTime() { - timeMu.Lock() now = time.Now - timeMu.Unlock() } diff --git a/main_test.go b/main_test.go index b9b3e218..d298dd21 100644 --- a/main_test.go +++ b/main_test.go @@ -98,9 +98,10 @@ func TestGinkgoSuite(t *testing.T) { //------------------------------------------------------------------------------ -func eventually(fn func() error, timeout time.Duration) (err error) { +func eventually(fn func() error, timeout time.Duration) error { done := make(chan struct{}) var exit int32 + var err error go func() { for atomic.LoadInt32(&exit) == 0 { err = fn() diff --git a/redis.go b/redis.go index 5af7d683..5558ad10 100644 --- a/redis.go +++ b/redis.go @@ -13,6 +13,8 @@ var Logger = log.New(os.Stderr, "redis: ", log.LstdFlags) type baseClient struct { connPool pool opt *Options + + onClose func() error // hook called when client is closed } func (c *baseClient) String() string { @@ -83,7 +85,16 @@ func (c *baseClient) process(cmd Cmder) { // It is rare to Close a Client, as the Client is meant to be // long-lived and shared between many goroutines. func (c *baseClient) Close() error { - return c.connPool.Close() + var retErr error + if c.onClose != nil { + if err := c.onClose(); err != nil && retErr == nil { + retErr = err + } + } + if err := c.connPool.Close(); err != nil && retErr == nil { + retErr = err + } + return retErr } //------------------------------------------------------------------------------ @@ -186,8 +197,10 @@ type Client struct { func newClient(opt *Options, pool pool) *Client { base := baseClient{opt: opt, connPool: pool} return &Client{ - baseClient: base, - commandable: commandable{process: base.process}, + baseClient: base, + commandable: commandable{ + process: base.process, + }, } } diff --git a/sentinel.go b/sentinel.go index 175c57e8..db5db64d 100644 --- a/sentinel.go +++ b/sentinel.go @@ -65,18 +65,31 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { opt: opt, } - return newClient(opt, failover.Pool()) + base := baseClient{ + opt: opt, + connPool: failover.Pool(), + + onClose: func() error { + return failover.Close() + }, + } + return &Client{ + baseClient: base, + commandable: commandable{ + process: base.process, + }, + } } //------------------------------------------------------------------------------ type sentinelClient struct { + baseClient commandable - *baseClient } func newSentinel(opt *Options) *sentinelClient { - base := &baseClient{ + base := baseClient{ opt: opt, connPool: newConnPool(opt), } @@ -116,8 +129,12 @@ type sentinelFailover struct { pool pool poolOnce sync.Once - lock sync.RWMutex - _sentinel *sentinelClient + mu sync.RWMutex + sentinel *sentinelClient +} + +func (d *sentinelFailover) Close() error { + return d.resetSentinel() } func (d *sentinelFailover) dial() (net.Conn, error) { @@ -137,15 +154,15 @@ func (d *sentinelFailover) Pool() pool { } func (d *sentinelFailover) MasterAddr() (string, error) { - defer d.lock.Unlock() - d.lock.Lock() + defer d.mu.Unlock() + d.mu.Lock() // Try last working sentinel. - if d._sentinel != nil { - addr, err := d._sentinel.GetMasterAddrByName(d.masterName).Result() + if d.sentinel != nil { + addr, err := d.sentinel.GetMasterAddrByName(d.masterName).Result() if err != nil { Logger.Printf("sentinel: GetMasterAddrByName %q failed: %s", d.masterName, err) - d.resetSentinel() + d._resetSentinel() } else { addr := net.JoinHostPort(addr[0], addr[1]) Logger.Printf("sentinel: %q addr is %s", d.masterName, addr) @@ -186,10 +203,26 @@ func (d *sentinelFailover) MasterAddr() (string, error) { func (d *sentinelFailover) setSentinel(sentinel *sentinelClient) { d.discoverSentinels(sentinel) - d._sentinel = sentinel + d.sentinel = sentinel go d.listen() } +func (d *sentinelFailover) resetSentinel() error { + d.mu.Lock() + err := d._resetSentinel() + d.mu.Unlock() + return err +} + +func (d *sentinelFailover) _resetSentinel() error { + var err error + if d.sentinel != nil { + err = d.sentinel.Close() + d.sentinel = nil + } + return err +} + func (d *sentinelFailover) discoverSentinels(sentinel *sentinelClient) { sentinels, err := sentinel.Sentinels(d.masterName).Result() if err != nil { @@ -247,55 +280,41 @@ func (d *sentinelFailover) listen() { var pubsub *PubSub for { if pubsub == nil { - pubsub = d._sentinel.PubSub() + pubsub = d.sentinel.PubSub() if err := pubsub.Subscribe("+switch-master"); err != nil { Logger.Printf("sentinel: Subscribe failed: %s", err) - d.lock.Lock() d.resetSentinel() - d.lock.Unlock() return } } - msg, err := pubsub.Receive() + msg, err := pubsub.ReceiveMessage() if err != nil { - Logger.Printf("sentinel: Receive failed: %s", err) + Logger.Printf("sentinel: ReceiveMessage failed: %s", err) pubsub.Close() + d.resetSentinel() return } - switch msg := msg.(type) { - case *Message: - switch msg.Channel { - case "+switch-master": - parts := strings.Split(msg.Payload, " ") - if parts[0] != d.masterName { - Logger.Printf("sentinel: ignore new %s addr", parts[0]) - continue - } - addr := net.JoinHostPort(parts[3], parts[4]) - Logger.Printf( - "sentinel: new %q addr is %s", - d.masterName, addr, - ) - - d.closeOldConns(addr) - default: - Logger.Printf("sentinel: unsupported message: %s", msg) + switch msg.Channel { + case "+switch-master": + parts := strings.Split(msg.Payload, " ") + if parts[0] != d.masterName { + Logger.Printf("sentinel: ignore new %s addr", parts[0]) + continue } - case *Subscription: - // Ignore. - default: - Logger.Printf("sentinel: unsupported message: %s", msg) + + addr := net.JoinHostPort(parts[3], parts[4]) + Logger.Printf( + "sentinel: new %q addr is %s", + d.masterName, addr, + ) + + d.closeOldConns(addr) } } } -func (d *sentinelFailover) resetSentinel() { - d._sentinel.Close() - d._sentinel = nil -} - func contains(slice []string, str string) bool { for _, s := range slice { if s == str {