Merge pull request #97 from go-redis/fix/pool-closes-all-connections

Fix pool to close all connections when client is closed.
This commit is contained in:
Vladimir Mihailenco 2015-05-05 12:38:46 +03:00
commit d3a8d04b9c
6 changed files with 294 additions and 189 deletions

96
conn.go Normal file
View File

@ -0,0 +1,96 @@
package redis
import (
"net"
"time"
"gopkg.in/bufio.v1"
)
type conn struct {
netcn net.Conn
rd *bufio.Reader
buf []byte
usedAt time.Time
readTimeout time.Duration
writeTimeout time.Duration
}
func newConnDialer(opt *options) func() (*conn, error) {
return func() (*conn, error) {
netcn, err := opt.Dialer()
if err != nil {
return nil, err
}
cn := &conn{
netcn: netcn,
buf: make([]byte, 0, 64),
}
cn.rd = bufio.NewReader(cn)
return cn, cn.init(opt)
}
}
func (cn *conn) init(opt *options) error {
if opt.Password == "" && opt.DB == 0 {
return nil
}
// Use connection to connect to redis
pool := newSingleConnPool(nil, false)
pool.SetConn(cn)
// Client is not closed because we want to reuse underlying connection.
client := newClient(opt, pool)
if opt.Password != "" {
if err := client.Auth(opt.Password).Err(); err != nil {
return err
}
}
if opt.DB > 0 {
if err := client.Select(opt.DB).Err(); err != nil {
return err
}
}
return nil
}
func (cn *conn) writeCmds(cmds ...Cmder) error {
buf := cn.buf[:0]
for _, cmd := range cmds {
buf = appendArgs(buf, cmd.args())
}
_, err := cn.Write(buf)
return err
}
func (cn *conn) Read(b []byte) (int, error) {
if cn.readTimeout != 0 {
cn.netcn.SetReadDeadline(time.Now().Add(cn.readTimeout))
} else {
cn.netcn.SetReadDeadline(zeroTime)
}
return cn.netcn.Read(b)
}
func (cn *conn) Write(b []byte) (int, error) {
if cn.writeTimeout != 0 {
cn.netcn.SetWriteDeadline(time.Now().Add(cn.writeTimeout))
} else {
cn.netcn.SetWriteDeadline(zeroTime)
}
return cn.netcn.Write(b)
}
func (cn *conn) RemoteAddr() net.Addr {
return cn.netcn.RemoteAddr()
}
func (cn *conn) Close() error {
return cn.netcn.Close()
}

288
pool.go
View File

