diff --git a/internal/pool/pool.go b/internal/pool/pool.go index bef000ea..a4e65084 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -140,47 +140,6 @@ func (p *ConnPool) lastDialError() error { return p._lastDialError.Load().(error) } -func (p *ConnPool) PopFree() *Conn { - select { - case p.queue <- struct{}{}: - default: - timer := timers.Get().(*time.Timer) - timer.Reset(p.opt.PoolTimeout) - - select { - case p.queue <- struct{}{}: - if !timer.Stop() { - <-timer.C - } - timers.Put(timer) - case <-timer.C: - timers.Put(timer) - atomic.AddUint32(&p.stats.Timeouts, 1) - return nil - } - } - - p.freeConnsMu.Lock() - cn := p.popFree() - p.freeConnsMu.Unlock() - - if cn == nil { - <-p.queue - } - return cn -} - -func (p *ConnPool) popFree() *Conn { - if len(p.freeConns) == 0 { - return nil - } - - idx := len(p.freeConns) - 1 - cn := p.freeConns[idx] - p.freeConns = p.freeConns[:idx] - return cn -} - // Get returns existed connection from the pool or creates a new one. func (p *ConnPool) Get() (*Conn, bool, error) { if p.closed() { @@ -235,6 +194,17 @@ func (p *ConnPool) Get() (*Conn, bool, error) { return newcn, true, nil } +func (p *ConnPool) popFree() *Conn { + if len(p.freeConns) == 0 { + return nil + } + + idx := len(p.freeConns) - 1 + cn := p.freeConns[idx] + p.freeConns = p.freeConns[:idx] + return cn +} + func (p *ConnPool) Put(cn *Conn) error { if data := cn.Rd.PeekBuffered(); data != nil { internal.Logf("connection has unread data: %q", data) @@ -303,17 +273,28 @@ func (p *ConnPool) closed() bool { return atomic.LoadUint32(&p._closed) == 1 } +func (p *ConnPool) Filter(fn func(*Conn) bool) error { + var firstErr error + p.connsMu.Lock() + for _, cn := range p.conns { + if fn(cn) { + if err := p.closeConn(cn); err != nil && firstErr == nil { + firstErr = err + } + } + } + p.connsMu.Unlock() + return firstErr +} + func (p *ConnPool) Close() error { if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) { return ErrClosed } - p.connsMu.Lock() var firstErr error + p.connsMu.Lock() for _, cn := range p.conns { - if cn == nil { - continue - } if err := p.closeConn(cn); err != nil && firstErr == nil { firstErr = err } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index f86327a4..68c9a1be 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -238,30 +238,4 @@ var _ = Describe("race", func() { } }) }) - - It("does not happen on Get and PopFree", func() { - connPool = pool.NewConnPool( - &pool.Options{ - Dialer: dummyDialer, - PoolSize: 10, - PoolTimeout: time.Minute, - IdleTimeout: time.Second, - IdleCheckFrequency: time.Millisecond, - }) - - perform(C, func(id int) { - for i := 0; i < N; i++ { - cn, _, err := connPool.Get() - Expect(err).NotTo(HaveOccurred()) - if err == nil { - Expect(connPool.Put(cn)).NotTo(HaveOccurred()) - } - - cn = connPool.PopFree() - if cn != nil { - Expect(connPool.Put(cn)).NotTo(HaveOccurred()) - } - } - }) - }) }) diff --git a/main_test.go b/main_test.go index 7c5a6a96..30f09c61 100644 --- a/main_test.go +++ b/main_test.go @@ -50,6 +50,10 @@ var cluster = &clusterScenario{ clients: make(map[string]*redis.Client, 6), } +func init() { + //redis.SetLogger(log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)) +} + var _ = BeforeSuite(func() { var err error diff --git a/pubsub.go b/pubsub.go index 19016b9d..4872b4e8 100644 --- a/pubsub.go +++ b/pubsub.go @@ -19,54 +19,53 @@ import ( type PubSub struct { base baseClient - mu sync.Mutex - cn *pool.Conn - closed bool - - subMu sync.Mutex + mu sync.Mutex + cn *pool.Conn channels []string patterns []string + closed bool cmd *Cmd } -func (c *PubSub) conn() (*pool.Conn, bool, error) { +func (c *PubSub) conn() (*pool.Conn, error) { c.mu.Lock() - defer c.mu.Unlock() + cn, err := c._conn() + c.mu.Unlock() + return cn, err +} +func (c *PubSub) _conn() (*pool.Conn, error) { if c.closed { - return nil, false, pool.ErrClosed + return nil, pool.ErrClosed } if c.cn != nil { - return c.cn, false, nil + return c.cn, nil } cn, err := c.base.connPool.NewConn() if err != nil { - return nil, false, err + return nil, err } if !cn.Inited { if err := c.base.initConn(cn); err != nil { _ = c.base.connPool.CloseConn(cn) - return nil, false, err + return nil, err } } if err := c.resubscribe(cn); err != nil { _ = c.base.connPool.CloseConn(cn) - return nil, false, err + return nil, err } c.cn = cn - return cn, true, nil + return cn, nil } func (c *PubSub) resubscribe(cn *pool.Conn) error { - c.subMu.Lock() - defer c.subMu.Unlock() - var firstErr error if len(c.channels) > 0 { if err := c._subscribe(cn, "subscribe", c.channels...); err != nil && firstErr == nil { @@ -81,6 +80,18 @@ func (c *PubSub) resubscribe(cn *pool.Conn) error { return firstErr } +func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error { + args := make([]interface{}, 1+len(channels)) + args[0] = redisCmd + for i, channel := range channels { + args[1+i] = channel + } + cmd := NewSliceCmd(args...) + + cn.SetWriteTimeout(c.base.opt.WriteTimeout) + return writeCmd(cn, cmd) +} + func (c *PubSub) putConn(cn *pool.Conn, err error) { if !internal.IsBadConn(err, true) { return @@ -114,67 +125,55 @@ func (c *PubSub) Close() error { return nil } -func (c *PubSub) subscribe(redisCmd string, channels ...string) error { - cn, isNew, err := c.conn() - if err != nil { - return err - } - - if isNew { - return nil - } - - err = c._subscribe(cn, redisCmd, channels...) - c.putConn(cn, err) - return err -} - -func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error { - args := make([]interface{}, 1+len(channels)) - args[0] = redisCmd - for i, channel := range channels { - args[1+i] = channel - } - cmd := NewSliceCmd(args...) - - cn.SetWriteTimeout(c.base.opt.WriteTimeout) - return writeCmd(cn, cmd) -} - // Subscribes the client to the specified channels. It returns // empty subscription if there are no channels. func (c *PubSub) Subscribe(channels ...string) error { - c.subMu.Lock() + c.mu.Lock() + err := c.subscribe("subscribe", channels...) c.channels = appendIfNotExists(c.channels, channels...) - c.subMu.Unlock() - return c.subscribe("subscribe", channels...) + c.mu.Unlock() + return err } // Subscribes the client to the given patterns. It returns // empty subscription if there are no patterns. func (c *PubSub) PSubscribe(patterns ...string) error { - c.subMu.Lock() + c.mu.Lock() + err := c.subscribe("psubscribe", patterns...) c.patterns = appendIfNotExists(c.patterns, patterns...) - c.subMu.Unlock() - return c.subscribe("psubscribe", patterns...) + c.mu.Unlock() + return err } // Unsubscribes the client from the given channels, or from all of // them if none is given. func (c *PubSub) Unsubscribe(channels ...string) error { - c.subMu.Lock() + c.mu.Lock() + err := c.subscribe("unsubscribe", channels...) c.channels = remove(c.channels, channels...) - c.subMu.Unlock() - return c.subscribe("unsubscribe", channels...) + c.mu.Unlock() + return err } // Unsubscribes the client from the given patterns, or from all of // them if none is given. func (c *PubSub) PUnsubscribe(patterns ...string) error { - c.subMu.Lock() + c.mu.Lock() + err := c.subscribe("punsubscribe", patterns...) c.patterns = remove(c.patterns, patterns...) - c.subMu.Unlock() - return c.subscribe("punsubscribe", patterns...) + c.mu.Unlock() + return err +} + +func (c *PubSub) subscribe(redisCmd string, channels ...string) error { + cn, err := c._conn() + if err != nil { + return err + } + + err = c._subscribe(cn, redisCmd, channels...) + c.putConn(cn, err) + return err } func (c *PubSub) Ping(payload ...string) error { @@ -184,7 +183,7 @@ func (c *PubSub) Ping(payload ...string) error { } cmd := NewCmd(args...) - cn, _, err := c.conn() + cn, err := c.conn() if err != nil { return err } @@ -277,7 +276,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { c.cmd = NewCmd() } - cn, _, err := c.conn() + cn, err := c.conn() if err != nil { return nil, err } diff --git a/sentinel.go b/sentinel.go index b28c3706..ed6e7ffb 100644 --- a/sentinel.go +++ b/sentinel.go @@ -132,7 +132,6 @@ func (c *sentinelClient) Sentinels(name string) *SliceCmd { } type sentinelFailover struct { - masterName string sentinelAddrs []string opt *Options @@ -140,8 +139,10 @@ type sentinelFailover struct { pool *pool.ConnPool poolOnce sync.Once - mu sync.RWMutex - sentinel *sentinelClient + mu sync.RWMutex + masterName string + _masterAddr string + sentinel *sentinelClient } func (d *sentinelFailover) Close() error { @@ -168,17 +169,30 @@ func (d *sentinelFailover) MasterAddr() (string, error) { d.mu.Lock() defer d.mu.Unlock() + addr, err := d.masterAddr() + if err != nil { + return "", err + } + + if d._masterAddr != addr { + d.switchMaster(addr) + } + + return addr, nil +} + +func (d *sentinelFailover) masterAddr() (string, error) { // Try last working sentinel. if d.sentinel != nil { addr, err := d.sentinel.GetMasterAddrByName(d.masterName).Result() - if err != nil { - internal.Logf("sentinel: GetMasterAddrByName %q failed: %s", d.masterName, err) - d._resetSentinel() - } else { + if err == nil { addr := net.JoinHostPort(addr[0], addr[1]) - internal.Logf("sentinel: %q addr is %s", d.masterName, addr) + internal.Logf("sentinel: master=%q addr=%q", d.masterName, addr) return addr, nil } + + internal.Logf("sentinel: GetMasterAddrByName name=%q failed: %s", d.masterName, err) + d._resetSentinel() } for i, sentinelAddr := range d.sentinelAddrs { @@ -193,25 +207,36 @@ func (d *sentinelFailover) MasterAddr() (string, error) { PoolTimeout: d.opt.PoolTimeout, IdleTimeout: d.opt.IdleTimeout, }) + masterAddr, err := sentinel.GetMasterAddrByName(d.masterName).Result() if err != nil { - internal.Logf("sentinel: GetMasterAddrByName %q failed: %s", d.masterName, err) + internal.Logf("sentinel: GetMasterAddrByName master=%q failed: %s", d.masterName, err) sentinel.Close() continue } // Push working sentinel to the top. d.sentinelAddrs[0], d.sentinelAddrs[i] = d.sentinelAddrs[i], d.sentinelAddrs[0] - d.setSentinel(sentinel) + addr := net.JoinHostPort(masterAddr[0], masterAddr[1]) - internal.Logf("sentinel: %q addr is %s", d.masterName, addr) return addr, nil } return "", errors.New("redis: all sentinels are unreachable") } +func (d *sentinelFailover) switchMaster(masterAddr string) { + internal.Logf( + "sentinel: new master=%q addr=%q", + d.masterName, masterAddr, + ) + _ = d.Pool().Filter(func(cn *pool.Conn) bool { + return cn.RemoteAddr().String() != masterAddr + }) + d._masterAddr = masterAddr +} + func (d *sentinelFailover) setSentinel(sentinel *sentinelClient) { d.discoverSentinels(sentinel) d.sentinel = sentinel @@ -219,25 +244,25 @@ func (d *sentinelFailover) setSentinel(sentinel *sentinelClient) { } func (d *sentinelFailover) resetSentinel() error { + var err error d.mu.Lock() - err := d._resetSentinel() + if d.sentinel != nil { + err = d._resetSentinel() + } d.mu.Unlock() return err } func (d *sentinelFailover) _resetSentinel() error { - var err error - if d.sentinel != nil { - err = d.sentinel.Close() - d.sentinel = nil - } + err := d.sentinel.Close() + d.sentinel = nil return err } func (d *sentinelFailover) discoverSentinels(sentinel *sentinelClient) { sentinels, err := sentinel.Sentinels(d.masterName).Result() if err != nil { - internal.Logf("sentinel: Sentinels %q failed: %s", d.masterName, err) + internal.Logf("sentinel: Sentinels master=%q failed: %s", d.masterName, err) return } for _, sentinel := range sentinels { @@ -248,8 +273,8 @@ func (d *sentinelFailover) discoverSentinels(sentinel *sentinelClient) { sentinelAddr := vals[i+1].(string) if !contains(d.sentinelAddrs, sentinelAddr) { internal.Logf( - "sentinel: discovered new %q sentinel: %s", - d.masterName, sentinelAddr, + "sentinel: discovered new sentinel=%q for master=%q", + sentinelAddr, d.masterName, ) d.sentinelAddrs = append(d.sentinelAddrs, sentinelAddr) } @@ -258,34 +283,6 @@ func (d *sentinelFailover) discoverSentinels(sentinel *sentinelClient) { } } -// closeOldConns closes connections to the old master after failover switch. -func (d *sentinelFailover) closeOldConns(newMaster string) { - // Good connections that should be put back to the pool. They - // can't be put immediately, because pool.PopFree will return them - // again on next iteration. - cnsToPut := make([]*pool.Conn, 0) - - for { - cn := d.pool.PopFree() - if cn == nil { - break - } - if cn.RemoteAddr().String() != newMaster { - internal.Logf( - "sentinel: closing connection to the old master %s", - cn.RemoteAddr(), - ) - d.pool.Remove(cn) - } else { - cnsToPut = append(cnsToPut, cn) - } - } - - for _, cn := range cnsToPut { - d.pool.Put(cn) - } -} - func (d *sentinelFailover) listen(sentinel *sentinelClient) { var pubsub *PubSub for { @@ -312,17 +309,16 @@ func (d *sentinelFailover) listen(sentinel *sentinelClient) { case "+switch-master": parts := strings.Split(msg.Payload, " ") if parts[0] != d.masterName { - internal.Logf("sentinel: ignore new %s addr", parts[0]) + internal.Logf("sentinel: ignore addr for master=%q", parts[0]) continue } - addr := net.JoinHostPort(parts[3], parts[4]) - internal.Logf( - "sentinel: new %q addr is %s", - d.masterName, addr, - ) - d.closeOldConns(addr) + d.mu.Lock() + if d._masterAddr != addr { + d.switchMaster(addr) + } + d.mu.Unlock() } } } diff --git a/sentinel_test.go b/sentinel_test.go index f1f580f3..c67713cd 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -23,15 +23,19 @@ var _ = Describe("Sentinel", func() { }) It("should facilitate failover", func() { - // Set value on master, verify + // Set value on master. err := client.Set("foo", "master", 0).Err() Expect(err).NotTo(HaveOccurred()) + // Verify. val, err := sentinelMaster.Get("foo").Result() Expect(err).NotTo(HaveOccurred()) Expect(val).To(Equal("master")) - // Wait until replicated + // Create subscription. + ch := client.Subscribe("foo").Channel() + + // Wait until replicated. Eventually(func() string { return sentinelSlave1.Get("foo").Val() }, "1s", "100ms").Should(Equal("master")) @@ -59,6 +63,15 @@ var _ = Describe("Sentinel", func() { Eventually(func() error { return client.Get("foo").Err() }, "5s", "100ms").ShouldNot(HaveOccurred()) + + // Publish message to check if subscription is renewed. + err = client.Publish("foo", "hello").Err() + Expect(err).NotTo(HaveOccurred()) + + var msg *redis.Message + Eventually(ch).Should(Receive(&msg)) + Expect(msg.Channel).To(Equal("foo")) + Expect(msg.Payload).To(Equal("hello")) }) It("supports DB selection", func() {