Merge pull request #97 from go-redis/fix/pool-closes-all-connections

Fix pool to close all connections when client is closed.
This commit is contained in:
Vladimir Mihailenco 2015-05-05 12:38:46 +03:00
commit d3a8d04b9c
6 changed files with 294 additions and 189 deletions

96
conn.go Normal file
View File

@ -0,0 +1,96 @@
package redis
import (
"net"
"time"
"gopkg.in/bufio.v1"
)
type conn struct {
netcn net.Conn
rd *bufio.Reader
buf []byte
usedAt time.Time
readTimeout time.Duration
writeTimeout time.Duration
}
func newConnDialer(opt *options) func() (*conn, error) {
return func() (*conn, error) {
netcn, err := opt.Dialer()
if err != nil {
return nil, err
}
cn := &conn{
netcn: netcn,
buf: make([]byte, 0, 64),
}
cn.rd = bufio.NewReader(cn)
return cn, cn.init(opt)
}
}
func (cn *conn) init(opt *options) error {
if opt.Password == "" && opt.DB == 0 {
return nil
}
// Use connection to connect to redis
pool := newSingleConnPool(nil, false)
pool.SetConn(cn)
// Client is not closed because we want to reuse underlying connection.
client := newClient(opt, pool)
if opt.Password != "" {
if err := client.Auth(opt.Password).Err(); err != nil {
return err
}
}
if opt.DB > 0 {
if err := client.Select(opt.DB).Err(); err != nil {
return err
}
}
return nil
}
func (cn *conn) writeCmds(cmds ...Cmder) error {
buf := cn.buf[:0]
for _, cmd := range cmds {
buf = appendArgs(buf, cmd.args())
}
_, err := cn.Write(buf)
return err
}
func (cn *conn) Read(b []byte) (int, error) {
if cn.readTimeout != 0 {
cn.netcn.SetReadDeadline(time.Now().Add(cn.readTimeout))
} else {
cn.netcn.SetReadDeadline(zeroTime)
}
return cn.netcn.Read(b)
}
func (cn *conn) Write(b []byte) (int, error) {
if cn.writeTimeout != 0 {
cn.netcn.SetWriteDeadline(time.Now().Add(cn.writeTimeout))
} else {
cn.netcn.SetWriteDeadline(zeroTime)
}
return cn.netcn.Write(b)
}
func (cn *conn) RemoteAddr() net.Addr {
return cn.netcn.RemoteAddr()
}
func (cn *conn) Close() error {
return cn.netcn.Close()
}

282
pool.go
View File

