PubSub conns don't share connection pool limit

This commit is contained in:
Vladimir Mihailenco 2017-04-17 15:43:58 +03:00
parent aeb22d6a37
commit 6499563e07
14 changed files with 180 additions and 153 deletions

View File

@ -1,6 +1,7 @@
package redis package redis
import ( import (
"net"
"time" "time"
"github.com/go-redis/redis/internal/pool" "github.com/go-redis/redis/internal/pool"
@ -10,8 +11,8 @@ func (c *baseClient) Pool() pool.Pooler {
return c.connPool return c.connPool
} }
func (c *PubSub) Pool() pool.Pooler { func (c *PubSub) SetNetConn(netConn net.Conn) {
return c.base.connPool c.cn = pool.NewConn(netConn)
} }
func (c *PubSub) ReceiveMessageTimeout(timeout time.Duration) (*Message, error) { func (c *PubSub) ReceiveMessageTimeout(timeout time.Duration) (*Message, error) {

View File

@ -1,7 +1,6 @@
package pool_test package pool_test
import ( import (
"errors"
"testing" "testing"
"time" "time"
@ -40,7 +39,6 @@ func BenchmarkPoolGetPut1000Conns(b *testing.B) {
func benchmarkPoolGetRemove(b *testing.B, poolSize int) { func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
connPool := pool.NewConnPool(dummyDialer, poolSize, time.Second, time.Hour, time.Hour) connPool := pool.NewConnPool(dummyDialer, poolSize, time.Second, time.Hour, time.Hour)
removeReason := errors.New("benchmark")
b.ResetTimer() b.ResetTimer()
@ -50,7 +48,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
if err := connPool.Remove(cn, removeReason); err != nil { if err := connPool.Remove(cn); err != nil {
b.Fatal(err) b.Fatal(err)
} }
} }

View File

@ -2,7 +2,6 @@ package pool
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -11,11 +10,8 @@ import (
"github.com/go-redis/redis/internal" "github.com/go-redis/redis/internal"
) )
var ( var ErrClosed = errors.New("redis: client is closed")
ErrClosed = errors.New("redis: client is closed") var ErrPoolTimeout = errors.New("redis: connection pool timeout")
ErrPoolTimeout = errors.New("redis: connection pool timeout")
errConnStale = errors.New("connection is stale")
)
var timers = sync.Pool{ var timers = sync.Pool{
New: func() interface{} { New: func() interface{} {
@ -36,12 +32,17 @@ type Stats struct {
} }
type Pooler interface { type Pooler interface {
NewConn() (*Conn, error)
CloseConn(*Conn) error
Get() (*Conn, bool, error) Get() (*Conn, bool, error)
Put(*Conn) error Put(*Conn) error
Remove(*Conn, error) error Remove(*Conn) error
Len() int Len() int
FreeLen() int FreeLen() int
Stats() *Stats Stats() *Stats
Close() error Close() error
} }
@ -87,11 +88,21 @@ func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout, idleCheckF
} }
func (p *ConnPool) NewConn() (*Conn, error) { func (p *ConnPool) NewConn() (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
netConn, err := p.dial() netConn, err := p.dial()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewConn(netConn), nil
cn := NewConn(netConn)
p.connsMu.Lock()
p.conns = append(p.conns, cn)
p.connsMu.Unlock()
return cn, nil
} }
func (p *ConnPool) PopFree() *Conn { func (p *ConnPool) PopFree() *Conn {
@ -164,7 +175,7 @@ func (p *ConnPool) Get() (*Conn, bool, error) {
} }
if cn.IsStale(p.idleTimeout) { if cn.IsStale(p.idleTimeout) {
p.remove(cn, errConnStale) p.CloseConn(cn)
continue continue
} }
@ -178,18 +189,13 @@ func (p *ConnPool) Get() (*Conn, bool, error) {
return nil, false, err return nil, false, err
} }
p.connsMu.Lock()
p.conns = append(p.conns, newcn)
p.connsMu.Unlock()
return newcn, true, nil return newcn, true, nil
} }
func (p *ConnPool) Put(cn *Conn) error { func (p *ConnPool) Put(cn *Conn) error {
if data := cn.Rd.PeekBuffered(); data != nil { if data := cn.Rd.PeekBuffered(); data != nil {
err := fmt.Errorf("connection has unread data: %q", data) internal.Logf("connection has unread data: %q", data)
internal.Logf(err.Error()) return p.Remove(cn)
return p.Remove(cn, err)
} }
p.freeConnsMu.Lock() p.freeConnsMu.Lock()
p.freeConns = append(p.freeConns, cn) p.freeConns = append(p.freeConns, cn)
@ -198,15 +204,13 @@ func (p *ConnPool) Put(cn *Conn) error {
return nil return nil
} }
func (p *ConnPool) Remove(cn *Conn, reason error) error { func (p *ConnPool) Remove(cn *Conn) error {
p.remove(cn, reason) _ = p.CloseConn(cn)
<-p.queue <-p.queue
return nil return nil
} }
func (p *ConnPool) remove(cn *Conn, reason error) { func (p *ConnPool) CloseConn(cn *Conn) error {
_ = p.closeConn(cn, reason)
p.connsMu.Lock() p.connsMu.Lock()
for i, c := range p.conns { for i, c := range p.conns {
if c == cn { if c == cn {
@ -215,6 +219,15 @@ func (p *ConnPool) remove(cn *Conn, reason error) {
} }
} }
p.connsMu.Unlock() p.connsMu.Unlock()
return p.closeConn(cn)
}
func (p *ConnPool) closeConn(cn *Conn) error {
if p.OnClose != nil {
_ = p.OnClose(cn)
}
return cn.Close()
} }
// Len returns total number of connections. // Len returns total number of connections.
@ -258,7 +271,7 @@ func (p *ConnPool) Close() error {
if cn == nil { if cn == nil {
continue continue
} }
if err := p.closeConn(cn, ErrClosed); err != nil && firstErr == nil { if err := p.closeConn(cn); err != nil && firstErr == nil {
firstErr = err firstErr = err
} }
} }
@ -272,13 +285,6 @@ func (p *ConnPool) Close() error {
return firstErr return firstErr
} }
func (p *ConnPool) closeConn(cn *Conn, reason error) error {
if p.OnClose != nil {
_ = p.OnClose(cn)
}
return cn.Close()
}
func (p *ConnPool) reapStaleConn() bool { func (p *ConnPool) reapStaleConn() bool {
if len(p.freeConns) == 0 { if len(p.freeConns) == 0 {
return false return false
@ -289,7 +295,7 @@ func (p *ConnPool) reapStaleConn() bool {
return false return false
} }
p.remove(cn, errConnStale) p.CloseConn(cn)
p.freeConns = append(p.freeConns[:0], p.freeConns[1:]...) p.freeConns = append(p.freeConns[:0], p.freeConns[1:]...)
return true return true

View File

@ -12,6 +12,14 @@ func NewSingleConnPool(cn *Conn) *SingleConnPool {
} }
} }
func (p *SingleConnPool) NewConn() (*Conn, error) {
panic("not implemented")
}
func (p *SingleConnPool) CloseConn(*Conn) error {
panic("not implemented")
}
func (p *SingleConnPool) Get() (*Conn, bool, error) { func (p *SingleConnPool) Get() (*Conn, bool, error) {
return p.cn, false, nil return p.cn, false, nil
} }
@ -23,7 +31,7 @@ func (p *SingleConnPool) Put(cn *Conn) error {
return nil return nil
} }
func (p *SingleConnPool) Remove(cn *Conn, _ error) error { func (p *SingleConnPool) Remove(cn *Conn) error {
if p.cn != cn { if p.cn != cn {
panic("p.cn != cn") panic("p.cn != cn")
} }

View File

@ -1,9 +1,6 @@
package pool package pool
import ( import "sync"
"errors"
"sync"
)
type StickyConnPool struct { type StickyConnPool struct {
pool *ConnPool pool *ConnPool
@ -23,6 +20,14 @@ func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool {
} }
} }
func (p *StickyConnPool) NewConn() (*Conn, error) {
panic("not implemented")
}
func (p *StickyConnPool) CloseConn(*Conn) error {
panic("not implemented")
}
func (p *StickyConnPool) Get() (*Conn, bool, error) { func (p *StickyConnPool) Get() (*Conn, bool, error) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
@ -58,20 +63,20 @@ func (p *StickyConnPool) Put(cn *Conn) error {
return nil return nil
} }
func (p *StickyConnPool) removeUpstream(reason error) error { func (p *StickyConnPool) removeUpstream() error {
err := p.pool.Remove(p.cn, reason) err := p.pool.Remove(p.cn)
p.cn = nil p.cn = nil
return err return err
} }
func (p *StickyConnPool) Remove(cn *Conn, reason error) error { func (p *StickyConnPool) Remove(cn *Conn) error {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
if p.closed { if p.closed {
return nil return nil
} }
return p.removeUpstream(reason) return p.removeUpstream()
} }
func (p *StickyConnPool) Len() int { func (p *StickyConnPool) Len() int {
@ -111,8 +116,7 @@ func (p *StickyConnPool) Close() error {
if p.reusable { if p.reusable {
err = p.putUpstream() err = p.putUpstream()
} else { } else {
reason := errors.New("redis: unreusable sticky connection") err = p.removeUpstream()
err = p.removeUpstream(reason)
} }
} }
return err return err

View File

@ -1,7 +1,6 @@
package pool_test package pool_test
import ( import (
"errors"
"testing" "testing"
"time" "time"
@ -59,7 +58,7 @@ var _ = Describe("ConnPool", func() {
// ok // ok
} }
err = connPool.Remove(cn, errors.New("test")) err = connPool.Remove(cn)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Check that Ping is unblocked. // Check that Ping is unblocked.
@ -169,7 +168,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(4)) Expect(connPool.Len()).To(Equal(4))
Expect(connPool.FreeLen()).To(Equal(0)) Expect(connPool.FreeLen()).To(Equal(0))
err = connPool.Remove(cn, errors.New("test")) err = connPool.Remove(cn)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(connPool.Len()).To(Equal(3)) Expect(connPool.Len()).To(Equal(3))
@ -219,7 +218,7 @@ var _ = Describe("race", func() {
cn, _, err := connPool.Get() cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
if err == nil { if err == nil {
Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred()) Expect(connPool.Remove(cn)).NotTo(HaveOccurred())
} }
} }
}) })