@ -4,13 +4,11 @@ import (
"errors"
"fmt"
"log"
"net"
"sync"
"sync/atomic"
"time"
"gopkg.in/bsm/ratelimit.v1"
"gopkg.in/bufio.v1"
)
var (
@ -28,103 +26,132 @@ type pool interface {
Put(*conn) error
Remove(*conn) error
Len() int
Size() int
FreeLen() int
Close() error
}
//------------------------------------------------------------------------------
type conn struct {
netcn net.Conn
rd *bufio.Reader
buf []byte
usedAt time.Time
readTimeout time.Duration
writeTimeout time.Duration
type connList struct {
cns []*conn
mx sync.Mutex
len int32 // atomic
size int32
}
func newConnFunc(dial func() (net.Conn, error)) func() (*conn, error) {
return func() (*conn, error) {
netcn, err := dial()
if err != nil {
return nil, err
}
cn := &conn{
netcn: netcn,
buf: make([]byte, 0, 64),
}
cn.rd = bufio.NewReader(cn)
return cn, nil
func newConnList(size int) *connList {
return &connList{
cns: make([]*conn, 0, size),
size: int32(size),
}
}
func (cn *conn) writeCmds(cmds ...Cmder) error {
buf := cn.buf[:0]
for _, cmd := range cmds {
buf = appendArgs(buf, cmd.args())
func (l *connList) Len() int {
return int(atomic.LoadInt32(&l.len))
}
// Reserve reserves place in the list and returns true on success. The
// caller must add or remove connection if place was reserved.
func (l *connList) Reserve() bool {
len := atomic.AddInt32(&l.len, 1)
reserved := len <= l.size
if !reserved {
atomic.AddInt32(&l.len, -1)
}
return reserved
}
// Add adds connection to the list. The caller must reserve place first.
func (l *connList) Add(cn *conn) {
l.mx.Lock()
l.cns = append(l.cns, cn)
l.mx.Unlock()
}
func (l *connList) Remove(cn *conn) error {
defer l.mx.Unlock()
l.mx.Lock()
if cn == nil {
atomic.AddInt32(&l.len, -1)
return nil
}
_, err := cn.Write(buf)
return err
}
func (cn *conn) Read(b []byte) (int, error) {
if cn.readTimeout != 0 {
cn.netcn.SetReadDeadline(time.Now().Add(cn.readTimeout))
} else {
cn.netcn.SetReadDeadline(zeroTime)
for i, c := range l.cns {
if c == cn {
l.cns = append(l.cns[:i], l.cns[i+1:]...)
atomic.AddInt32(&l.len, -1)
return cn.Close()
}
return cn.netcn.Read(b)
}
func (cn *conn) Write(b []byte) (int, error) {
if cn.writeTimeout != 0 {
cn.netcn.SetWriteDeadline(time.Now().Add(cn.writeTimeout))
} else {
cn.netcn.SetWriteDeadline(zeroTime)
}
return cn.netcn.Write(b)
panic("conn not found in the list")
}
func (cn *conn) RemoteAddr() net.Addr {
return cn.netcn.RemoteAddr()
func (l *connList) Replace(cn, newcn *conn) error {
defer l.mx.Unlock()
l.mx.Lock()
for i, c := range l.cns {
if c == cn {
l.cns[i] = newcn
return cn.Close()
}
}
panic("conn not found in the list")
}
func (cn *conn) Close() error {
return cn.netcn.Close()
func (l *connList) Close() (retErr error) {
l.mx.Lock()
for _, c := range l.cns {
if err := c.Close(); err != nil {
retErr = err
}
}
l.cns = nil
atomic.StoreInt32(&l.len, 0)
l.mx.Unlock()
return retErr
}
func (cn *conn) isIdle(timeout time.Duration) bool {
return timeout > 0 && time.Since(cn.usedAt) > timeout
type connPoolOptions struct {
Dialer func() (*conn, error)
PoolSize int
PoolTimeout time.Duration
IdleTimeout time.Duration
IdleCheckFrequency time.Duration
}
//------------------------------------------------------------------------------
type connPool struct {
dial func() (*conn, error)
rl *ratelimit.RateLimiter
opt *options
opt *connPoolOptions
conns *connList
freeConns chan *conn
size int32
closed int32
_closed int32
lastDialErr error
}
func newConnPool(dial func() (*conn, error), opt *options) *connPool {
return &connPool{
dial: dial,
func newConnPool(opt *connPoolOptions) *connPool {
p := &connPool{
rl: ratelimit.New(2*opt.PoolSize, time.Second),
opt: opt,
conns: newConnList(opt.PoolSize),
freeConns: make(chan *conn, opt.PoolSize),
}
if p.opt.IdleTimeout > 0 && p.opt.IdleCheckFrequency > 0 {
go p.reaper()
}
return p
}
func (p *connPool) isClosed() bool { return atomic.LoadInt32(&p.closed) > 0 }
func (p *connPool) closed() bool {
return atomic.LoadInt32(&p._closed) == 1
}
func (p *connPool) isIdle(cn *conn) bool {
return p.opt.IdleTimeout > 0 && time.Since(cn.usedAt) > p.opt.IdleTimeout
}
// First returns first non-idle connection from the pool or nil if
// there are no connections.
@ -132,8 +159,8 @@ func (p *connPool) First() *conn {
for {
select {
case cn := <-p.freeConns:
if cn.isIdle(p.opt.IdleTimeout) {
p.Remove(cn)
if p.isIdle(cn) {
p.conns.Remove(cn)
continue
}
return cn
@ -150,7 +177,7 @@ func (p *connPool) wait(timeout time.Duration) *conn {
for {
select {
case cn := <-p.freeConns:
if cn.isIdle(p.opt.IdleTimeout) {
if p.isIdle(cn) {
p.Remove(cn)
continue
}
@ -172,52 +199,19 @@ func (p *connPool) new() (*conn, error) {
return nil, err
}
cn, err := p.dial()
cn, err := p.opt.Dialer()
if err != nil {
p.lastDialErr = err
return nil, err
}
if err := p.initConn(cn); err != nil {
cn.Close()
return nil, err
}
return cn, nil
}
// Initialize connection
func (p *connPool) initConn(cn *conn) error {
if p.opt.Password == "" && p.opt.DB == 0 {
return nil
}
// Use connection to connect to redis
pool := newSingleConnPool(p, false)
pool.SetConn(cn)
// Client is not closed because we want to reuse underlying connection.
client := newClient(p.opt, pool)
if p.opt.Password != "" {
if err := client.Auth(p.opt.Password).Err(); err != nil {
return err
}
}
if p.opt.DB > 0 {
if err := client.Select(p.opt.DB).Err(); err != nil {
return err
}
}
return nil
}
// Get returns existed connection from the pool or creates a new one
// if needed.
func (p *connPool) Get() (*conn, error) {
if p.isClosed() {
if p.closed() {
return nil, errClosed
}
@ -226,16 +220,16 @@ func (p *connPool) Get() (*conn, error) {
return cn, nil
}
// Try to create a new one
if ref := atomic.AddInt32(&p.size, 1); int(ref) <= p.opt.PoolSize {
// Try to create a new one.
if p.conns.Reserve() {
cn, err := p.new()
if err != nil {
atomic.AddInt32(&p.size, -1) // Undo ref increment
p.conns.Remove(nil)
return nil, err
}
p.conns.Add(cn)
return cn, nil
}
atomic.AddInt32(&p.size, -1)
// Otherwise, wait for the available connection
if cn := p.wait(p.opt.PoolTimeout); cn != nil {
@ -259,49 +253,53 @@ func (p *connPool) Put(cn *conn) error {
}
func (p *connPool) Remove(cn *conn) error {
if p.isClosed() {
atomic.AddInt32(&p.size, -1)
return cn.Close()
}
// Replace existing connection with new one and unblock `wait`.
newcn, err := p.new()
if err != nil {
atomic.AddInt32(&p.size, -1)
} else {
p.Put(newcn)
}
return cn.Close()
}
// Len returns number of idle connections.
func (p *connPool) Len() int {
return len(p.freeConns)
}
// Size returns number of connections in the pool.
func (p *connPool) Size() int {
return int(atomic.LoadInt32(&p.size))
}
func (p *connPool) Close() (retErr error) {
if !atomic.CompareAndSwapInt32(&p.closed, 0, 1) {
if p.closed() {
// Close already closed all connections.
return nil
}
// Wait until pool has no connections
for p.Size() > 0 {
cn := p.wait(p.opt.PoolTimeout)
if cn == nil {
// Replace existing connection with new one and unblock waiter.
newcn, err := p.new()
if err != nil {
return p.conns.Remove(cn)
}
p.freeConns <- newcn
return p.conns.Replace(cn, newcn)
}
// Len returns total number of connections.
func (p *connPool) Len() int {
return p.conns.Len()
}
// FreeLen returns number of free connections.
func (p *connPool) FreeLen() int {
return len(p.freeConns)
}
func (p *connPool) Close() error {
if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) {
return errClosed
}
return p.conns.Close()
}
func (p *connPool) reaper() {
ticker := time.NewTicker(p.opt.IdleCheckFrequency)
defer ticker.Stop()
for _ = range ticker.C {
if p.closed() {
break
}
if err := p.Remove(cn); err != nil {
retErr = err
}
}
return retErr
// pool.First removes idle connections from the pool and
// returns first non-idle connection. So just put returned
// connection back.
if cn := p.First(); cn != nil {
p.Put(cn)
}
}
}
//------------------------------------------------------------------------------
@ -404,7 +402,7 @@ func (p *singleConnPool) Len() int {
return 1
}
func (p *singleConnPool) Size() int {
func (p *singleConnPool) FreeLen() int {
defer p.cnMtx.Unlock()
p.cnMtx.Lock()
if p.cn == nil {

View File

@ -48,9 +48,9 @@ var _ = Describe("Pool", func() {
})
pool := client.Pool()
Expect(pool.Size()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(BeNumerically("<=", 10))
Expect(pool.Size()).To(Equal(pool.Len()))
Expect(pool.FreeLen()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(Equal(pool.FreeLen()))
})
It("should respect max on multi", func() {
@ -70,9 +70,9 @@ var _ = Describe("Pool", func() {
})
pool := client.Pool()
Expect(pool.Size()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(BeNumerically("<=", 10))
Expect(pool.Size()).To(Equal(pool.Len()))
Expect(pool.FreeLen()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(Equal(pool.FreeLen()))
})
It("should respect max on pipelines", func() {
@ -88,9 +88,9 @@ var _ = Describe("Pool", func() {
})
pool := client.Pool()
Expect(pool.Size()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(BeNumerically("<=", 10))
Expect(pool.Size()).To(Equal(pool.Len()))
Expect(pool.FreeLen()).To(BeNumerically("<=", 10))
Expect(pool.Len()).To(Equal(pool.FreeLen()))
})
It("should respect max on pubsub", func() {
@ -101,8 +101,8 @@ var _ = Describe("Pool", func() {
})
pool := client.Pool()
Expect(pool.Size()).To(Equal(10))
Expect(pool.Len()).To(Equal(10))
Expect(pool.FreeLen()).To(Equal(10))
})
It("should remove broken connections", func() {
@ -120,8 +120,8 @@ var _ = Describe("Pool", func() {
Expect(val).To(Equal("PONG"))
pool := client.Pool()
Expect(pool.Size()).To(Equal(1))
Expect(pool.Len()).To(Equal(1))
Expect(pool.FreeLen()).To(Equal(1))
})
It("should reuse connections", func() {
@ -132,8 +132,8 @@ var _ = Describe("Pool", func() {
}
pool := client.Pool()
Expect(pool.Size()).To(Equal(1))
Expect(pool.Len()).To(Equal(1))
Expect(pool.FreeLen()).To(Equal(1))
})
It("should unblock client when connection is removed", func() {

View File

@ -67,19 +67,6 @@ func (c *baseClient) Close() error {
//------------------------------------------------------------------------------
type options struct {
Password string
DB int64
DialTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
PoolSize int
PoolTimeout time.Duration
IdleTimeout time.Duration
}
type Options struct {
// The network type, either "tcp" or "unix".
// Default: "tcp"
@ -120,6 +107,15 @@ type Options struct {
IdleTimeout time.Duration
}
func (opt *Options) getDialer() func() (net.Conn, error) {
if opt.Dialer == nil {
return func() (net.Conn, error) {
return net.DialTimeout(opt.getNetwork(), opt.Addr, opt.getDialTimeout())
}
}
return opt.Dialer
}
func (opt *Options) getNetwork() string {
if opt.Network == "" {
return "tcp"
@ -150,15 +146,39 @@ func (opt *Options) getPoolTimeout() time.Duration {
func (opt *Options) options() *options {
return &options{
Dialer: opt.getDialer(),
PoolSize: opt.getPoolSize(),
PoolTimeout: opt.getPoolTimeout(),
IdleTimeout: opt.IdleTimeout,
DB: opt.DB,
Password: opt.Password,
DialTimeout: opt.getDialTimeout(),
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
}
}
PoolSize: opt.getPoolSize(),
PoolTimeout: opt.getPoolTimeout(),
type options struct {
Dialer func() (net.Conn, error)
PoolSize int
PoolTimeout time.Duration
IdleTimeout time.Duration
Password string
DB int64
DialTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
}
func (opt *options) connPoolOptions() *connPoolOptions {
return &connPoolOptions{
Dialer: newConnDialer(opt),
PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout,
}
}
@ -180,11 +200,6 @@ func newClient(opt *options, pool pool) *Client {
func NewClient(clOpt *Options) *Client {
opt := clOpt.options()
dialer := clOpt.Dialer
if dialer == nil {
dialer = func() (net.Conn, error) {
return net.DialTimeout(clOpt.getNetwork(), clOpt.Addr, opt.DialTimeout)
}
}
return newClient(opt, newConnPool(newConnFunc(dialer), opt))
pool := newConnPool(opt.connPoolOptions())
return newClient(opt, pool)
}

View File

@ -34,7 +34,7 @@ var _ = Describe("Client", func() {
})
AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
client.Close()
})
It("should ping", func() {

View File

@ -104,14 +104,9 @@ type sentinelClient struct {
func newSentinel(clOpt *Options) *sentinelClient {
opt := clOpt.options()
opt.Password = ""
opt.DB = 0
dialer := func() (net.Conn, error) {
return net.DialTimeout("tcp", clOpt.Addr, opt.DialTimeout)
}
base := &baseClient{
opt: opt,
connPool: newConnPool(newConnFunc(dialer), opt),
connPool: newConnPool(opt.connPoolOptions()),
}
return &sentinelClient{
baseClient: base,
@ -163,7 +158,8 @@ func (d *sentinelFailover) dial() (net.Conn, error) {
func (d *sentinelFailover) Pool() pool {
d.poolOnce.Do(func() {
d.pool = newConnPool(newConnFunc(d.dial), d.opt)
d.opt.Dialer = d.dial
d.pool = newConnPool(d.opt.connPoolOptions())
})
return d.pool
}