@ -4,13 +4,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"log" "log"
"net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"gopkg.in/bsm/ratelimit.v1" "gopkg.in/bsm/ratelimit.v1"
"gopkg.in/bufio.v1"
) )
var ( var (
@ -28,103 +26,132 @@ type pool interface {
Put(*conn) error Put(*conn) error
Remove(*conn) error Remove(*conn) error
Len() int Len() int
Size() int FreeLen() int
Close() error Close() error
} }
//------------------------------------------------------------------------------ type connList struct {
cns []*conn
type conn struct { mx sync.Mutex
netcn net.Conn len int32 // atomic
rd *bufio.Reader size int32
buf []byte
usedAt time.Time
readTimeout time.Duration
writeTimeout time.Duration
} }
func newConnFunc(dial func() (net.Conn, error)) func() (*conn, error) { func newConnList(size int) *connList {
return func() (*conn, error) { return &connList{
netcn, err := dial() cns: make([]*conn, 0, size),
if err != nil { size: int32(size),
return nil, err
}
cn := &conn{
netcn: netcn,
buf: make([]byte, 0, 64),
}
cn.rd = bufio.NewReader(cn)
return cn, nil
} }
} }
func (cn *conn) writeCmds(cmds ...Cmder) error { func (l *connList) Len() int {
buf := cn.buf[:0] return int(atomic.LoadInt32(&l.len))
for _, cmd := range cmds {
buf = appendArgs(buf, cmd.args())
} }
_, err := cn.Write(buf) // Reserve reserves place in the list and returns true on success. The
return err // caller must add or remove connection if place was reserved.
func (l *connList) Reserve() bool {
len := atomic.AddInt32(&l.len, 1)
reserved := len <= l.size
if !reserved {
atomic.AddInt32(&l.len, -1)
}
return reserved
} }
func (cn *conn) Read(b []byte) (int, error) { // Add adds connection to the list. The caller must reserve place first.
if cn.readTimeout != 0 { func (l *connList) Add(cn *conn) {
cn.netcn.SetReadDeadline(time.Now().Add(cn.readTimeout)) l.mx.Lock()
} else { l.cns = append(l.cns, cn)
cn.netcn.SetReadDeadline(zeroTime) l.mx.Unlock()
}
return cn.netcn.Read(b)
} }
func (cn *conn) Write(b []byte) (int, error) { func (l *connList) Remove(cn *conn) error {
if cn.writeTimeout != 0 { defer l.mx.Unlock()
cn.netcn.SetWriteDeadline(time.Now().Add(cn.writeTimeout)) l.mx.Lock()
} else {
cn.netcn.SetWriteDeadline(zeroTime) if cn == nil {
} atomic.AddInt32(&l.len, -1)
return cn.netcn.Write(b) return nil
} }
func (cn *conn) RemoteAddr() net.Addr { for i, c := range l.cns {
return cn.netcn.RemoteAddr() if c == cn {
l.cns = append(l.cns[:i], l.cns[i+1:]...)
atomic.AddInt32(&l.len, -1)
return cn.Close()
}
} }
func (cn *conn) Close() error { panic("conn not found in the list")
return cn.netcn.Close()
} }
func (cn *conn) isIdle(timeout time.Duration) bool { func (l *connList) Replace(cn, newcn *conn) error {
return timeout > 0 && time.Since(cn.usedAt) > timeout defer l.mx.Unlock()
l.mx.Lock()
for i, c := range l.cns {
if c == cn {
l.cns[i] = newcn
return cn.Close()
}
} }
//------------------------------------------------------------------------------ panic("conn not found in the list")
}
func (l *connList) Close() (retErr error) {
l.mx.Lock()
for _, c := range l.cns {
if err := c.Close(); err != nil {
retErr = err
}
}
l.cns = nil
atomic.StoreInt32(&l.len, 0)
l.mx.Unlock()
return retErr
}
type connPoolOptions struct {
Dialer func() (*conn, error)
PoolSize int
PoolTimeout time.Duration
IdleTimeout time.Duration
IdleCheckFrequency time.Duration
}
type connPool struct { type connPool struct {
dial func() (*conn, error)
rl *ratelimit.RateLimiter rl *ratelimit.RateLimiter
opt *connPoolOptions
opt *options conns *connList
freeConns chan *conn freeConns chan *conn
size int32 _closed int32
closed int32
lastDialErr error lastDialErr error
} }
func newConnPool(dial func() (*conn, error), opt *options) *connPool { func newConnPool(opt *connPoolOptions) *connPool {
return &connPool{ p := &connPool{
dial: dial,
rl: ratelimit.New(2*opt.PoolSize, time.Second), rl: ratelimit.New(2*opt.PoolSize, time.Second),
opt: opt, opt: opt,
conns: newConnList(opt.PoolSize),
freeConns: make(chan *conn, opt.PoolSize), freeConns: make(chan *conn, opt.PoolSize),
} }
if p.opt.IdleTimeout > 0 && p.opt.IdleCheckFrequency > 0 {
go p.reaper()
}
return p
} }
func (p *connPool) isClosed() bool { return atomic.LoadInt32(&p.closed) > 0 } func (p *connPool) closed() bool {
return atomic.LoadInt32(&p._closed) == 1
}
func (p *connPool) isIdle(cn *conn) bool {
return p.opt.IdleTimeout > 0 && time.Since(cn.usedAt) > p.opt.IdleTimeout
}
// First returns first non-idle connection from the pool or nil if // First returns first non-idle connection from the pool or nil if
// there are no connections. // there are no connections.
@ -132,8 +159,8 @@ func (p *connPool) First() *conn {
for { for {
select { select {
case cn := <-p.freeConns: case cn := <-p.freeConns:
if cn.isIdle(p.opt.IdleTimeout) { if p.isIdle(cn) {
p.Remove(cn) p.conns.Remove(cn)
continue continue
} }
return cn return cn
@ -150,7 +177,7 @@ func (p *connPool) wait(timeout time.Duration) *conn {
for { for {
select { select {
case cn := <-p.freeConns: case cn := <-p.freeConns:
if cn.isIdle(p.opt.IdleTimeout) { if p.isIdle(cn) {
p.Remove(cn) p.Remove(cn)
continue continue
} }
@ -172,52 +199,19 @@ func (p *connPool) new() (*conn, error) {
return nil, err return nil, err
} }
cn, err := p.dial() cn, err := p.opt.Dialer()
if err != nil { if err != nil {
p.lastDialErr = err p.lastDialErr = err
return nil, err return nil, err
} }
if err := p.initConn(cn); err != nil {
cn.Close()
return nil, err
}
return cn, nil return cn, nil
} }
// Initialize connection
func (p *connPool) initConn(cn *conn) error {
if p.opt.Password == "" && p.opt.DB == 0 {
return nil
}
// Use connection to connect to redis
pool := newSingleConnPool(p, false)
pool.SetConn(cn)
// Client is not closed because we want to reuse underlying connection.
client := newClient(p.opt, pool)
if p.opt.Password != "" {
if err := client.Auth(p.opt.Password).Err(); err != nil {
return err
}
}
if p.opt.DB > 0 {
if err := client.Select(p.opt.DB).Err(); err != nil {
return err
}
}
return nil
}
// Get returns existed connection from the pool or creates a new one // Get returns existed connection from the pool or creates a new one
// if needed. // if needed.
func (p *connPool) Get() (*conn, error) { func (p *connPool) Get() (*conn, error) {
if p.isClosed() { if p.closed() {
return nil, errClosed return nil, errClosed
} }
@ -226,16 +220,16 @@ func (p *connPool) Get() (*conn, error) {
return cn, nil return cn, nil
} }
// Try to create a new one // Try to create a new one.
if ref := atomic.AddInt32(&p.size, 1); int(ref) <= p.opt.PoolSize { if p.conns.Reserve() {
cn, err := p.new() cn, err := p.new()
if err != nil { if err != nil {
atomic.AddInt32(&p.size, -1) // Undo ref increment p.conns.Remove(nil)
return nil, err return nil, err
} }
p.conns.Add(cn)
return cn, nil return cn, nil
} }
atomic.AddInt32(&p.size, -1)
// Otherwise, wait for the available connection // Otherwise, wait for the available connection
if cn := p.wait(p.opt.PoolTimeout); cn != nil { if cn := p.wait(p.opt.PoolTimeout); cn != nil {
@ -259,49 +253,53 @@ func (p *connPool) Put(cn *conn) error {
} }
func (p *connPool) Remove(cn *conn) error { func (p *connPool) Remove(cn *conn) error {
if p.isClosed() { if p.closed() {
atomic.AddInt32(&p.size, -1) // Close already closed all connections.
return cn.Close()
}
// Replace existing connection with new one and unblock `wait`.
newcn, err := p.new()
if err != nil {
atomic.AddInt32(&p.size, -1)
} else {
p.Put(newcn)
}
return cn.Close()
}
// Len returns number of idle connections.
func (p *connPool) Len() int {
return len(p.freeConns)
}
// Size returns number of connections in the pool.
func (p *connPool) Size() int {
return int(atomic.LoadInt32(&p.size))
}
func (p *connPool) Close() (retErr error) {
if !atomic.CompareAndSwapInt32(&p.closed, 0, 1) {
return nil return nil
} }
// Wait until pool has no connections // Replace existing connection with new one and unblock waiter.
for p.Size() > 0 { newcn, err := p.new()
cn := p.wait(p.opt.PoolTimeout) if err != nil {
if cn == nil { return p.conns.Remove(cn)
break
}
if err := p.Remove(cn); err != nil {
retErr = err
} }
p.freeConns <- newcn
return p.conns.Replace(cn, newcn)
} }
return retErr // Len returns total number of connections.
func (p *connPool) Len() int {
return p.conns.Len()
}
// FreeLen returns number of free connections.
func (p *connPool) FreeLen() int {
return len(p.freeConns)
}
func (p *connPool) Close() error {
if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) {
return errClosed
}
return p.conns.Close()
}
func (p *connPool) reaper() {
ticker := time.NewTicker(p.opt.IdleCheckFrequency)
defer ticker.Stop()
for _ = range ticker.C {
if p.closed() {
break
}
// pool.First removes idle connections from the pool and
// returns first non-idle connection. So just put returned
// connection back.
if cn := p.First(); cn != nil {
p.Put(cn)
}
}
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
@ -404,7 +402,7 @@ func (p *singleConnPool) Len() int {
return 1 return 1
} }
func (p *singleConnPool) Size() int { func (p *singleConnPool) FreeLen() int {
defer p.cnMtx.Unlock() defer p.cnMtx.Unlock()
p.cnMtx.Lock() p.cnMtx.Lock()
if p.cn == nil { if p.cn == nil {

View File

@ -48,9 +48,9 @@ var _ = Describe("Pool", func() {
}) })
pool := client.Pool() pool := client.Pool()
Expect(pool.Size()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(BeNumerically("<=", 10)) Expect(pool.Len()).To(BeNumerically("<=", 10))
Expect(pool.Size()).To(Equal(pool.Len())) Expect(pool.FreeLen()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(Equal(pool.FreeLen()))
}) })
It("should respect max on multi", func() { It("should respect max on multi", func() {
@ -70,9 +70,9 @@ var _ = Describe("Pool", func() {
}) })
pool := client.Pool() pool := client.Pool()
Expect(pool.Size()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(BeNumerically("<=", 10)) Expect(pool.Len()).To(BeNumerically("<=", 10))
Expect(pool.Size()).To(Equal(pool.Len())) Expect(pool.FreeLen()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(Equal(pool.FreeLen()))
}) })
It("should respect max on pipelines", func() { It("should respect max on pipelines", func() {
@ -88,9 +88,9 @@ var _ = Describe("Pool", func() {
}) })
pool := client.Pool() pool := client.Pool()
Expect(pool.Size()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(BeNumerically("<=", 10)) Expect(pool.Len()).To(BeNumerically("<=", 10))
Expect(pool.Size()).To(Equal(pool.Len())) Expect(pool.FreeLen()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(Equal(pool.FreeLen()))
}) })
It("should respect max on pubsub", func() { It("should respect max on pubsub", func() {
@ -101,8 +101,8 @@ var _ = Describe("Pool", func() {
}) })
pool := client.Pool() pool := client.Pool()
Expect(pool.Size()).To(Equal(10))
Expect(pool.Len()).To(Equal(10)) Expect(pool.Len()).To(Equal(10))
Expect(pool.FreeLen()).To(Equal(10))
}) })
It("should remove broken connections", func() { It("should remove broken connections", func() {
@ -120,8 +120,8 @@ var _ = Describe("Pool", func() {
Expect(val).To(Equal("PONG")) Expect(val).To(Equal("PONG"))
pool := client.Pool() pool := client.Pool()
Expect(pool.Size()).To(Equal(1))
Expect(pool.Len()).To(Equal(1)) Expect(pool.Len()).To(Equal(1))
Expect(pool.FreeLen()).To(Equal(1))
}) })
It("should reuse connections", func() { It("should reuse connections", func() {
@ -132,8 +132,8 @@ var _ = Describe("Pool", func() {
} }
pool := client.Pool() pool := client.Pool()
Expect(pool.Size()).To(Equal(1))
Expect(pool.Len()).To(Equal(1)) Expect(pool.Len()).To(Equal(1))
Expect(pool.FreeLen()).To(Equal(1))
}) })
It("should unblock client when connection is removed", func() { It("should unblock client when connection is removed", func() {

View File

@ -67,19 +67,6 @@ func (c *baseClient) Close() error {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
type options struct {
Password string
DB int64
DialTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
PoolSize int
PoolTimeout time.Duration
IdleTimeout time.Duration
}
type Options struct { type Options struct {
// The network type, either "tcp" or "unix". // The network type, either "tcp" or "unix".
// Default: "tcp" // Default: "tcp"
@ -120,6 +107,15 @@ type Options struct {
IdleTimeout time.Duration IdleTimeout time.Duration
} }
func (opt *Options) getDialer() func() (net.Conn, error) {
if opt.Dialer == nil {
return func() (net.Conn, error) {
return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout())
}
}
return opt.Dialer
}
func (opt *Options) getNetwork() string { func (opt *Options) getNetwork() string {
if opt.Network == "" { if opt.Network == "" {
return "tcp" return "tcp"
@ -150,15 +146,39 @@ func (opt *Options) getPoolTimeout() time.Duration {
func (opt *Options) options() *options { func (opt *Options) options() *options {
return &options{ return &options{
Dialer: opt.getDialer(),
PoolSize: opt.getPoolSize(),
PoolTimeout: opt.getPoolTimeout(),
IdleTimeout: opt.IdleTimeout,
DB: opt.DB, DB: opt.DB,
Password: opt.Password, Password: opt.Password,
DialTimeout: opt.getDialTimeout(), DialTimeout: opt.getDialTimeout(),
ReadTimeout: opt.ReadTimeout, ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout, WriteTimeout: opt.WriteTimeout,
}
}
PoolSize: opt.getPoolSize(), type options struct {
PoolTimeout: opt.getPoolTimeout(), Dialer func() (net.Conn, error)
PoolSize int
PoolTimeout time.Duration
IdleTimeout time.Duration
Password string
DB int64
DialTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
}
func (opt *options) connPoolOptions() *connPoolOptions {
return &connPoolOptions{
Dialer: newConnDialer(opt),
PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout, IdleTimeout: opt.IdleTimeout,
} }
} }
@ -180,11 +200,6 @@ func newClient(opt *options, pool pool) *Client {
func NewClient(clOpt *Options) *Client { func NewClient(clOpt *Options) *Client {
opt := clOpt.options() opt := clOpt.options()
dialer := clOpt.Dialer pool := newConnPool(opt.connPoolOptions())
if dialer == nil { return newClient(opt, pool)
dialer = func() (net.Conn, error) {
return net.DialTimeout(clOpt.getNetwork(), clOpt.Addr, opt.DialTimeout)
}
}
return newClient(opt, newConnPool(newConnFunc(dialer), opt))
} }

View File

@ -34,7 +34,7 @@ var _ = Describe("Client", func() {
}) })
AfterEach(func() { AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred()) client.Close()
}) })
It("should ping", func() { It("should ping", func() {

View File

@ -104,14 +104,9 @@ type sentinelClient struct {
func newSentinel(clOpt *Options) *sentinelClient { func newSentinel(clOpt *Options) *sentinelClient {
opt := clOpt.options() opt := clOpt.options()
opt.Password = ""
opt.DB = 0
dialer := func() (net.Conn, error) {
return net.DialTimeout("tcp", clOpt.Addr, opt.DialTimeout)
}
base := &baseClient{ base := &baseClient{
opt: opt, opt: opt,
connPool: newConnPool(newConnFunc(dialer), opt), connPool: newConnPool(opt.connPoolOptions()),
} }
return &sentinelClient{ return &sentinelClient{
baseClient: base, baseClient: base,
@ -163,7 +158,8 @@ func (d *sentinelFailover) dial() (net.Conn, error) {
func (d *sentinelFailover) Pool() pool { func (d *sentinelFailover) Pool() pool {
d.poolOnce.Do(func() { d.poolOnce.Do(func() {
d.pool = newConnPool(newConnFunc(d.dial), d.opt) d.opt.Dialer = d.dial
d.pool = newConnPool(d.opt.connPoolOptions())
}) })
return d.pool return d.pool
} }