diff --git a/internal/pool/pool.go b/internal/pool/pool.go index da1e381f..16e508c0 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -185,7 +185,7 @@ func (p *ConnPool) Get() (*Conn, error) { if !cn.IsStale(p.idleTimeout) { return cn, nil } - _ = cn.Close() + _ = p.closeConn(cn, errConnStale) } newcn, err := p.NewConn() @@ -196,7 +196,7 @@ func (p *ConnPool) Get() (*Conn, error) { p.connsMu.Lock() if cn != nil { - p.remove(cn, errConnStale) + p.removeConn(cn) } p.conns = append(p.conns, newcn) p.connsMu.Unlock() @@ -218,16 +218,20 @@ func (p *ConnPool) Put(cn *Conn) error { } func (p *ConnPool) Remove(cn *Conn, reason error) error { - _ = cn.Close() - p.connsMu.Lock() p.remove(cn, reason) - p.connsMu.Unlock() p.queue <- struct{}{} return nil } func (p *ConnPool) remove(cn *Conn, reason error) { - p.storeLastErr(reason.Error()) + _ = p.closeConn(cn, reason) + + p.connsMu.Lock() + p.removeConn(cn) + p.connsMu.Unlock() +} + +func (p *ConnPool) removeConn(cn *Conn) { for i, c := range p.conns { if c == cn { p.conns = append(p.conns[:i], p.conns[i+1:]...) @@ -272,13 +276,12 @@ func (p *ConnPool) Close() (retErr error) { } p.connsMu.Lock() - // Close all connections. for _, cn := range p.conns { if cn == nil { continue } - if err := p.closeConn(cn); err != nil && retErr == nil { + if err := p.closeConn(cn, ErrClosed); err != nil && retErr == nil { retErr = err } } @@ -292,41 +295,48 @@ func (p *ConnPool) Close() (retErr error) { return retErr } -func (p *ConnPool) closeConn(cn *Conn) error { +func (p *ConnPool) closeConn(cn *Conn, reason error) error { + p.storeLastErr(reason.Error()) if p.OnClose != nil { _ = p.OnClose(cn) } return cn.Close() } -func (p *ConnPool) ReapStaleConns() (n int, err error) { - <-p.queue - p.freeConnsMu.Lock() - +func (p *ConnPool) reapStaleConn() bool { if len(p.freeConns) == 0 { + return false + } + + cn := p.freeConns[0] + if !cn.IsStale(p.idleTimeout) { + return false + } + + p.remove(cn, errConnStale) + p.freeConns = append(p.freeConns[:0], p.freeConns[1:]...) + + return true +} + +func (p *ConnPool) ReapStaleConns() (int, error) { + var n int + for { + <-p.queue + p.freeConnsMu.Lock() + + reaped := p.reapStaleConn() + p.freeConnsMu.Unlock() p.queue <- struct{}{} - return - } - var idx int - var cn *Conn - for idx, cn = range p.freeConns { - if !cn.IsStale(p.idleTimeout) { + if reaped { + n++ + } else { break } - p.connsMu.Lock() - p.remove(cn, errConnStale) - p.connsMu.Unlock() - n++ } - if idx > 0 { - p.freeConns = append(p.freeConns[:0], p.freeConns[idx:]...) - } - - p.freeConnsMu.Unlock() - p.queue <- struct{}{} - return + return n, nil } func (p *ConnPool) reaper(frequency time.Duration) { diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index ceec428f..2c8908fe 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -94,20 +94,31 @@ var _ = Describe("ConnPool", func() { }) var _ = Describe("conns reaper", func() { + const idleTimeout = time.Minute + var connPool *pool.ConnPool + var idleConns, closedConns []*pool.Conn BeforeEach(func() { connPool = pool.NewConnPool( - dummyDialer, 10, time.Second, time.Millisecond, time.Hour) + dummyDialer, 10, time.Second, idleTimeout, time.Hour) + + closedConns = nil + connPool.OnClose = func(cn *pool.Conn) error { + closedConns = append(closedConns, cn) + return nil + } var cns []*pool.Conn // add stale connections + idleConns = nil for i := 0; i < 3; i++ { cn, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) - cn.UsedAt = time.Now().Add(-2 * time.Minute) + cn.UsedAt = time.Now().Add(-2 * idleTimeout) cns = append(cns, cn) + idleConns = append(idleConns, cn) } // add fresh connections @@ -139,6 +150,17 @@ var _ = Describe("conns reaper", func() { Expect(connPool.FreeLen()).To(Equal(3)) }) + It("does not reap fresh connections", func() { + n, err := connPool.ReapStaleConns() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(0)) + }) + + It("stale connections are closed", func() { + Expect(closedConns).To(HaveLen(3)) + Expect(closedConns).To(ConsistOf(idleConns)) + }) + It("pool is functional", func() { for j := 0; j < 3; j++ { var freeCns []*pool.Conn