Simplify connection management with sticky connection pool. Fixes #260.

This commit is contained in:
Vladimir Mihailenco 2016-03-01 12:31:06 +02:00
parent 0382d1e980
commit 110e93a8e4
10 changed files with 140 additions and 90 deletions

View File

@ -33,14 +33,14 @@ func isNetworkError(err error) bool {
return ok return ok
} }
func isBadConn(cn *conn, ei error) bool { func isBadConn(err error) bool {
if cn.rd.Buffered() > 0 { if err == nil {
return true
}
if ei == nil {
return false return false
} }
if _, ok := ei.(redisError); ok { if _, ok := err.(redisError); ok {
return false
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return false return false
} }
return true return true

View File

@ -6,6 +6,10 @@ func (c *baseClient) Pool() pool {
return c.connPool return c.connPool
} }
func (c *PubSub) Pool() pool {
return c.base.connPool
}
var NewConnDialer = newConnDialer var NewConnDialer = newConnDialer
func (cn *conn) SetNetConn(netcn net.Conn) { func (cn *conn) SetNetConn(netcn net.Conn) {

View File

@ -7,7 +7,6 @@ import (
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"sync/atomic" "sync/atomic"
"syscall"
"testing" "testing"
"time" "time"
@ -243,10 +242,6 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
var (
errTimeout = syscall.ETIMEDOUT
)
type badConnError string type badConnError string
func (e badConnError) Error() string { return string(e) } func (e badConnError) Error() string { return string(e) }

View File

@ -45,18 +45,6 @@ func (c *Client) Multi() *Multi {
return multi return multi
} }
func (c *Multi) putConn(cn *conn, err error) {
if isBadConn(cn, err) {
// Close current connection.
c.base.connPool.(*stickyConnPool).Reset(err)
} else {
err := c.base.connPool.Put(cn)
if err != nil {
Logger.Printf("pool.Put failed: %s", err)
}
}
}
func (c *Multi) process(cmd Cmder) { func (c *Multi) process(cmd Cmder) {
if c.cmds == nil { if c.cmds == nil {
c.base.process(cmd) c.base.process(cmd)
@ -145,7 +133,7 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) {
} }
err = c.execCmds(cn, cmds) err = c.execCmds(cn, cmds)
c.putConn(cn, err) c.base.putConn(cn, err)
return retCmds, err return retCmds, err
} }

View File

@ -166,4 +166,31 @@ var _ = Describe("Multi", func() {
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
It("should recover from bad connection when there are no commands", func() {
// Put bad connection in the pool.
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())
{
tx, err := client.Watch("key")
Expect(err).To(MatchError("bad connection"))
Expect(tx).To(BeNil())
}
{
tx, err := client.Watch("key")
Expect(err).NotTo(HaveOccurred())
err = tx.Ping().Err()
Expect(err).NotTo(HaveOccurred())
err = tx.Close()
Expect(err).NotTo(HaveOccurred())
}
})
}) })

20
pool.go
View File

