diff --git a/example_test.go b/example_test.go index bb98fdd..c93e0b7 100644 --- a/example_test.go +++ b/example_test.go @@ -254,7 +254,7 @@ func ExamplePubSub_Receive() { for i := 0; i < 2; i++ { // ReceiveTimeout is a low level API. Use ReceiveMessage instead. - msgi, err := pubsub.ReceiveTimeout(500 * time.Millisecond) + msgi, err := pubsub.ReceiveTimeout(time.Second) if err != nil { panic(err) } diff --git a/internal/pool/conn_stack.go b/internal/pool/conn_stack.go index 047bc85..a26ab0e 100644 --- a/internal/pool/conn_stack.go +++ b/internal/pool/conn_stack.go @@ -28,17 +28,19 @@ func (s *connStack) Push(cn *Conn) { s.free <- struct{}{} } -func (s *connStack) ShiftStale(timeout time.Duration) *Conn { +func (s *connStack) ShiftStale(idleTimeout time.Duration) *Conn { select { case <-s.free: s.mu.Lock() defer s.mu.Unlock() - if cn := s.cns[0]; cn.IsStale(timeout) { + if cn := s.cns[0]; cn.IsStale(idleTimeout) { copy(s.cns, s.cns[1:]) s.cns = s.cns[:len(s.cns)-1] return cn } + + s.free <- struct{}{} return nil default: return nil diff --git a/internal/pool/pool.go b/internal/pool/pool.go index d5c91db..243ebea 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -83,6 +83,15 @@ func (p *ConnPool) isIdle(cn *Conn) bool { return p.idleTimeout > 0 && time.Since(cn.UsedAt) > p.idleTimeout } +func (p *ConnPool) Add(cn *Conn) bool { + if !p.conns.Reserve() { + return false + } + p.conns.Add(cn) + p.Put(cn) + return true +} + // First returns first non-idle connection from the pool or nil if // there are no connections. func (p *ConnPool) First() *Conn { @@ -216,6 +225,12 @@ func (p *ConnPool) Replace(cn *Conn, reason error) error { return nil } +func (p *ConnPool) Remove(cn *Conn, reason error) error { + p.storeLastErr(reason.Error()) + _ = cn.Close() + return p.conns.Remove(cn) +} + // Len returns total number of connections. func (p *ConnPool) Len() int { return p.conns.Len() @@ -253,6 +268,20 @@ func (p *ConnPool) Close() (retErr error) { return retErr } +func (p *ConnPool) ReapStaleConns() (n int, err error) { + for { + cn := p.freeConns.ShiftStale(p.idleTimeout) + if cn == nil { + break + } + if err = p.Remove(cn, errors.New("connection is stale")); err != nil { + return + } + n++ + } + return +} + func (p *ConnPool) reaper() { ticker := time.NewTicker(time.Minute) defer ticker.Stop() @@ -261,13 +290,11 @@ func (p *ConnPool) reaper() { if p.closed() { break } - - for { - cn := p.freeConns.ShiftStale(p.idleTimeout) - if cn == nil { - break - } - _ = p.conns.Remove(cn) + n, err := p.ReapStaleConns() + if err != nil { + Logger.Printf("ReapStaleConns failed: %s", err) + } else if n > 0 { + Logger.Printf("removed %d stale connections", n) } } } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go new file mode 100644 index 0000000..07d3a52 --- /dev/null +++ b/internal/pool/pool_test.go @@ -0,0 +1,93 @@ +package pool_test + +import ( + "errors" + "net" + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "gopkg.in/redis.v3/internal/pool" +) + +func TestGinkgoSuite(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "pool") +} + +var _ = Describe("conns reapser", func() { + var connPool *pool.ConnPool + + BeforeEach(func() { + dial := func() (net.Conn, error) { + return &net.TCPConn{}, nil + } + connPool = pool.NewConnPool(dial, 10, 0, time.Minute) + + // add stale connections + for i := 0; i < 3; i++ { + cn := pool.NewConn(&net.TCPConn{}) + cn.UsedAt = time.Now().Add(-2 * time.Minute) + Expect(connPool.Add(cn)).To(BeTrue()) + } + + // add fresh connections + for i := 0; i < 3; i++ { + cn := pool.NewConn(&net.TCPConn{}) + Expect(connPool.Add(cn)).To(BeTrue()) + } + + Expect(connPool.Len()).To(Equal(6)) + Expect(connPool.FreeLen()).To(Equal(6)) + + n, err := connPool.ReapStaleConns() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(3)) + }) + + It("reaps stale connections", func() { + Expect(connPool.Len()).To(Equal(3)) + Expect(connPool.FreeLen()).To(Equal(3)) + }) + + It("pool is functional", func() { + for j := 0; j < 3; j++ { + var freeCns []*pool.Conn + for i := 0; i < 3; i++ { + cn := connPool.First() + Expect(cn).NotTo(BeNil()) + freeCns = append(freeCns, cn) + } + + Expect(connPool.Len()).To(Equal(3)) + Expect(connPool.FreeLen()).To(Equal(0)) + + cn := connPool.First() + Expect(cn).To(BeNil()) + + cn, isNew, err := connPool.Get() + Expect(err).NotTo(HaveOccurred()) + Expect(isNew).To(BeTrue()) + Expect(cn).NotTo(BeNil()) + + Expect(connPool.Len()).To(Equal(4)) + Expect(connPool.FreeLen()).To(Equal(0)) + + err = connPool.Remove(cn, errors.New("test")) + Expect(err).NotTo(HaveOccurred()) + + Expect(connPool.Len()).To(Equal(3)) + Expect(connPool.FreeLen()).To(Equal(0)) + + for _, cn := range freeCns { + err := connPool.Put(cn) + Expect(err).NotTo(HaveOccurred()) + } + + Expect(connPool.Len()).To(Equal(3)) + Expect(connPool.FreeLen()).To(Equal(3)) + } + }) +})