diff --git a/internal/pool/export_test.go b/internal/pool/export_test.go index de7a644..75dd4ad 100644 --- a/internal/pool/export_test.go +++ b/internal/pool/export_test.go @@ -1,6 +1,8 @@ package pool -import "time" +import ( + "time" +) func (cn *Conn) SetCreatedAt(tm time.Time) { cn.createdAt = tm diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 91b55e4..44a4e77 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -121,9 +121,10 @@ func (p *ConnPool) checkMinIdleConns() { for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns { p.poolSize++ p.idleConnsLen++ + go func() { err := p.addIdleConn() - if err != nil { + if err != nil && err != ErrClosed { p.connsMu.Lock() p.poolSize-- p.idleConnsLen-- @@ -140,9 +141,16 @@ func (p *ConnPool) addIdleConn() error { } p.connsMu.Lock() + defer p.connsMu.Unlock() + + // It is not allowed to add new connections to the closed connection pool. + if p.closed() { + _ = cn.Close() + return ErrClosed + } + p.conns = append(p.conns, cn) p.idleConns = append(p.idleConns, cn) - p.connsMu.Unlock() return nil } @@ -157,6 +165,14 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { } p.connsMu.Lock() + defer p.connsMu.Unlock() + + // It is not allowed to add new connections to the closed connection pool. + if p.closed() { + _ = cn.Close() + return nil, ErrClosed + } + p.conns = append(p.conns, cn) if pooled { // If pool is full remove the cn on next Put. @@ -166,7 +182,6 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { p.poolSize++ } } - p.connsMu.Unlock() return cn, nil } @@ -237,9 +252,13 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { for { p.connsMu.Lock() - cn := p.popIdle() + cn, err := p.popIdle() p.connsMu.Unlock() + if err != nil { + return nil, err + } + if cn == nil { break } @@ -308,10 +327,13 @@ func (p *ConnPool) freeTurn() { <-p.queue } -func (p *ConnPool) popIdle() *Conn { +func (p *ConnPool) popIdle() (*Conn, error) { + if p.closed() { + return nil, ErrClosed + } n := len(p.idleConns) if n == 0 { - return nil + return nil, nil } var cn *Conn @@ -326,7 +348,7 @@ func (p *ConnPool) popIdle() *Conn { } p.idleConnsLen-- p.checkMinIdleConns() - return cn + return cn, nil } func (p *ConnPool) Put(ctx context.Context, cn *Conn) { diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 795aef3..fcb1288 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -2,6 +2,7 @@ package pool_test import ( "context" + "net" "sync" "testing" "time" @@ -30,6 +31,43 @@ var _ = Describe("ConnPool", func() { connPool.Close() }) + It("should safe close", func() { + const minIdleConns = 10 + + var ( + wg sync.WaitGroup + closedChan = make(chan struct{}) + ) + wg.Add(minIdleConns) + connPool = pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + wg.Done() + <-closedChan + return &net.TCPConn{}, nil + }, + PoolSize: 10, + PoolTimeout: time.Hour, + IdleTimeout: time.Millisecond, + IdleCheckFrequency: time.Millisecond, + MinIdleConns: minIdleConns, + }) + wg.Wait() + Expect(connPool.Close()).NotTo(HaveOccurred()) + close(closedChan) + + // We wait for 1 second and believe that checkMinIdleConns has been executed. + time.Sleep(time.Second) + + Expect(connPool.Stats()).To(Equal(&pool.Stats{ + Hits: 0, + Misses: 0, + Timeouts: 0, + TotalConns: 0, + IdleConns: 0, + StaleConns: 0, + })) + }) + It("should unblock client when conn is removed", func() { // Reserve one connection. cn, err := connPool.Get(ctx) diff --git a/internal/proto/reader.go b/internal/proto/reader.go index 10d1b42..0e6ca77 100644 --- a/internal/proto/reader.go +++ b/internal/proto/reader.go @@ -19,7 +19,7 @@ const ( //------------------------------------------------------------------------------ -const Nil = RedisError("redis: nil") +const Nil = RedisError("redis: nil") // nolint:errname type RedisError string