mirror of https://github.com/go-redis/redis.git
Merge pull request #380 from go-redis/fix/pubsub-resubscribe
Simplify resubscribing in PubSub.
This commit is contained in:
commit
a7d1d0b9ac
|
@ -516,7 +516,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
|
|||
}
|
||||
}
|
||||
|
||||
cn, err := node.Client.conn()
|
||||
cn, _, err := node.Client.conn()
|
||||
if err != nil {
|
||||
setCmdsErr(cmds, err)
|
||||
setRetErr(err)
|
||||
|
|
|
@ -16,7 +16,7 @@ func benchmarkPoolGetPut(b *testing.B, poolSize int) {
|
|||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
|
|||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ type PoolStats struct {
|
|||
}
|
||||
|
||||
type Pooler interface {
|
||||
Get() (*Conn, error)
|
||||
Get() (*Conn, bool, error)
|
||||
Put(*Conn) error
|
||||
Remove(*Conn, error) error
|
||||
Len() int
|
||||
|
@ -152,9 +152,9 @@ func (p *ConnPool) popFree() *Conn {
|
|||
}
|
||||
|
||||
// Get returns existed connection from the pool or creates a new one.
|
||||
func (p *ConnPool) Get() (*Conn, error) {
|
||||
func (p *ConnPool) Get() (*Conn, bool, error) {
|
||||
if p.Closed() {
|
||||
return nil, ErrClosed
|
||||
return nil, false, ErrClosed
|
||||
}
|
||||
|
||||
atomic.AddUint32(&p.stats.Requests, 1)
|
||||
|
@ -170,7 +170,7 @@ func (p *ConnPool) Get() (*Conn, error) {
|
|||
case <-timer.C:
|
||||
timers.Put(timer)
|
||||
atomic.AddUint32(&p.stats.Timeouts, 1)
|
||||
return nil, ErrPoolTimeout
|
||||
return nil, false, ErrPoolTimeout
|
||||
}
|
||||
|
||||
p.freeConnsMu.Lock()
|
||||
|
@ -180,7 +180,7 @@ func (p *ConnPool) Get() (*Conn, error) {
|
|||
if cn != nil {
|
||||
atomic.AddUint32(&p.stats.Hits, 1)
|
||||
if !cn.IsStale(p.idleTimeout) {
|
||||
return cn, nil
|
||||
return cn, false, nil
|
||||
}
|
||||
_ = p.closeConn(cn, errConnStale)
|
||||
}
|
||||
|
@ -188,7 +188,7 @@ func (p *ConnPool) Get() (*Conn, error) {
|
|||
newcn, err := p.NewConn()
|
||||
if err != nil {
|
||||
<-p.queue
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
p.connsMu.Lock()
|
||||
|
@ -198,7 +198,7 @@ func (p *ConnPool) Get() (*Conn, error) {
|
|||
p.conns = append(p.conns, newcn)
|
||||
p.connsMu.Unlock()
|
||||
|
||||
return newcn, nil
|
||||
return newcn, true, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool) Put(cn *Conn) error {
|
||||
|
|
|
@ -16,8 +16,8 @@ func (p *SingleConnPool) First() *Conn {
|
|||
return p.cn
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Get() (*Conn, error) {
|
||||
return p.cn, nil
|
||||
func (p *SingleConnPool) Get() (*Conn, bool, error) {
|
||||
return p.cn, false, nil
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Put(cn *Conn) error {
|
||||
|
|
|
@ -30,23 +30,23 @@ func (p *StickyConnPool) First() *Conn {
|
|||
return cn
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) Get() (*Conn, error) {
|
||||
func (p *StickyConnPool) Get() (*Conn, bool, error) {
|
||||
defer p.mx.Unlock()
|
||||
p.mx.Lock()
|
||||
|
||||
if p.closed {
|
||||
return nil, ErrClosed
|
||||
return nil, false, ErrClosed
|
||||
}
|
||||
if p.cn != nil {
|
||||
return p.cn, nil
|
||||
return p.cn, false, nil
|
||||
}
|
||||
|
||||
cn, err := p.pool.Get()
|
||||
cn, _, err := p.pool.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
p.cn = cn
|
||||
return cn, nil
|
||||
return cn, true, nil
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) put() (err error) {
|
||||
|
|
|
@ -26,7 +26,7 @@ var _ = Describe("ConnPool", func() {
|
|||
It("rate limits dial", func() {
|
||||
var rateErr error
|
||||
for i := 0; i < 1000; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
if err != nil {
|
||||
rateErr = err
|
||||
break
|
||||
|
@ -40,13 +40,13 @@ var _ = Describe("ConnPool", func() {
|
|||
|
||||
It("should unblock client when conn is removed", func() {
|
||||
// Reserve one connection.
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// Reserve all other connections.
|
||||
var cns []*pool.Conn
|
||||
for i := 0; i < 9; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cns = append(cns, cn)
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ var _ = Describe("ConnPool", func() {
|
|||
defer GinkgoRecover()
|
||||
|
||||
started <- true
|
||||
_, err := connPool.Get()
|
||||
_, _, err := connPool.Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
done <- true
|
||||
|
||||
|
@ -113,7 +113,7 @@ var _ = Describe("conns reaper", func() {
|
|||
// add stale connections
|
||||
idleConns = nil
|
||||
for i := 0; i < 3; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cn.UsedAt = time.Now().Add(-2 * idleTimeout)
|
||||
conns = append(conns, cn)
|
||||
|
@ -122,7 +122,7 @@ var _ = Describe("conns reaper", func() {
|
|||
|
||||
// add fresh connections
|
||||
for i := 0; i < 3; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
conns = append(conns, cn)
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ var _ = Describe("conns reaper", func() {
|
|||
for j := 0; j < 3; j++ {
|
||||
var freeCns []*pool.Conn
|
||||
for i := 0; i < 3; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cn).NotTo(BeNil())
|
||||
freeCns = append(freeCns, cn)
|
||||
|
@ -176,7 +176,7 @@ var _ = Describe("conns reaper", func() {
|
|||
Expect(connPool.Len()).To(Equal(3))
|
||||
Expect(connPool.FreeLen()).To(Equal(0))
|
||||
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cn).NotTo(BeNil())
|
||||
conns = append(conns, cn)
|
||||
|
@ -224,7 +224,7 @@ var _ = Describe("race", func() {
|
|||
|
||||
perform(C, func(id int) {
|
||||
for i := 0; i < N; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
if err == nil {
|
||||
Expect(connPool.Put(cn)).NotTo(HaveOccurred())
|
||||
|
@ -232,7 +232,7 @@ var _ = Describe("race", func() {
|
|||
}
|
||||
}, func(id int) {
|
||||
for i := 0; i < N; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
if err == nil {
|
||||
Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred())
|
||||
|
@ -248,7 +248,7 @@ var _ = Describe("race", func() {
|
|||
|
||||
perform(C, func(id int) {
|
||||
for i := 0; i < N; i++ {
|
||||
cn, err := connPool.Get()
|
||||
cn, _, err := connPool.Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
if err == nil {
|
||||
Expect(connPool.Put(cn)).NotTo(HaveOccurred())
|
||||
|
|
|
@ -91,7 +91,7 @@ var _ = Describe("pool", func() {
|
|||
})
|
||||
|
||||
It("should remove broken connections", func() {
|
||||
cn, err := client.Pool().Get()
|
||||
cn, _, err := client.Pool().Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cn.NetConn = &badConn{}
|
||||
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())
|
||||
|
|
93
pubsub.go
93
pubsub.go
|
@ -18,12 +18,25 @@ type PubSub struct {
|
|||
|
||||
channels []string
|
||||
patterns []string
|
||||
}
|
||||
|
||||
nsub int // number of active subscriptions
|
||||
func (c *PubSub) conn() (*pool.Conn, bool, error) {
|
||||
cn, isNew, err := c.base.conn()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if isNew {
|
||||
c.resubscribe()
|
||||
}
|
||||
return cn, isNew, nil
|
||||
}
|
||||
|
||||
func (c *PubSub) putConn(cn *pool.Conn, err error) {
|
||||
c.base.putConn(cn, err, true)
|
||||
}
|
||||
|
||||
func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
|
||||
cn, err := c.base.conn()
|
||||
cn, _, err := c.conn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -44,7 +57,6 @@ func (c *PubSub) Subscribe(channels ...string) error {
|
|||
err := c.subscribe("SUBSCRIBE", channels...)
|
||||
if err == nil {
|
||||
c.channels = appendIfNotExists(c.channels, channels...)
|
||||
c.nsub += len(channels)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -54,43 +66,10 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
|
|||
err := c.subscribe("PSUBSCRIBE", patterns...)
|
||||
if err == nil {
|
||||
c.patterns = appendIfNotExists(c.patterns, patterns...)
|
||||
c.nsub += len(patterns)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func remove(ss []string, es ...string) []string {
|
||||
if len(es) == 0 {
|
||||
return ss[:0]
|
||||
}
|
||||
for _, e := range es {
|
||||
for i, s := range ss {
|
||||
if s == e {
|
||||
ss = append(ss[:i], ss[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
func appendIfNotExists(ss []string, es ...string) []string {
|
||||
for _, e := range es {
|
||||
found := false
|
||||
for _, s := range ss {
|
||||
if s == e {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
ss = append(ss, e)
|
||||
}
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
// Unsubscribes the client from the given channels, or from all of
|
||||
// them if none is given.
|
||||
func (c *PubSub) Unsubscribe(channels ...string) error {
|
||||
|
@ -116,7 +95,7 @@ func (c *PubSub) Close() error {
|
|||
}
|
||||
|
||||
func (c *PubSub) Ping(payload string) error {
|
||||
cn, err := c.base.conn()
|
||||
cn, _, err := c.conn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -198,11 +177,7 @@ func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) {
|
|||
// is not received in time. This is low-level API and most clients
|
||||
// should use ReceiveMessage.
|
||||
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
|
||||
if c.nsub == 0 {
|
||||
c.resubscribe()
|
||||
}
|
||||
|
||||
cn, err := c.base.conn()
|
||||
cn, _, err := c.conn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -274,12 +249,6 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *PubSub) putConn(cn *pool.Conn, err error) {
|
||||
if !c.base.putConn(cn, err, true) {
|
||||
c.nsub = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PubSub) resubscribe() {
|
||||
if c.base.closed() {
|
||||
return
|
||||
|
@ -295,3 +264,31 @@ func (c *PubSub) resubscribe() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func remove(ss []string, es ...string) []string {
|
||||
if len(es) == 0 {
|
||||
return ss[:0]
|
||||
}
|
||||
for _, e := range es {
|
||||
for i, s := range ss {
|
||||
if s == e {
|
||||
ss = append(ss[:i], ss[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
func appendIfNotExists(ss []string, es ...string) []string {
|
||||
loop:
|
||||
for _, e := range es {
|
||||
for _, s := range ss {
|
||||
if s == e {
|
||||
continue loop
|
||||
}
|
||||
}
|
||||
ss = append(ss, e)
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
|
|
@ -288,7 +288,7 @@ var _ = Describe("PubSub", func() {
|
|||
})
|
||||
|
||||
expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
|
||||
cn1, err := pubsub.Pool().Get()
|
||||
cn1, _, err := pubsub.Pool().Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cn1.NetConn = &badConn{
|
||||
readErr: io.EOF,
|
||||
|
|
14
redis.go
14
redis.go
|
@ -27,18 +27,18 @@ func (c *baseClient) String() string {
|
|||
return fmt.Sprintf("Redis<%s db:%d>", c.opt.Addr, c.opt.DB)
|
||||
}
|
||||
|
||||
func (c *baseClient) conn() (*pool.Conn, error) {
|
||||
cn, err := c.connPool.Get()
|
||||
func (c *baseClient) conn() (*pool.Conn, bool, error) {
|
||||
cn, isNew, err := c.connPool.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
if !cn.Inited {
|
||||
if err := c.initConn(cn); err != nil {
|
||||
_ = c.connPool.Remove(cn, err)
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
return cn, err
|
||||
return cn, isNew, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
|
||||
|
@ -84,7 +84,7 @@ func (c *baseClient) Process(cmd Cmder) error {
|
|||
cmd.reset()
|
||||
}
|
||||
|
||||
cn, err := c.conn()
|
||||
cn, _, err := c.conn()
|
||||
if err != nil {
|
||||
cmd.setErr(err)
|
||||
return err
|
||||
|
@ -197,7 +197,7 @@ func (c *Client) pipelineExec(cmds []Cmder) error {
|
|||
var retErr error
|
||||
failedCmds := cmds
|
||||
for i := 0; i <= c.opt.MaxRetries; i++ {
|
||||
cn, err := c.conn()
|
||||
cn, _, err := c.conn()
|
||||
if err != nil {
|
||||
setCmdsErr(failedCmds, err)
|
||||
return err
|
||||
|
|
|
@ -144,7 +144,7 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
// Put bad connection in the pool.
|
||||
cn, err := client.Pool().Get()
|
||||
cn, _, err := client.Pool().Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
cn.NetConn = &badConn{}
|
||||
|
@ -156,7 +156,7 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("should update conn.UsedAt on read/write", func() {
|
||||
cn, err := client.Pool().Get()
|
||||
cn, _, err := client.Pool().Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cn.UsedAt).NotTo(BeZero())
|
||||
createdAt := cn.UsedAt
|
||||
|
@ -168,7 +168,7 @@ var _ = Describe("Client", func() {
|
|||
err = client.Ping().Err()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
cn, err = client.Pool().Get()
|
||||
cn, _, err = client.Pool().Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cn).NotTo(BeNil())
|
||||
Expect(cn.UsedAt.After(createdAt)).To(BeTrue())
|
||||
|
|
2
ring.go
2
ring.go
|
@ -318,7 +318,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) error {
|
|||
|
||||
for name, cmds := range cmdsMap {
|
||||
client := c.shards[name].Client
|
||||
cn, err := client.conn()
|
||||
cn, _, err := client.conn()
|
||||
if err != nil {
|
||||
setCmdsErr(cmds, err)
|
||||
if retErr == nil {
|
||||
|
|
2
tx.go
2
tx.go
|
@ -139,7 +139,7 @@ func (c *Tx) MultiExec(fn func() error) ([]Cmder, error) {
|
|||
// Strip MULTI and EXEC commands.
|
||||
retCmds := cmds[1 : len(cmds)-1]
|
||||
|
||||
cn, err := c.conn()
|
||||
cn, _, err := c.conn()
|
||||
if err != nil {
|
||||
setCmdsErr(retCmds, err)
|
||||
return retCmds, err
|
||||
|
|
|
@ -126,7 +126,7 @@ var _ = Describe("Tx", func() {
|
|||
|
||||
It("should recover from bad connection", func() {
|
||||
// Put bad connection in the pool.
|
||||
cn, err := client.Pool().Get()
|
||||
cn, _, err := client.Pool().Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
cn.NetConn = &badConn{}
|
||||
|
@ -153,7 +153,7 @@ var _ = Describe("Tx", func() {
|
|||
|
||||
It("should recover from bad connection when there are no commands", func() {
|
||||
// Put bad connection in the pool.
|
||||
cn, err := client.Pool().Get()
|
||||
cn, _, err := client.Pool().Get()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
cn.NetConn = &badConn{}
|
||||
|
|
Loading…
Reference in New Issue