@ -246,13 +246,14 @@ func (p *connPool) Get() (cn *conn, isNew bool, err error) {
// Try to create a new one. // Try to create a new one.
if p.conns.Reserve() { if p.conns.Reserve() {
isNew = true
cn, err = p.new() cn, err = p.new()
if err != nil { if err != nil {
p.conns.Remove(nil) p.conns.Remove(nil)
return return
} }
p.conns.Add(cn) p.conns.Add(cn)
isNew = true
return return
} }
@ -481,13 +482,13 @@ func (p *stickyConnPool) Put(cn *conn) error {
return nil return nil
} }
func (p *stickyConnPool) remove(reason error) (err error) { func (p *stickyConnPool) remove(reason error) error {
err = p.pool.Remove(p.cn, reason) err := p.pool.Remove(p.cn, reason)
p.cn = nil p.cn = nil
return err return err
} }
func (p *stickyConnPool) Remove(cn *conn, _ error) error { func (p *stickyConnPool) Remove(cn *conn, reason error) error {
defer p.mx.Unlock() defer p.mx.Unlock()
p.mx.Lock() p.mx.Lock()
if p.closed { if p.closed {
@ -499,7 +500,7 @@ func (p *stickyConnPool) Remove(cn *conn, _ error) error {
if cn != nil && p.cn != cn { if cn != nil && p.cn != cn {
panic("p.cn != cn") panic("p.cn != cn")
} }
return nil return p.remove(reason)
} }
func (p *stickyConnPool) Len() int { func (p *stickyConnPool) Len() int {
@ -522,15 +523,6 @@ func (p *stickyConnPool) FreeLen() int {
func (p *stickyConnPool) Stats() *PoolStats { return nil } func (p *stickyConnPool) Stats() *PoolStats { return nil }
func (p *stickyConnPool) Reset(reason error) (err error) {
p.mx.Lock()
if p.cn != nil {
err = p.remove(reason)
}
p.mx.Unlock()
return err
}
func (p *stickyConnPool) Close() error { func (p *stickyConnPool) Close() error {
defer p.mx.Unlock() defer p.mx.Unlock()
p.mx.Lock() p.mx.Lock()

View File

@ -17,16 +17,18 @@ func (c *Client) Publish(channel, message string) *IntCmd {
// http://redis.io/topics/pubsub. It's NOT safe for concurrent use by // http://redis.io/topics/pubsub. It's NOT safe for concurrent use by
// multiple goroutines. // multiple goroutines.
type PubSub struct { type PubSub struct {
*baseClient base *baseClient
channels []string channels []string
patterns []string patterns []string
nsub int // number of active subscriptions
} }
// Deprecated. Use Subscribe/PSubscribe instead. // Deprecated. Use Subscribe/PSubscribe instead.
func (c *Client) PubSub() *PubSub { func (c *Client) PubSub() *PubSub {
return &PubSub{ return &PubSub{
baseClient: &baseClient{ base: &baseClient{
opt: c.opt, opt: c.opt,
connPool: newStickyConnPool(c.connPool, false), connPool: newStickyConnPool(c.connPool, false),
}, },
@ -46,7 +48,7 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) {
} }
func (c *PubSub) subscribe(cmd string, channels ...string) error { func (c *PubSub) subscribe(cmd string, channels ...string) error {
cn, _, err := c.conn() cn, _, err := c.base.conn()
if err != nil { if err != nil {
return err return err
} }
@ -65,6 +67,7 @@ func (c *PubSub) Subscribe(channels ...string) error {
err := c.subscribe("SUBSCRIBE", channels...) err := c.subscribe("SUBSCRIBE", channels...)
if err == nil { if err == nil {
c.channels = append(c.channels, channels...) c.channels = append(c.channels, channels...)
c.nsub += len(channels)
} }
return err return err
} }
@ -74,6 +77,7 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
err := c.subscribe("PSUBSCRIBE", patterns...) err := c.subscribe("PSUBSCRIBE", patterns...)
if err == nil { if err == nil {
c.patterns = append(c.patterns, patterns...) c.patterns = append(c.patterns, patterns...)
c.nsub += len(patterns)
} }
return err return err
} }
@ -113,8 +117,12 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error {
return err return err
} }
func (c *PubSub) Close() error {
return c.base.Close()
}
func (c *PubSub) Ping(payload string) error { func (c *PubSub) Ping(payload string) error {
cn, _, err := c.conn() cn, _, err := c.base.conn()
if err != nil { if err != nil {
return err return err
} }
@ -178,7 +186,7 @@ func (p *Pong) String() string {
return "Pong" return "Pong"
} }
func newMessage(reply []interface{}) (interface{}, error) { func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) {
switch kind := reply[0].(string); kind { switch kind := reply[0].(string); kind {
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
return &Subscription{ return &Subscription{
@ -210,7 +218,11 @@ func newMessage(reply []interface{}) (interface{}, error) {
// is not received in time. This is low-level API and most clients // is not received in time. This is low-level API and most clients
// should use ReceiveMessage. // should use ReceiveMessage.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
cn, _, err := c.conn() if c.nsub == 0 {
c.resubscribe()
}
cn, _, err := c.base.conn()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -222,7 +234,8 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newMessage(cmd.Val())
return c.newMessage(cmd.Val())
} }
// Receive returns a message as a Subscription, Message, PMessage, // Receive returns a message as a Subscription, Message, PMessage,
@ -232,22 +245,6 @@ func (c *PubSub) Receive() (interface{}, error) {
return c.ReceiveTimeout(0) return c.ReceiveTimeout(0)
} }
func (c *PubSub) reconnect(reason error) {
// Close current connection.
c.connPool.(*stickyConnPool).Reset(reason)
if len(c.channels) > 0 {
if err := c.Subscribe(c.channels...); err != nil {
Logger.Printf("Subscribe failed: %s", err)
}
}
if len(c.patterns) > 0 {
if err := c.PSubscribe(c.patterns...); err != nil {
Logger.Printf("PSubscribe failed: %s", err)
}
}
}
// ReceiveMessage returns a message or error. It automatically // ReceiveMessage returns a message or error. It automatically
// reconnects to Redis in case of network errors. // reconnects to Redis in case of network errors.
func (c *PubSub) ReceiveMessage() (*Message, error) { func (c *PubSub) ReceiveMessage() (*Message, error) {
@ -259,10 +256,8 @@ func (c *PubSub) ReceiveMessage() (*Message, error) {
return nil, err return nil, err
} }
goodConn := errNum == 0
errNum++ errNum++
if errNum < 3 {
if goodConn {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
err := c.Ping("") err := c.Ping("")
if err == nil { if err == nil {
@ -270,16 +265,16 @@ func (c *PubSub) ReceiveMessage() (*Message, error) {
} }
Logger.Printf("PubSub.Ping failed: %s", err) Logger.Printf("PubSub.Ping failed: %s", err)
} }
} } else {
// 3 consequent errors - connection is bad
if errNum > 2 { // and/or Redis Server is down.
// Sleep to not exceed max number of open connections.
time.Sleep(time.Second) time.Sleep(time.Second)
} }
c.reconnect(err)
continue continue
} }
// Reset error number. // Reset error number, because we received a message.
errNum = 0 errNum = 0
switch msg := msgi.(type) { switch msg := msgi.(type) {
@ -300,3 +295,22 @@ func (c *PubSub) ReceiveMessage() (*Message, error) {
} }
} }
} }
func (c *PubSub) putConn(cn *conn, err error) {
if !c.base.putConn(cn, err) {
c.nsub = 0
}
}
func (c *PubSub) resubscribe() {
if len(c.channels) > 0 {
if err := c.Subscribe(c.channels...); err != nil {
Logger.Printf("Subscribe failed: %s", err)
}
}
if len(c.patterns) > 0 {
if err := c.PSubscribe(c.patterns...); err != nil {
Logger.Printf("PSubscribe failed: %s", err)
}
}
}

View File

@ -1,6 +1,7 @@
package redis_test package redis_test
import ( import (
"io"
"net" "net"
"sync" "sync"
"time" "time"
@ -230,18 +231,41 @@ var _ = Describe("PubSub", func() {
Expect(pong.Payload).To(Equal("hello")) Expect(pong.Payload).To(Equal("hello"))
}) })
It("should ReceiveMessage", func() { It("should multi-ReceiveMessage", func() {
pubsub, err := client.Subscribe("mychannel") pubsub, err := client.Subscribe("mychannel")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
defer pubsub.Close() defer pubsub.Close()
var wg sync.WaitGroup err = client.Publish("mychannel", "hello").Err()
wg.Add(1) Expect(err).NotTo(HaveOccurred())
err = client.Publish("mychannel", "world").Err()
Expect(err).NotTo(HaveOccurred())
msg, err := pubsub.ReceiveMessage()
Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello"))
msg, err = pubsub.ReceiveMessage()
Expect(err).NotTo(HaveOccurred())
Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("world"))
})
It("should ReceiveMessage after timeout", func() {
pubsub, err := client.Subscribe("mychannel")
Expect(err).NotTo(HaveOccurred())
defer pubsub.Close()
done := make(chan bool, 1)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
defer wg.Done() defer func() {
done <- true
}()
time.Sleep(readTimeout + 100*time.Millisecond) time.Sleep(5*time.Second + 100*time.Millisecond)
n, err := client.Publish("mychannel", "hello").Result() n, err := client.Publish("mychannel", "hello").Result()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(n).To(Equal(int64(1))) Expect(n).To(Equal(int64(1)))
@ -252,22 +276,23 @@ var _ = Describe("PubSub", func() {
Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello")) Expect(msg.Payload).To(Equal("hello"))
wg.Wait() Eventually(done).Should(Receive())
}) })
expectReceiveMessage := func(pubsub *redis.PubSub) { expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
cn1, _, err := pubsub.Pool().Get() cn1, _, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
cn1.SetNetConn(&badConn{ cn1.SetNetConn(&badConn{
readErr: errTimeout, readErr: io.EOF,
writeErr: errTimeout, writeErr: io.EOF,
}) })
var wg sync.WaitGroup done := make(chan bool, 1)
wg.Add(1)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
defer wg.Done() defer func() {
done <- true
}()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
err := client.Publish("mychannel", "hello").Err() err := client.Publish("mychannel", "hello").Err()
@ -279,7 +304,7 @@ var _ = Describe("PubSub", func() {
Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Channel).To(Equal("mychannel"))
Expect(msg.Payload).To(Equal("hello")) Expect(msg.Payload).To(Equal("hello"))
wg.Wait() Eventually(done).Should(Receive())
} }
It("Subscribe should reconnect on ReceiveMessage error", func() { It("Subscribe should reconnect on ReceiveMessage error", func() {
@ -287,7 +312,7 @@ var _ = Describe("PubSub", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
defer pubsub.Close() defer pubsub.Close()
expectReceiveMessage(pubsub) expectReceiveMessageOnError(pubsub)
}) })
It("PSubscribe should reconnect on ReceiveMessage error", func() { It("PSubscribe should reconnect on ReceiveMessage error", func() {
@ -295,7 +320,7 @@ var _ = Describe("PubSub", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
defer pubsub.Close() defer pubsub.Close()
expectReceiveMessage(pubsub) expectReceiveMessageOnError(pubsub)
}) })
It("should return on Close", func() { It("should return on Close", func() {

View File

@ -23,15 +23,20 @@ func (c *baseClient) conn() (*conn, bool, error) {
return c.connPool.Get() return c.connPool.Get()
} }
func (c *baseClient) putConn(cn *conn, err error) { func (c *baseClient) putConn(cn *conn, err error) bool {
if isBadConn(cn, err) { if isBadConn(err) {
err = c.connPool.Remove(cn, err) err = c.connPool.Remove(cn, err)
} else {
err = c.connPool.Put(cn)
}
if err != nil { if err != nil {
Logger.Printf("pool.Put failed: %s", err) log.Printf("pool.Remove failed: %s", err)
} }
return false
}
err = c.connPool.Put(cn)
if err != nil {
log.Printf("pool.Put failed: %s", err)
}
return true
} }
func (c *baseClient) process(cmd Cmder) { func (c *baseClient) process(cmd Cmder) {

View File

@ -88,7 +88,7 @@ func newSentinel(opt *Options) *sentinelClient {
func (c *sentinelClient) PubSub() *PubSub { func (c *sentinelClient) PubSub() *PubSub {
return &PubSub{ return &PubSub{
baseClient: &baseClient{ base: &baseClient{
opt: c.opt, opt: c.opt,
connPool: newStickyConnPool(c.connPool, false), connPool: newStickyConnPool(c.connPool, false),
}, },