View File

@ -84,7 +84,7 @@ func (opt *Options) init() {
} }
} }
if opt.PoolSize == 0 { if opt.PoolSize == 0 {
opt.PoolSize = 100 opt.PoolSize = 10
} }
if opt.DialTimeout == 0 { if opt.DialTimeout == 0 {
opt.DialTimeout = 5 * time.Second opt.DialTimeout = 5 * time.Second

View File

@ -77,18 +77,6 @@ var _ = Describe("pool", func() {
Expect(pool.Len()).To(Equal(pool.FreeLen())) Expect(pool.Len()).To(Equal(pool.FreeLen()))
}) })
It("respects max size on pubsub", func() {
connPool := client.Pool()
perform(1000, func(id int) {
pubsub := client.Subscribe("test")
Expect(pubsub.Close()).NotTo(HaveOccurred())
})
Expect(connPool.Len()).To(Equal(connPool.FreeLen()))
Expect(connPool.Len()).To(BeNumerically("<=", 10))
})
It("removes broken connections", func() { It("removes broken connections", func() {
cn, _, err := client.Pool().Get() cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())

135
pubsub.go
View File

@ -3,6 +3,7 @@ package redis
import ( import (
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
"github.com/go-redis/redis/internal" "github.com/go-redis/redis/internal"
@ -14,25 +15,72 @@ import (
// multiple goroutines. // multiple goroutines.
type PubSub struct { type PubSub struct {
base baseClient base baseClient
cmd *Cmd
mu sync.Mutex
cn *pool.Conn
closed bool
cmd *Cmd
channels []string channels []string
patterns []string patterns []string
} }
func (c *PubSub) conn() (*pool.Conn, bool, error) { func (c *PubSub) conn() (*pool.Conn, error) {
cn, isNew, err := c.base.conn() cn, isNew, err := c._conn()
if err != nil { if err != nil {
return nil, false, err return nil, err
} }
if isNew { if isNew {
c.resubscribe() c.resubscribe()
} }
return cn, isNew, nil
return cn, nil
}
func (c *PubSub) resubscribe() {
if len(c.channels) > 0 {
if err := c.subscribe("subscribe", c.channels...); err != nil {
internal.Logf("Subscribe failed: %s", err)
}
}
if len(c.patterns) > 0 {
if err := c.subscribe("psubscribe", c.patterns...); err != nil {
internal.Logf("PSubscribe failed: %s", err)
}
}
}
func (c *PubSub) _conn() (*pool.Conn, bool, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil, false, pool.ErrClosed
}
if c.cn != nil {
return c.cn, false, nil
}
cn, err := c.base.connPool.NewConn()
if err != nil {
return nil, false, err
}
c.cn = cn
return cn, true, nil
} }
func (c *PubSub) putConn(cn *pool.Conn, err error) { func (c *PubSub) putConn(cn *pool.Conn, err error) {
c.base.putConn(cn, err, true) if internal.IsBadConn(err, true) {
c.mu.Lock()
if c.cn == cn {
_ = c.closeConn()
}
c.mu.Unlock()
}
} }
func (c *PubSub) subscribe(redisCmd string, channels ...string) error { func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
@ -43,7 +91,7 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
} }
cmd := NewSliceCmd(args...) cmd := NewSliceCmd(args...)
cn, _, err := c.conn() cn, err := c.conn()
if err != nil { if err != nil {
return err return err
} }
@ -56,14 +104,14 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
// Subscribes the client to the specified channels. // Subscribes the client to the specified channels.
func (c *PubSub) Subscribe(channels ...string) error { func (c *PubSub) Subscribe(channels ...string) error {
err := c.subscribe("SUBSCRIBE", channels...) err := c.subscribe("subscribe", channels...)
c.channels = appendIfNotExists(c.channels, channels...) c.channels = appendIfNotExists(c.channels, channels...)
return err return err
} }
// Subscribes the client to the given patterns. // Subscribes the client to the given patterns.
func (c *PubSub) PSubscribe(patterns ...string) error { func (c *PubSub) PSubscribe(patterns ...string) error {
err := c.subscribe("PSUBSCRIBE", patterns...) err := c.subscribe("psubscribe", patterns...)
c.patterns = appendIfNotExists(c.patterns, patterns...) c.patterns = appendIfNotExists(c.patterns, patterns...)
return err return err
} }
@ -71,7 +119,7 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
// Unsubscribes the client from the given channels, or from all of // Unsubscribes the client from the given channels, or from all of
// them if none is given. // them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error { func (c *PubSub) Unsubscribe(channels ...string) error {
err := c.subscribe("UNSUBSCRIBE", channels...) err := c.subscribe("unsubscribe", channels...)
c.channels = remove(c.channels, channels...) c.channels = remove(c.channels, channels...)
return err return err
} }
@ -79,23 +127,41 @@ func (c *PubSub) Unsubscribe(channels ...string) error {
// Unsubscribes the client from the given patterns, or from all of // Unsubscribes the client from the given patterns, or from all of
// them if none is given. // them if none is given.
func (c *PubSub) PUnsubscribe(patterns ...string) error { func (c *PubSub) PUnsubscribe(patterns ...string) error {
err := c.subscribe("PUNSUBSCRIBE", patterns...) err := c.subscribe("punsubscribe", patterns...)
c.patterns = remove(c.patterns, patterns...) c.patterns = remove(c.patterns, patterns...)
return err return err
} }
func (c *PubSub) Close() error { func (c *PubSub) Close() error {
return c.base.Close() c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return pool.ErrClosed
}
c.closed = true
if c.cn != nil {
_ = c.closeConn()
}
return nil
}
func (c *PubSub) closeConn() error {
err := c.base.connPool.CloseConn(c.cn)
c.cn = nil
return err
} }
func (c *PubSub) Ping(payload ...string) error { func (c *PubSub) Ping(payload ...string) error {
args := []interface{}{"PING"} args := []interface{}{"ping"}
if len(payload) == 1 { if len(payload) == 1 {
args = append(args, payload[0]) args = append(args, payload[0])
} }
cmd := NewCmd(args...) cmd := NewCmd(args...)
cn, _, err := c.conn() cn, err := c.conn()
if err != nil { if err != nil {
return err return err
} }
@ -188,7 +254,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
c.cmd = NewCmd() c.cmd = NewCmd()
} }
cn, _, err := c.conn() cn, err := c.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -259,19 +325,6 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) {
} }
} }
func (c *PubSub) resubscribe() {
if len(c.channels) > 0 {
if err := c.Subscribe(c.channels...); err != nil {
internal.Logf("Subscribe failed: %s", err)
}
}
if len(c.patterns) > 0 {
if err := c.PSubscribe(c.patterns...); err != nil {
internal.Logf("PSubscribe failed: %s", err)
}
}
}
// Channel returns a channel for concurrently receiving messages. // Channel returns a channel for concurrently receiving messages.
// The channel is closed with PubSub. // The channel is closed with PubSub.
func (c *PubSub) Channel() <-chan *Message { func (c *PubSub) Channel() <-chan *Message {
@ -292,6 +345,19 @@ func (c *PubSub) Channel() <-chan *Message {
return ch return ch
} }
func appendIfNotExists(ss []string, es ...string) []string {
loop:
for _, e := range es {
for _, s := range ss {
if s == e {
continue loop
}
}
ss = append(ss, e)
}
return ss
}
func remove(ss []string, es ...string) []string { func remove(ss []string, es ...string) []string {
if len(es) == 0 { if len(es) == 0 {
return ss[:0] return ss[:0]
@ -306,16 +372,3 @@ func remove(ss []string, es ...string) []string {
} }
return ss return ss
} }
func appendIfNotExists(ss []string, es ...string) []string {
loop:
for _, e := range es {
for _, s := range ss {
if s == e {
continue loop
}
}
ss = append(ss, e)
}
return ss
}

View File

@ -274,18 +274,15 @@ var _ = Describe("PubSub", func() {
Eventually(done).Should(Receive()) Eventually(done).Should(Receive())
stats := client.PoolStats() stats := client.PoolStats()
Expect(stats.Requests).To(Equal(uint32(3))) Expect(stats.Requests).To(Equal(uint32(2)))
Expect(stats.Hits).To(Equal(uint32(1))) Expect(stats.Hits).To(Equal(uint32(1)))
}) })
expectReceiveMessageOnError := func(pubsub *redis.PubSub) { expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
cn, _, err := pubsub.Pool().Get() pubsub.SetNetConn(&badConn{
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{
readErr: io.EOF, readErr: io.EOF,
writeErr: io.EOF, writeErr: io.EOF,
}) })
pubsub.Pool().Put(cn)
done := make(chan bool, 1) done := make(chan bool, 1)
go func() { go func() {
@ -305,10 +302,6 @@ var _ = Describe("PubSub", func() {
Expect(msg.Payload).To(Equal("hello")) Expect(msg.Payload).To(Equal("hello"))
Eventually(done).Should(Receive()) Eventually(done).Should(Receive())
stats := client.PoolStats()
Expect(stats.Requests).To(Equal(uint32(4)))
Expect(stats.Hits).To(Equal(uint32(1)))
} }
It("Subscribe should reconnect on ReceiveMessage error", func() { It("Subscribe should reconnect on ReceiveMessage error", func() {

View File

@ -136,35 +136,6 @@ var _ = Describe("races", func() {
}) })
}) })
It("should PubSub", func() {
connPool := client.Pool()
perform(C, func(id int) {
for i := 0; i < N; i++ {
pubsub := client.Subscribe(fmt.Sprintf("mychannel%d", id))
go func() {
defer GinkgoRecover()
time.Sleep(time.Millisecond)
err := pubsub.Close()
Expect(err).NotTo(HaveOccurred())
}()
_, err := pubsub.ReceiveMessage()
Expect(err.Error()).To(ContainSubstring("closed"))
val := "echo" + strconv.Itoa(i)
echo, err := client.Echo(val).Result()
Expect(err).NotTo(HaveOccurred())
Expect(echo).To(Equal(val))
}
})
Expect(connPool.Len()).To(Equal(connPool.FreeLen()))
Expect(connPool.Len()).To(BeNumerically("<=", 10))
})
It("should select db", func() { It("should select db", func() {
err := client.Set("db", 1, 0).Err() err := client.Set("db", 1, 0).Err()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())

View File

@ -31,9 +31,10 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) {
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
if !cn.Inited { if !cn.Inited {
if err := c.initConn(cn); err != nil { if err := c.initConn(cn); err != nil {
_ = c.connPool.Remove(cn, err) _ = c.connPool.Remove(cn)
return nil, false, err return nil, false, err
} }
} }
@ -42,7 +43,7 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) {
func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool {
if internal.IsBadConn(err, allowTimeout) { if internal.IsBadConn(err, allowTimeout) {
_ = c.connPool.Remove(cn, err) _ = c.connPool.Remove(cn)
return false return false
} }
@ -353,7 +354,7 @@ func (c *Client) pubSub() *PubSub {
return &PubSub{ return &PubSub{
base: baseClient{ base: baseClient{
opt: c.opt, opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), false), connPool: c.connPool,
}, },
} }
} }

View File

@ -95,7 +95,7 @@ var _ = Describe("Client", func() {
Expect(client.Close()).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred())
_, err := pubsub.Receive() _, err := pubsub.Receive()
Expect(err).To(HaveOccurred()) Expect(err).To(MatchError("redis: client is closed"))
Expect(pubsub.Close()).NotTo(HaveOccurred()) Expect(pubsub.Close()).NotTo(HaveOccurred())
}) })
@ -217,6 +217,7 @@ var _ = Describe("Client", func() {
}) })
var _ = Describe("Client timeout", func() { var _ = Describe("Client timeout", func() {
var opt *redis.Options
var client *redis.Client var client *redis.Client
AfterEach(func() { AfterEach(func() {
@ -240,7 +241,13 @@ var _ = Describe("Client timeout", func() {
}) })
It("Subscribe timeouts", func() { It("Subscribe timeouts", func() {
if opt.WriteTimeout == 0 {
return
}
pubsub := client.Subscribe() pubsub := client.Subscribe()
defer pubsub.Close()
err := pubsub.Subscribe("_") err := pubsub.Subscribe("_")
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(net.Error).Timeout()).To(BeTrue()) Expect(err.(net.Error).Timeout()).To(BeTrue())
@ -269,7 +276,7 @@ var _ = Describe("Client timeout", func() {
Context("read timeout", func() { Context("read timeout", func() {
BeforeEach(func() { BeforeEach(func() {
opt := redisOptions() opt = redisOptions()
opt.ReadTimeout = time.Nanosecond opt.ReadTimeout = time.Nanosecond
opt.WriteTimeout = -1 opt.WriteTimeout = -1
client = redis.NewClient(opt) client = redis.NewClient(opt)
@ -280,7 +287,7 @@ var _ = Describe("Client timeout", func() {
Context("write timeout", func() { Context("write timeout", func() {
BeforeEach(func() { BeforeEach(func() {
opt := redisOptions() opt = redisOptions()
opt.ReadTimeout = -1 opt.ReadTimeout = -1
opt.WriteTimeout = time.Nanosecond opt.WriteTimeout = time.Nanosecond
client = redis.NewClient(opt) client = redis.NewClient(opt)

View File

@ -2,7 +2,6 @@ package redis
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"strings" "strings"
"sync" "sync"
@ -111,7 +110,7 @@ func (c *sentinelClient) PubSub() *PubSub {
return &PubSub{ return &PubSub{
base: baseClient{ base: baseClient{
opt: c.opt, opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), false), connPool: c.connPool,
}, },
} }
} }
@ -268,12 +267,11 @@ func (d *sentinelFailover) closeOldConns(newMaster string) {
break break
} }
if cn.RemoteAddr().String() != newMaster { if cn.RemoteAddr().String() != newMaster {
err := fmt.Errorf( internal.Logf(
"sentinel: closing connection to the old master %s", "sentinel: closing connection to the old master %s",
cn.RemoteAddr(), cn.RemoteAddr(),
) )
internal.Logf(err.Error()) d.pool.Remove(cn)
d.pool.Remove(cn, err)
} else { } else {
cnsToPut = append(cnsToPut, cn) cnsToPut = append(cnsToPut, cn)
} }