Simplify resubscribing in PubSub.

This commit is contained in:
Vladimir Mihailenco 2016-09-29 12:07:04 +00:00
parent 833b0c68df
commit e57ac63b6e
14 changed files with 90 additions and 93 deletions

View File

@ -516,7 +516,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
} }
} }
cn, err := node.Client.conn() cn, _, err := node.Client.conn()
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
setRetErr(err) setRetErr(err)

View File

@ -16,7 +16,7 @@ func benchmarkPoolGetPut(b *testing.B, poolSize int) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
cn, err := connPool.Get() cn, _, err := connPool.Get()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -48,7 +48,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
cn, err := connPool.Get() cn, _, err := connPool.Get()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@ -36,7 +36,7 @@ type PoolStats struct {
} }
type Pooler interface { type Pooler interface {
Get() (*Conn, error) Get() (*Conn, bool, error)
Put(*Conn) error Put(*Conn) error
Remove(*Conn, error) error Remove(*Conn, error) error
Len() int Len() int
@ -152,9 +152,9 @@ func (p *ConnPool) popFree() *Conn {
} }
// Get returns existed connection from the pool or creates a new one. // Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (*Conn, error) { func (p *ConnPool) Get() (*Conn, bool, error) {
if p.Closed() { if p.Closed() {
return nil, ErrClosed return nil, false, ErrClosed
} }
atomic.AddUint32(&p.stats.Requests, 1) atomic.AddUint32(&p.stats.Requests, 1)
@ -170,7 +170,7 @@ func (p *ConnPool) Get() (*Conn, error) {
case <-timer.C: case <-timer.C:
timers.Put(timer) timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1) atomic.AddUint32(&p.stats.Timeouts, 1)
return nil, ErrPoolTimeout return nil, false, ErrPoolTimeout
} }
p.freeConnsMu.Lock() p.freeConnsMu.Lock()
@ -180,7 +180,7 @@ func (p *ConnPool) Get() (*Conn, error) {
if cn != nil { if cn != nil {
atomic.AddUint32(&p.stats.Hits, 1) atomic.AddUint32(&p.stats.Hits, 1)
if !cn.IsStale(p.idleTimeout) { if !cn.IsStale(p.idleTimeout) {
return cn, nil return cn, false, nil
} }
_ = p.closeConn(cn, errConnStale) _ = p.closeConn(cn, errConnStale)
} }
@ -188,7 +188,7 @@ func (p *ConnPool) Get() (*Conn, error) {
newcn, err := p.NewConn() newcn, err := p.NewConn()
if err != nil { if err != nil {
<-p.queue <-p.queue
return nil, err return nil, false, err
} }
p.connsMu.Lock() p.connsMu.Lock()
@ -198,7 +198,7 @@ func (p *ConnPool) Get() (*Conn, error) {
p.conns = append(p.conns, newcn) p.conns = append(p.conns, newcn)
p.connsMu.Unlock() p.connsMu.Unlock()
return newcn, nil return newcn, true, nil
} }
func (p *ConnPool) Put(cn *Conn) error { func (p *ConnPool) Put(cn *Conn) error {

View File

@ -16,8 +16,8 @@ func (p *SingleConnPool) First() *Conn {
return p.cn return p.cn
} }
func (p *SingleConnPool) Get() (*Conn, error) { func (p *SingleConnPool) Get() (*Conn, bool, error) {
return p.cn, nil return p.cn, false, nil
} }
func (p *SingleConnPool) Put(cn *Conn) error { func (p *SingleConnPool) Put(cn *Conn) error {

View File

@ -30,23 +30,23 @@ func (p *StickyConnPool) First() *Conn {
return cn return cn
} }
func (p *StickyConnPool) Get() (*Conn, error) { func (p *StickyConnPool) Get() (*Conn, bool, error) {
defer p.mx.Unlock() defer p.mx.Unlock()
p.mx.Lock() p.mx.Lock()
if p.closed { if p.closed {
return nil, ErrClosed return nil, false, ErrClosed
} }
if p.cn != nil { if p.cn != nil {
return p.cn, nil return p.cn, false, nil
} }
cn, err := p.pool.Get() cn, _, err := p.pool.Get()
if err != nil { if err != nil {
return nil, err return nil, false, err
} }
p.cn = cn p.cn = cn
return cn, nil return cn, true, nil
} }
func (p *StickyConnPool) put() (err error) { func (p *StickyConnPool) put() (err error) {

View File

@ -26,7 +26,7 @@ var _ = Describe("ConnPool", func() {
It("rate limits dial", func() { It("rate limits dial", func() {
var rateErr error var rateErr error
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
cn, err := connPool.Get() cn, _, err := connPool.Get()
if err != nil { if err != nil {
rateErr = err rateErr = err
break break
@ -40,13 +40,13 @@ var _ = Describe("ConnPool", func() {
It("should unblock client when conn is removed", func() { It("should unblock client when conn is removed", func() {
// Reserve one connection. // Reserve one connection.
cn, err := connPool.Get() cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Reserve all other connections. // Reserve all other connections.
var cns []*pool.Conn var cns []*pool.Conn
for i := 0; i < 9; i++ { for i := 0; i < 9; i++ {
cn, err := connPool.Get() cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn) cns = append(cns, cn)
} }
@ -57,7 +57,7 @@ var _ = Describe("ConnPool", func() {
defer GinkgoRecover() defer GinkgoRecover()
started <- true started <- true
_, err := connPool.Get() _, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
done <- true done <- true
@ -113,7 +113,7 @@ var _ = Describe("conns reaper", func() {
// add stale connections // add stale connections
idleConns = nil idleConns = nil
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get() cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.UsedAt = time.Now().Add(-2 * idleTimeout) cn.UsedAt = time.Now().Add(-2 * idleTimeout)
conns = append(conns, cn) conns = append(conns, cn)
@ -122,7 +122,7 @@ var _ = Describe("conns reaper", func() {
// add fresh connections // add fresh connections
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get() cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
conns = append(conns, cn) conns = append(conns, cn)
} }
@ -167,7 +167,7 @@ var _ = Describe("conns reaper", func() {
for j := 0; j < 3; j++ { for j := 0; j < 3; j++ {
var freeCns []*pool.Conn var freeCns []*pool.Conn
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
cn, err := connPool.Get() cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn) freeCns = append(freeCns, cn)
@ -176,7 +176,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(3)) Expect(connPool.Len()).To(Equal(3))
Expect(connPool.FreeLen()).To(Equal(0)) Expect(connPool.FreeLen()).To(Equal(0))
cn, err := connPool.Get() cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
conns = append(conns, cn) conns = append(conns, cn)
@ -224,7 +224,7 @@ var _ = Describe("race", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get() cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
Expect(connPool.Put(cn)).NotTo(HaveOccurred()) Expect(connPool.Put(cn)).NotTo(HaveOccurred())
@ -232,7 +232,7 @@ var _ = Describe("race", func() {
} }
}, func(id int) { }, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get() cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred()) Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred())
@ -248,7 +248,7 @@ var _ = Describe("race", func() {
perform(C, func(id int) { perform(C, func(id int) {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
cn, err := connPool.Get() cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
Expect(connPool.Put(cn)).NotTo(HaveOccurred()) Expect(connPool.Put(cn)).NotTo(HaveOccurred())

View File

@ -91,7 +91,7 @@ var _ = Describe("pool", func() {
}) })
It("should remove broken connections", func() { It("should remove broken connections", func() {
cn, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{} cn.NetConn = &badConn{}
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred()) Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())

View File

@ -18,12 +18,25 @@ type PubSub struct {
channels []string channels []string
patterns []string patterns []string
}
nsub int // number of active subscriptions func (c *PubSub) conn() (*pool.Conn, bool, error) {
cn, isNew, err := c.base.conn()
if err != nil {
return nil, false, err
}
if isNew {
c.resubscribe()
}
return cn, isNew, nil
}
func (c *PubSub) putConn(cn *pool.Conn, err error) {
c.base.putConn(cn, err, true)
} }
func (c *PubSub) subscribe(redisCmd string, channels ...string) error { func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
cn, err := c.base.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
return err return err
} }
@ -44,7 +57,6 @@ func (c *PubSub) Subscribe(channels ...string) error {
err := c.subscribe("SUBSCRIBE", channels...) err := c.subscribe("SUBSCRIBE", channels...)
if err == nil { if err == nil {
c.channels = appendIfNotExists(c.channels, channels...) c.channels = appendIfNotExists(c.channels, channels...)
c.nsub += len(channels)
} }
return err return err
} }
@ -54,43 +66,10 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
err := c.subscribe("PSUBSCRIBE", patterns...) err := c.subscribe("PSUBSCRIBE", patterns...)
if err == nil { if err == nil {
c.patterns = appendIfNotExists(c.patterns, patterns...) c.patterns = appendIfNotExists(c.patterns, patterns...)
c.nsub += len(patterns)
} }
return err return err
} }
func remove(ss []string, es ...string) []string {
if len(es) == 0 {
return ss[:0]
}
for _, e := range es {
for i, s := range ss {
if s == e {
ss = append(ss[:i], ss[i+1:]...)
break
}
}
}
return ss
}
func appendIfNotExists(ss []string, es ...string) []string {
for _, e := range es {
found := false
for _, s := range ss {
if s == e {
found = true
break
}
}
if !found {
ss = append(ss, e)
}
}
return ss
}
// Unsubscribes the client from the given channels, or from all of // Unsubscribes the client from the given channels, or from all of
// them if none is given. // them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error { func (c *PubSub) Unsubscribe(channels ...string) error {
@ -116,7 +95,7 @@ func (c *PubSub) Close() error {
} }
func (c *PubSub) Ping(payload string) error { func (c *PubSub) Ping(payload string) error {
cn, err := c.base.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
return err return err
} }
@ -198,11 +177,7 @@ func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) {
// is not received in time. This is low-level API and most clients // is not received in time. This is low-level API and most clients
// should use ReceiveMessage. // should use ReceiveMessage.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
if c.nsub == 0 { cn, _, err := c.conn()
c.resubscribe()
}
cn, err := c.base.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -274,12 +249,6 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) {
} }
} }
func (c *PubSub) putConn(cn *pool.Conn, err error) {
if !c.base.putConn(cn, err, true) {
c.nsub = 0
}
}
func (c *PubSub) resubscribe() { func (c *PubSub) resubscribe() {
if c.base.closed() { if c.base.closed() {
return return
@ -295,3 +264,31 @@ func (c *PubSub) resubscribe() {
} }
} }
} }
func remove(ss []string, es ...string) []string {
if len(es) == 0 {
return ss[:0]
}
for _, e := range es {
for i, s := range ss {
if s == e {
ss = append(ss[:i], ss[i+1:]...)
break
}
}
}
return ss
}
func appendIfNotExists(ss []string, es ...string) []string {
loop:
for _, e := range es {
for _, s := range ss {
if s == e {
continue loop
}
}
ss = append(ss, e)
}
return ss
}

View File

@ -288,7 +288,7 @@ var _ = Describe("PubSub", func() {
}) })
expectReceiveMessageOnError := func(pubsub *redis.PubSub) { expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
cn1, err := pubsub.Pool().Get() cn1, _, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn1.NetConn = &badConn{ cn1.NetConn = &badConn{
readErr: io.EOF, readErr: io.EOF,

View File

@ -27,18 +27,18 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB) return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB)
} }
func (c *baseClient) conn() (*pool.Conn, error) { func (c *baseClient) conn() (*pool.Conn, bool, error) {
cn, err := c.connPool.Get() cn, isNew, err := c.connPool.Get()
if err != nil { if err != nil {
return nil, err return nil, false, err
} }
if !cn.Inited { if !cn.Inited {
if err := c.initConn(cn); err != nil { if err := c.initConn(cn); err != nil {
_ = c.connPool.Remove(cn, err) _ = c.connPool.Remove(cn, err)
return nil, err return nil, false, err
} }
} }
return cn, err return cn, isNew, nil
} }
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
@ -84,7 +84,7 @@ func (c *baseClient) Process(cmd Cmder) error {
cmd.reset() cmd.reset()
} }
cn, err := c.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
cmd.setErr(err) cmd.setErr(err)
return err return err
@ -197,7 +197,7 @@ func (c *Client) pipelineExec(cmds []Cmder) error {
var retErr error var retErr error
failedCmds := cmds failedCmds := cmds
for i := 0; i <= c.opt.MaxRetries; i++ { for i := 0; i <= c.opt.MaxRetries; i++ {
cn, err := c.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
setCmdsErr(failedCmds, err) setCmdsErr(failedCmds, err)
return err return err

View File

@ -144,7 +144,7 @@ var _ = Describe("Client", func() {
}) })
// Put bad connection in the pool. // Put bad connection in the pool.
cn, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{} cn.NetConn = &badConn{}
@ -156,7 +156,7 @@ var _ = Describe("Client", func() {
}) })
It("should update conn.UsedAt on read/write", func() { It("should update conn.UsedAt on read/write", func() {
cn, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn.UsedAt).NotTo(BeZero()) Expect(cn.UsedAt).NotTo(BeZero())
createdAt := cn.UsedAt createdAt := cn.UsedAt
@ -168,7 +168,7 @@ var _ = Describe("Client", func() {
err = client.Ping().Err() err = client.Ping().Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn, err = client.Pool().Get() cn, _, err = client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil()) Expect(cn).NotTo(BeNil())
Expect(cn.UsedAt.After(createdAt)).To(BeTrue()) Expect(cn.UsedAt.After(createdAt)).To(BeTrue())

View File

@ -318,7 +318,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) error {
for name, cmds := range cmdsMap { for name, cmds := range cmdsMap {
client := c.shards[name].Client client := c.shards[name].Client
cn, err := client.conn() cn, _, err := client.conn()
if err != nil { if err != nil {
setCmdsErr(cmds, err) setCmdsErr(cmds, err)
if retErr == nil { if retErr == nil {

2
tx.go
View File

@ -139,7 +139,7 @@ func (c *Tx) MultiExec(fn func() error) ([]Cmder, error) {
// Strip MULTI and EXEC commands. // Strip MULTI and EXEC commands.
retCmds := cmds[1 : len(cmds)-1] retCmds := cmds[1 : len(cmds)-1]
cn, err := c.conn() cn, _, err := c.conn()
if err != nil { if err != nil {
setCmdsErr(retCmds, err) setCmdsErr(retCmds, err)
return retCmds, err return retCmds, err

View File

@ -126,7 +126,7 @@ var _ = Describe("Tx", func() {
It("should recover from bad connection", func() { It("should recover from bad connection", func() {
// Put bad connection in the pool. // Put bad connection in the pool.
cn, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{} cn.NetConn = &badConn{}
@ -153,7 +153,7 @@ var _ = Describe("Tx", func() {
It("should recover from bad connection when there are no commands", func() { It("should recover from bad connection when there are no commands", func() {
// Put bad connection in the pool. // Put bad connection in the pool.
cn, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{} cn.NetConn = &badConn{}