Merge pull request #549 from go-redis/fix/pubsub-out-of-pool

PubSub conns don't share connection pool limit
This commit is contained in:
Vladimir Mihailenco 2017-04-17 17:07:07 +03:00 committed by GitHub
commit 4fdc3bb9f9
14 changed files with 205 additions and 160 deletions

View File

@ -1,6 +1,7 @@
package redis
import (
"net"
"time"
"github.com/go-redis/redis/internal/pool"
@ -10,8 +11,8 @@ func (c *baseClient) Pool() pool.Pooler {
return c.connPool
}
func (c *PubSub) Pool() pool.Pooler {
return c.base.connPool
func (c *PubSub) SetNetConn(netConn net.Conn) {
c.cn = pool.NewConn(netConn)
}
func (c *PubSub) ReceiveMessageTimeout(timeout time.Duration) (*Message, error) {

View File

@ -1,7 +1,6 @@
package pool_test
import (
"errors"
"testing"
"time"
@ -40,7 +39,6 @@ func BenchmarkPoolGetPut1000Conns(b *testing.B) {
func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
connPool := pool.NewConnPool(dummyDialer, poolSize, time.Second, time.Hour, time.Hour)
removeReason := errors.New("benchmark")
b.ResetTimer()
@ -50,7 +48,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
if err != nil {
b.Fatal(err)
}
if err := connPool.Remove(cn, removeReason); err != nil {
if err := connPool.Remove(cn); err != nil {
b.Fatal(err)
}
}

View File

@ -2,7 +2,6 @@ package pool
import (
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
@ -11,11 +10,8 @@ import (
"github.com/go-redis/redis/internal"
)
var (
ErrClosed = errors.New("redis: client is closed")
ErrPoolTimeout = errors.New("redis: connection pool timeout")
errConnStale = errors.New("connection is stale")
)
var ErrClosed = errors.New("redis: client is closed")
var ErrPoolTimeout = errors.New("redis: connection pool timeout")
var timers = sync.Pool{
New: func() interface{} {
@ -36,12 +32,17 @@ type Stats struct {
}
type Pooler interface {
NewConn() (*Conn, error)
CloseConn(*Conn) error
Get() (*Conn, bool, error)
Put(*Conn) error
Remove(*Conn, error) error
Remove(*Conn) error
Len() int
FreeLen() int
Stats() *Stats
Close() error
}
@ -87,11 +88,21 @@ func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout, idleCheckF
}
func (p *ConnPool) NewConn() (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
netConn, err := p.dial()
if err != nil {
return nil, err
}
return NewConn(netConn), nil
cn := NewConn(netConn)
p.connsMu.Lock()
p.conns = append(p.conns, cn)
p.connsMu.Unlock()
return cn, nil
}
func (p *ConnPool) PopFree() *Conn {
@ -164,7 +175,7 @@ func (p *ConnPool) Get() (*Conn, bool, error) {
}
if cn.IsStale(p.idleTimeout) {
p.remove(cn, errConnStale)
p.CloseConn(cn)
continue
}
@ -178,18 +189,13 @@ func (p *ConnPool) Get() (*Conn, bool, error) {
return nil, false, err
}
p.connsMu.Lock()
p.conns = append(p.conns, newcn)
p.connsMu.Unlock()
return newcn, true, nil
}
func (p *ConnPool) Put(cn *Conn) error {
if data := cn.Rd.PeekBuffered(); data != nil {
err := fmt.Errorf("connection has unread data: %q", data)
internal.Logf(err.Error())
return p.Remove(cn, err)
internal.Logf("connection has unread data: %q", data)
return p.Remove(cn)
}
p.freeConnsMu.Lock()
p.freeConns = append(p.freeConns, cn)
@ -198,15 +204,13 @@ func (p *ConnPool) Put(cn *Conn) error {
return nil
}
func (p *ConnPool) Remove(cn *Conn, reason error) error {
p.remove(cn, reason)
func (p *ConnPool) Remove(cn *Conn) error {
_ = p.CloseConn(cn)
<-p.queue
return nil
}
func (p *ConnPool) remove(cn *Conn, reason error) {
_ = p.closeConn(cn, reason)
func (p *ConnPool) CloseConn(cn *Conn) error {
p.connsMu.Lock()
for i, c := range p.conns {
if c == cn {
@ -215,6 +219,15 @@ func (p *ConnPool) remove(cn *Conn, reason error) {
}
}
p.connsMu.Unlock()
return p.closeConn(cn)
}
func (p *ConnPool) closeConn(cn *Conn) error {
if p.OnClose != nil {
_ = p.OnClose(cn)
}
return cn.Close()
}
// Len returns total number of connections.
@ -258,7 +271,7 @@ func (p *ConnPool) Close() error {
if cn == nil {
continue
}
if err := p.closeConn(cn, ErrClosed); err != nil && firstErr == nil {
if err := p.closeConn(cn); err != nil && firstErr == nil {
firstErr = err
}
}
@ -272,13 +285,6 @@ func (p *ConnPool) Close() error {
return firstErr
}
func (p *ConnPool) closeConn(cn *Conn, reason error) error {
if p.OnClose != nil {
_ = p.OnClose(cn)
}
return cn.Close()
}
func (p *ConnPool) reapStaleConn() bool {
if len(p.freeConns) == 0 {
return false
@ -289,7 +295,7 @@ func (p *ConnPool) reapStaleConn() bool {
return false
}
p.remove(cn, errConnStale)
p.CloseConn(cn)
p.freeConns = append(p.freeConns[:0], p.freeConns[1:]...)
return true

View File

@ -12,6 +12,14 @@ func NewSingleConnPool(cn *Conn) *SingleConnPool {
}
}
func (p *SingleConnPool) NewConn() (*Conn, error) {
panic("not implemented")
}
func (p *SingleConnPool) CloseConn(*Conn) error {
panic("not implemented")
}
func (p *SingleConnPool) Get() (*Conn, bool, error) {
return p.cn, false, nil
}
@ -23,7 +31,7 @@ func (p *SingleConnPool) Put(cn *Conn) error {
return nil
}
func (p *SingleConnPool) Remove(cn *Conn, _ error) error {
func (p *SingleConnPool) Remove(cn *Conn) error {
if p.cn != cn {
panic("p.cn != cn")
}

View File

@ -1,9 +1,6 @@
package pool
import (
"errors"
"sync"
)
import "sync"
type StickyConnPool struct {
pool *ConnPool
@ -23,6 +20,14 @@ func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool {
}
}
func (p *StickyConnPool) NewConn() (*Conn, error) {
panic("not implemented")
}
func (p *StickyConnPool) CloseConn(*Conn) error {
panic("not implemented")
}
func (p *StickyConnPool) Get() (*Conn, bool, error) {
p.mu.Lock()
defer p.mu.Unlock()
@ -58,20 +63,20 @@ func (p *StickyConnPool) Put(cn *Conn) error {
return nil
}
func (p *StickyConnPool) removeUpstream(reason error) error {
err := p.pool.Remove(p.cn, reason)
func (p *StickyConnPool) removeUpstream() error {
err := p.pool.Remove(p.cn)
p.cn = nil
return err
}
func (p *StickyConnPool) Remove(cn *Conn, reason error) error {
func (p *StickyConnPool) Remove(cn *Conn) error {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return nil
}
return p.removeUpstream(reason)
return p.removeUpstream()
}
func (p *StickyConnPool) Len() int {
@ -111,8 +116,7 @@ func (p *StickyConnPool) Close() error {
if p.reusable {
err = p.putUpstream()
} else {
reason := errors.New("redis: unreusable sticky connection")
err = p.removeUpstream(reason)
err = p.removeUpstream()
}
}
return err

View File

@ -1,7 +1,6 @@
package pool_test
import (
"errors"
"testing"
"time"
@ -59,7 +58,7 @@ var _ = Describe("ConnPool", func() {
// ok
}
err = connPool.Remove(cn, errors.New("test"))
err = connPool.Remove(cn)
Expect(err).NotTo(HaveOccurred())
// Check that Ping is unblocked.
@ -169,7 +168,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(4))
Expect(connPool.FreeLen()).To(Equal(0))
err = connPool.Remove(cn, errors.New("test"))
err = connPool.Remove(cn)
Expect(err).NotTo(HaveOccurred())
Expect(connPool.Len()).To(Equal(3))
@ -219,7 +218,7 @@ var _ = Describe("race", func() {
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil {
Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred())
Expect(connPool.Remove(cn)).NotTo(HaveOccurred())
}
}
})

View File

@ -47,7 +47,7 @@ type Options struct {
WriteTimeout time.Duration
// Maximum number of socket connections.
// Default is 100 connections.
// Default is 10 connections.
PoolSize int
// Amount of time client waits for connection if all connections
// are busy before returning an error.
@ -84,7 +84,7 @@ func (opt *Options) init() {
}
}
if opt.PoolSize == 0 {
opt.PoolSize = 100
opt.PoolSize = 10
}
if opt.DialTimeout == 0 {
opt.DialTimeout = 5 * time.Second

View File

@ -77,18 +77,6 @@ var _ = Describe("pool", func() {
Expect(pool.Len()).To(Equal(pool.FreeLen()))
})
It("respects max size on pubsub", func() {
connPool := client.Pool()
perform(1000, func(id int) {
pubsub := client.Subscribe("test")
Expect(pubsub.Close()).NotTo(HaveOccurred())
})
Expect(connPool.Len()).To(Equal(connPool.FreeLen()))
Expect(connPool.Len()).To(BeNumerically("<=", 10))
})
It("removes broken connections", func() {
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())

165
pubsub.go
View File

@ -3,6 +3,7 @@ package redis
import (
"fmt"
"net"
"sync"
"time"
"github.com/go-redis/redis/internal"
@ -14,25 +15,82 @@ import (
// multiple goroutines.
type PubSub struct {
base baseClient
cmd *Cmd
mu sync.Mutex
cn *pool.Conn
closed bool
cmd *Cmd
subMu sync.Mutex
channels []string
patterns []string
}
func (c *PubSub) conn() (*pool.Conn, bool, error) {
cn, isNew, err := c.base.conn()
func (c *PubSub) conn() (*pool.Conn, error) {
cn, isNew, err := c._conn()
if err != nil {
return nil, err
}
if isNew {
if err := c.resubscribe(); err != nil {
internal.Logf("resubscribe failed: %s", err)
}
}
return cn, nil
}
func (c *PubSub) resubscribe() error {
c.subMu.Lock()
channels := c.channels
patterns := c.patterns
c.subMu.Unlock()
var firstErr error
if len(channels) > 0 {
if err := c.subscribe("subscribe", channels...); err != nil && firstErr == nil {
firstErr = err
}
}
if len(patterns) > 0 {
if err := c.subscribe("psubscribe", patterns...); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
func (c *PubSub) _conn() (*pool.Conn, bool, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil, false, pool.ErrClosed
}
if c.cn != nil {
return c.cn, false, nil
}
cn, err := c.base.connPool.NewConn()
if err != nil {
return nil, false, err
}
if isNew {
c.resubscribe()
}
return cn, isNew, nil
c.cn = cn
return cn, true, nil
}
func (c *PubSub) putConn(cn *pool.Conn, err error) {
c.base.putConn(cn, err, true)
if internal.IsBadConn(err, true) {
c.mu.Lock()
if c.cn == cn {
_ = c.closeConn()
}
c.mu.Unlock()
}
}
func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
@ -43,11 +101,15 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
}
cmd := NewSliceCmd(args...)
cn, _, err := c.conn()
cn, isNew, err := c._conn()
if err != nil {
return err
}
if isNew {
return c.resubscribe()
}
cn.SetWriteTimeout(c.base.opt.WriteTimeout)
err = writeCmd(cn, cmd)
c.putConn(cn, err)
@ -56,46 +118,68 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
// Subscribes the client to the specified channels.
func (c *PubSub) Subscribe(channels ...string) error {
err := c.subscribe("SUBSCRIBE", channels...)
c.subMu.Lock()
c.channels = appendIfNotExists(c.channels, channels...)
return err
c.subMu.Unlock()
return c.subscribe("subscribe", channels...)
}
// Subscribes the client to the given patterns.
func (c *PubSub) PSubscribe(patterns ...string) error {
err := c.subscribe("PSUBSCRIBE", patterns...)
c.subMu.Lock()
c.patterns = appendIfNotExists(c.patterns, patterns...)
return err
c.subMu.Unlock()
return c.subscribe("psubscribe", patterns...)
}
// Unsubscribes the client from the given channels, or from all of
// them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error {
err := c.subscribe("UNSUBSCRIBE", channels...)
c.subMu.Lock()
c.channels = remove(c.channels, channels...)
return err
c.subMu.Unlock()
return c.subscribe("unsubscribe", channels...)
}
// Unsubscribes the client from the given patterns, or from all of
// them if none is given.
func (c *PubSub) PUnsubscribe(patterns ...string) error {
err := c.subscribe("PUNSUBSCRIBE", patterns...)
c.subMu.Lock()
c.patterns = remove(c.patterns, patterns...)
return err
c.subMu.Unlock()
return c.subscribe("punsubscribe", patterns...)
}
func (c *PubSub) Close() error {
return c.base.Close()
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return pool.ErrClosed
}
c.closed = true
if c.cn != nil {
_ = c.closeConn()
}
return nil
}
func (c *PubSub) closeConn() error {
err := c.base.connPool.CloseConn(c.cn)
c.cn = nil
return err
}
func (c *PubSub) Ping(payload ...string) error {
args := []interface{}{"PING"}
args := []interface{}{"ping"}
if len(payload) == 1 {
args = append(args, payload[0])
}
cmd := NewCmd(args...)
cn, _, err := c.conn()
cn, err := c.conn()
if err != nil {
return err
}
@ -188,7 +272,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
c.cmd = NewCmd()
}
cn, _, err := c.conn()
cn, err := c.conn()
if err != nil {
return nil, err
}
@ -259,19 +343,6 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) {
}
}
func (c *PubSub) resubscribe() {
if len(c.channels) > 0 {
if err := c.Subscribe(c.channels...); err != nil {
internal.Logf("Subscribe failed: %s", err)
}
}
if len(c.patterns) > 0 {
if err := c.PSubscribe(c.patterns...); err != nil {
internal.Logf("PSubscribe failed: %s", err)
}
}
}
// Channel returns a channel for concurrently receiving messages.
// The channel is closed with PubSub.
func (c *PubSub) Channel() <-chan *Message {
@ -292,6 +363,19 @@ func (c *PubSub) Channel() <-chan *Message {
return ch
}
func appendIfNotExists(ss []string, es ...string) []string {
loop:
for _, e := range es {
for _, s := range ss {
if s == e {
continue loop
}
}
ss = append(ss, e)
}
return ss
}
func remove(ss []string, es ...string) []string {
if len(es) == 0 {
return ss[:0]
@ -306,16 +390,3 @@ func remove(ss []string, es ...string) []string {
}
return ss
}
func appendIfNotExists(ss []string, es ...string) []string {
loop:
for _, e := range es {
for _, s := range ss {
if s == e {
continue loop
}
}
ss = append(ss, e)
}
return ss
}

View File

@ -274,18 +274,15 @@ var _ = Describe("PubSub", func() {
Eventually(done).Should(Receive())
stats := client.PoolStats()
Expect(stats.Requests).To(Equal(uint32(3)))
Expect(stats.Requests).To(Equal(uint32(2)))
Expect(stats.Hits).To(Equal(uint32(1)))
})
expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
cn, _, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{
pubsub.SetNetConn(&badConn{
readErr: io.EOF,
writeErr: io.EOF,
})
pubsub.Pool().Put(cn)
done := make(chan bool, 1)
go func() {
@ -305,10 +302,6 @@ var _ = Describe("PubSub", func() {
Expect(msg.Payload).To(Equal("hello"))
Eventually(done).Should(Receive())
stats := client.PoolStats()
Expect(stats.Requests).To(Equal(uint32(4)))
Expect(stats.Hits).To(Equal(uint32(1)))
}
It("Subscribe should reconnect on ReceiveMessage error", func() {

View File

@ -136,35 +136,6 @@ var _ = Describe("races", func() {
})
})
It("should PubSub", func() {
connPool := client.Pool()
perform(C, func(id int) {
for i := 0; i < N; i++ {
pubsub := client.Subscribe(fmt.Sprintf("mychannel%d", id))
go func() {
defer GinkgoRecover()
time.Sleep(time.Millisecond)
err := pubsub.Close()
Expect(err).NotTo(HaveOccurred())
}()
_, err := pubsub.ReceiveMessage()
Expect(err.Error()).To(ContainSubstring("closed"))
val := "echo" + strconv.Itoa(i)
echo, err := client.Echo(val).Result()
Expect(err).NotTo(HaveOccurred())
Expect(echo).To(Equal(val))
}
})
Expect(connPool.Len()).To(Equal(connPool.FreeLen()))
Expect(connPool.Len()).To(BeNumerically("<=", 10))
})
It("should select db", func() {
err := client.Set("db", 1, 0).Err()
Expect(err).NotTo(HaveOccurred())

View File

@ -31,9 +31,10 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) {
if err != nil {
return nil, false, err
}
if !cn.Inited {
if err := c.initConn(cn); err != nil {
_ = c.connPool.Remove(cn, err)
_ = c.connPool.Remove(cn)
return nil, false, err
}
}
@ -42,7 +43,7 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) {
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
if internal.IsBadConn(err, allowTimeout) {
_ = c.connPool.Remove(cn, err)
_ = c.connPool.Remove(cn)
return false
}
@ -353,7 +354,7 @@ func (c *Client) pubSub() *PubSub {
return &PubSub{
base: baseClient{
opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), false),
connPool: c.connPool,
},
}
}

View File

@ -95,7 +95,7 @@ var _ = Describe("Client", func() {
Expect(client.Close()).NotTo(HaveOccurred())
_, err := pubsub.Receive()
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError("redis: client is closed"))
Expect(pubsub.Close()).NotTo(HaveOccurred())
})
@ -217,6 +217,7 @@ var _ = Describe("Client", func() {
})
var _ = Describe("Client timeout", func() {
var opt *redis.Options
var client *redis.Client
AfterEach(func() {
@ -240,7 +241,13 @@ var _ = Describe("Client timeout", func() {
})
It("Subscribe timeouts", func() {
if opt.WriteTimeout == 0 {
return
}
pubsub := client.Subscribe()
defer pubsub.Close()
err := pubsub.Subscribe("_")
Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue())
@ -269,7 +276,7 @@ var _ = Describe("Client timeout", func() {
Context("read timeout", func() {
BeforeEach(func() {
opt := redisOptions()
opt = redisOptions()
opt.ReadTimeout = time.Nanosecond
opt.WriteTimeout = -1
client = redis.NewClient(opt)
@ -280,7 +287,7 @@ var _ = Describe("Client timeout", func() {
Context("write timeout", func() {
BeforeEach(func() {
opt := redisOptions()
opt = redisOptions()
opt.ReadTimeout = -1
opt.WriteTimeout = time.Nanosecond
client = redis.NewClient(opt)

View File

@ -2,7 +2,6 @@ package redis
import (
"errors"
"fmt"
"net"
"strings"
"sync"
@ -111,7 +110,7 @@ func (c *sentinelClient) PubSub() *PubSub {
return &PubSub{
base: baseClient{
opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), false),
connPool: c.connPool,
},
}
}
@ -268,12 +267,11 @@ func (d *sentinelFailover) closeOldConns(newMaster string) {
break
}
if cn.RemoteAddr().String() != newMaster {
err := fmt.Errorf(
internal.Logf(
"sentinel: closing connection to the old master %s",
cn.RemoteAddr(),
)
internal.Logf(err.Error())
d.pool.Remove(cn, err)
d.pool.Remove(cn)
} else {
cnsToPut = append(cnsToPut, cn)
}