Simplify resubscribing in PubSub.

This commit is contained in:
Vladimir Mihailenco 2016-09-29 12:07:04 +00:00
parent 833b0c68df
commit e57ac63b6e
14 changed files with 90 additions and 93 deletions

View File

@ -516,7 +516,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
}
}
cn, err := node.Client.conn()
cn, _, err := node.Client.conn()
if err != nil {
setCmdsErr(cmds, err)
setRetErr(err)

View File

@ -16,7 +16,7 @@ func benchmarkPoolGetPut(b *testing.B, poolSize int) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
if err != nil {
b.Fatal(err)
}
@ -48,7 +48,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
if err != nil {
b.Fatal(err)
}

View File

@ -36,7 +36,7 @@ type PoolStats struct {
}
type Pooler interface {
Get() (*Conn, error)
Get() (*Conn, bool, error)
Put(*Conn) error
Remove(*Conn, error) error
Len() int
@ -152,9 +152,9 @@ func (p *ConnPool) popFree() *Conn {
}
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (*Conn, error) {
func (p *ConnPool) Get() (*Conn, bool, error) {
if p.Closed() {
return nil, ErrClosed
return nil, false, ErrClosed
}
atomic.AddUint32(&p.stats.Requests, 1)
@ -170,7 +170,7 @@ func (p *ConnPool) Get() (*Conn, error) {
case <-timer.C:
timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1)
return nil, ErrPoolTimeout
return nil, false, ErrPoolTimeout
}
p.freeConnsMu.Lock()
@ -180,7 +180,7 @@ func (p *ConnPool) Get() (*Conn, error) {
if cn != nil {
atomic.AddUint32(&p.stats.Hits, 1)
if !cn.IsStale(p.idleTimeout) {
return cn, nil
return cn, false, nil
}
_ = p.closeConn(cn, errConnStale)
}
@ -188,7 +188,7 @@ func (p *ConnPool) Get() (*Conn, error) {
newcn, err := p.NewConn()
if err != nil {
<-p.queue
return nil, err
return nil, false, err
}
p.connsMu.Lock()
@ -198,7 +198,7 @@ func (p *ConnPool) Get() (*Conn, error) {
p.conns = append(p.conns, newcn)
p.connsMu.Unlock()
return newcn, nil
return newcn, true, nil
}
func (p *ConnPool) Put(cn *Conn) error {

View File

@ -16,8 +16,8 @@ func (p *SingleConnPool) First() *Conn {
return p.cn
}
func (p *SingleConnPool) Get() (*Conn, error) {
return p.cn, nil
func (p *SingleConnPool) Get() (*Conn, bool, error) {
return p.cn, false, nil
}
func (p *SingleConnPool) Put(cn *Conn) error {

View File

@ -30,23 +30,23 @@ func (p *StickyConnPool) First() *Conn {
return cn
}
func (p *StickyConnPool) Get() (*Conn, error) {
func (p *StickyConnPool) Get() (*Conn, bool, error) {
defer p.mx.Unlock()
p.mx.Lock()
if p.closed {
return nil, ErrClosed
return nil, false, ErrClosed
}
if p.cn != nil {
return p.cn, nil
return p.cn, false, nil
}
cn, err := p.pool.Get()
cn, _, err := p.pool.Get()
if err != nil {
return nil, err
return nil, false, err
}
p.cn = cn
return cn, nil
return cn, true, nil
}
func (p *StickyConnPool) put() (err error) {

View File

@ -26,7 +26,7 @@ var _ = Describe("ConnPool", func() {
It("rate limits dial", func() {
var rateErr error
for i := 0; i < 1000; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
if err != nil {
rateErr = err
break
@ -40,13 +40,13 @@ var _ = Describe("ConnPool", func() {
It("should unblock client when conn is removed", func() {
// Reserve one connection.
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
// Reserve all other connections.
var cns []*pool.Conn
for i := 0; i < 9; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn)
}
@ -57,7 +57,7 @@ var _ = Describe("ConnPool", func() {
defer GinkgoRecover()
started <- true
_, err := connPool.Get()
_, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
done <- true
@ -113,7 +113,7 @@ var _ = Describe("conns reaper", func() {
// add stale connections
idleConns = nil
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
cn.UsedAt = time.Now().Add(-2 * idleTimeout)
conns = append(conns, cn)
@ -122,7 +122,7 @@ var _ = Describe("conns reaper", func() {
// add fresh connections
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
conns = append(conns, cn)
}
@ -167,7 +167,7 @@ var _ = Describe("conns reaper", func() {
for j := 0; j < 3; j++ {
var freeCns []*pool.Conn
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn)
@ -176,7 +176,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(3))
Expect(connPool.FreeLen()).To(Equal(0))
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
conns = append(conns, cn)
@ -224,7 +224,7 @@ var _ = Describe("race", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil {
Expect(connPool.Put(cn)).NotTo(HaveOccurred())
@ -232,7 +232,7 @@ var _ = Describe("race", func() {
}
}, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil {
Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred())
@ -248,7 +248,7 @@ var _ = Describe("race", func() {
perform(C, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil {
Expect(connPool.Put(cn)).NotTo(HaveOccurred())

View File

@ -91,7 +91,7 @@ var _ = Describe("pool", func() {
})
It("should remove broken connections", func() {
cn, err := client.Pool().Get()
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{}
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())

View File

@ -18,12 +18,25 @@ type PubSub struct {
channels []string
patterns []string
}
nsub int // number of active subscriptions
func (c *PubSub) conn() (*pool.Conn, bool, error) {
cn, isNew, err := c.base.conn()
if err != nil {
return nil, false, err
}
if isNew {
c.resubscribe()
}
return cn, isNew, nil
}
func (c *PubSub) putConn(cn *pool.Conn, err error) {
c.base.putConn(cn, err, true)
}
func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
cn, err := c.base.conn()
cn, _, err := c.conn()
if err != nil {
return err
}
@ -44,7 +57,6 @@ func (c *PubSub) Subscribe(channels ...string) error {
err := c.subscribe("SUBSCRIBE", channels...)
if err == nil {
c.channels = appendIfNotExists(c.channels, channels...)
c.nsub += len(channels)
}
return err
}
@ -54,43 +66,10 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
err := c.subscribe("PSUBSCRIBE", patterns...)
if err == nil {
c.patterns = appendIfNotExists(c.patterns, patterns...)
c.nsub += len(patterns)
}
return err
}
func remove(ss []string, es ...string) []string {
if len(es) == 0 {
return ss[:0]
}
for _, e := range es {
for i, s := range ss {
if s == e {
ss = append(ss[:i], ss[i+1:]...)
break
}
}
}
return ss
}
func appendIfNotExists(ss []string, es ...string) []string {
for _, e := range es {
found := false
for _, s := range ss {
if s == e {
found = true
break
}
}
if !found {
ss = append(ss, e)
}
}
return ss
}
// Unsubscribes the client from the given channels, or from all of
// them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error {
@ -116,7 +95,7 @@ func (c *PubSub) Close() error {
}
func (c *PubSub) Ping(payload string) error {
cn, err := c.base.conn()
cn, _, err := c.conn()
if err != nil {
return err
}
@ -198,11 +177,7 @@ func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) {
// is not received in time. This is low-level API and most clients
// should use ReceiveMessage.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
if c.nsub == 0 {
c.resubscribe()
}
cn, err := c.base.conn()
cn, _, err := c.conn()
if err != nil {
return nil, err
}
@ -274,12 +249,6 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) {
}
}
func (c *PubSub) putConn(cn *pool.Conn, err error) {
if !c.base.putConn(cn, err, true) {
c.nsub = 0
}
}
func (c *PubSub) resubscribe() {
if c.base.closed() {
return
@ -295,3 +264,31 @@ func (c *PubSub) resubscribe() {
}
}
}
func remove(ss []string, es ...string) []string {
if len(es) == 0 {
return ss[:0]
}
for _, e := range es {
for i, s := range ss {
if s == e {
ss = append(ss[:i], ss[i+1:]...)
break
}
}
}
return ss
}
func appendIfNotExists(ss []string, es ...string) []string {
loop:
for _, e := range es {
for _, s := range ss {
if s == e {
continue loop
}
}
ss = append(ss, e)
}
return ss
}

View File

@ -288,7 +288,7 @@ var _ = Describe("PubSub", func() {
})
expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
cn1, err := pubsub.Pool().Get()
cn1, _, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn1.NetConn = &badConn{
readErr: io.EOF,

View File

@ -27,18 +27,18 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB)
}
func (c *baseClient) conn() (*pool.Conn, error) {
cn, err := c.connPool.Get()
func (c *baseClient) conn() (*pool.Conn, bool, error) {
cn, isNew, err := c.connPool.Get()
if err != nil {
return nil, err
return nil, false, err
}
if !cn.Inited {
if err := c.initConn(cn); err != nil {
_ = c.connPool.Remove(cn, err)
return nil, err
return nil, false, err
}
}
return cn, err
return cn, isNew, nil
}
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
@ -84,7 +84,7 @@ func (c *baseClient) Process(cmd Cmder) error {
cmd.reset()
}
cn, err := c.conn()
cn, _, err := c.conn()
if err != nil {
cmd.setErr(err)
return err
@ -197,7 +197,7 @@ func (c *Client) pipelineExec(cmds []Cmder) error {
var retErr error
failedCmds := cmds
for i := 0; i <= c.opt.MaxRetries; i++ {
cn, err := c.conn()
cn, _, err := c.conn()
if err != nil {
setCmdsErr(failedCmds, err)
return err

View File

@ -144,7 +144,7 @@ var _ = Describe("Client", func() {
})
// Put bad connection in the pool.
cn, err := client.Pool().Get()
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{}
@ -156,7 +156,7 @@ var _ = Describe("Client", func() {
})
It("should update conn.UsedAt on read/write", func() {
cn, err := client.Pool().Get()
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
Expect(cn.UsedAt).NotTo(BeZero())
createdAt := cn.UsedAt
@ -168,7 +168,7 @@ var _ = Describe("Client", func() {
err = client.Ping().Err()
Expect(err).NotTo(HaveOccurred())
cn, err = client.Pool().Get()
cn, _, err = client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
Expect(cn.UsedAt.After(createdAt)).To(BeTrue())

View File

@ -318,7 +318,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) error {
for name, cmds := range cmdsMap {
client := c.shards[name].Client
cn, err := client.conn()
cn, _, err := client.conn()
if err != nil {
setCmdsErr(cmds, err)
if retErr == nil {

2
tx.go
View File

@ -139,7 +139,7 @@ func (c *Tx) MultiExec(fn func() error) ([]Cmder, error) {
// Strip MULTI and EXEC commands.
retCmds := cmds[1 : len(cmds)-1]
cn, err := c.conn()
cn, _, err := c.conn()
if err != nil {
setCmdsErr(retCmds, err)
return retCmds, err

View File

@ -126,7 +126,7 @@ var _ = Describe("Tx", func() {
It("should recover from bad connection", func() {
// Put bad connection in the pool.
cn, err := client.Pool().Get()
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{}
@ -153,7 +153,7 @@ var _ = Describe("Tx", func() {
It("should recover from bad connection when there are no commands", func() {
// Put bad connection in the pool.
cn, err := client.Pool().Get()
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{}