diff --git a/internal/pool/conn_stack.go b/internal/pool/conn_stack.go index a26ab0ee..8b8f5057 100644 --- a/internal/pool/conn_stack.go +++ b/internal/pool/conn_stack.go @@ -32,13 +32,13 @@ func (s *connStack) ShiftStale(idleTimeout time.Duration) *Conn { select { case <-s.free: s.mu.Lock() - defer s.mu.Unlock() - if cn := s.cns[0]; cn.IsStale(idleTimeout) { copy(s.cns, s.cns[1:]) s.cns = s.cns[:len(s.cns)-1] + s.mu.Unlock() return cn } + s.mu.Unlock() s.free <- struct{}{} return nil diff --git a/internal/pool/main_test.go b/internal/pool/main_test.go new file mode 100644 index 00000000..43afe3fa --- /dev/null +++ b/internal/pool/main_test.go @@ -0,0 +1,35 @@ +package pool_test + +import ( + "net" + "sync" + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestGinkgoSuite(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "pool") +} + +func perform(n int, cbs ...func(int)) { + var wg sync.WaitGroup + for _, cb := range cbs { + for i := 0; i < n; i++ { + wg.Add(1) + go func(cb func(int), i int) { + defer GinkgoRecover() + defer wg.Done() + + cb(i) + }(cb, i) + } + } + wg.Wait() +} + +func dummyDialer() (net.Conn, error) { + return &net.TCPConn{}, nil +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 932146ea..c3281c3d 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -12,7 +12,7 @@ import ( "gopkg.in/bsm/ratelimit.v1" ) -var Logger = log.New(os.Stderr, "pg: ", log.LstdFlags) +var Logger = log.New(os.Stderr, "redis: ", log.LstdFlags) var ( ErrClosed = errors.New("redis: client is closed") @@ -108,9 +108,9 @@ func (p *ConnPool) First() *Conn { } // wait waits for free non-idle connection. It returns nil on timeout. -func (p *ConnPool) wait() *Conn { +func (p *ConnPool) wait(timeout time.Duration) *Conn { for { - cn := p.freeConns.PopWithTimeout(p.poolTimeout) + cn := p.freeConns.PopWithTimeout(timeout) if cn != nil && cn.IsStale(p.idleTimeout) { var err error cn, err = p.replace(cn) @@ -175,7 +175,7 @@ func (p *ConnPool) Get() (*Conn, error) { // Otherwise, wait for the available connection. atomic.AddUint32(&p.stats.Waits, 1) - if cn := p.wait(); cn != nil { + if cn := p.wait(p.poolTimeout); cn != nil { return cn, nil } @@ -270,8 +270,8 @@ func (p *ConnPool) Close() (retErr error) { } // Wait for app to free connections, but don't close them immediately. - for i := 0; i < p.Len(); i++ { - if cn := p.wait(); cn == nil { + for i := 0; i < p.Len()-p.FreeLen(); i++ { + if cn := p.wait(3 * time.Second); cn == nil { break } } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 1c591924..d7a29883 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -12,19 +12,93 @@ import ( "gopkg.in/redis.v3/internal/pool" ) -func TestGinkgoSuite(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "pool") -} +var _ = Describe("ConnPool", func() { + var connPool *pool.ConnPool + + BeforeEach(func() { + pool.SetIdleCheckFrequency(time.Second) + connPool = pool.NewConnPool(dummyDialer, 10, time.Hour, time.Second) + }) + + AfterEach(func() { + connPool.Close() + }) + + It("rate limits dial", func() { + var rateErr error + for i := 0; i < 1000; i++ { + cn, err := connPool.Get() + if err != nil { + rateErr = err + break + } + + _ = connPool.Replace(cn, errors.New("test")) + } + + Expect(rateErr).To(MatchError(`redis: you open connections too fast (last_error="test")`)) + }) + + It("should unblock client when conn is removed", func() { + // Reserve one connection. + cn, err := connPool.Get() + Expect(err).NotTo(HaveOccurred()) + + // Reserve all other connections. + var cns []*pool.Conn + for i := 0; i < 9; i++ { + cn, err := connPool.Get() + Expect(err).NotTo(HaveOccurred()) + cns = append(cns, cn) + } + + started := make(chan bool, 1) + done := make(chan bool, 1) + go func() { + defer GinkgoRecover() + + started <- true + _, err := connPool.Get() + Expect(err).NotTo(HaveOccurred()) + done <- true + + err = connPool.Put(cn) + Expect(err).NotTo(HaveOccurred()) + }() + <-started + + // Check that Get is blocked. + select { + case <-done: + Fail("Get is not blocked") + default: + // ok + } + + err = connPool.Replace(cn, errors.New("test")) + Expect(err).NotTo(HaveOccurred()) + + // Check that Ping is unblocked. + select { + case <-done: + // ok + case <-time.After(time.Second): + Fail("Get is not unblocked") + } + + for _, cn := range cns { + err = connPool.Put(cn) + Expect(err).NotTo(HaveOccurred()) + } + }) +}) var _ = Describe("conns reapser", func() { var connPool *pool.ConnPool BeforeEach(func() { - dial := func() (net.Conn, error) { - return &net.TCPConn{}, nil - } - connPool = pool.NewConnPool(dial, 10, 0, time.Minute) + pool.SetIdleCheckFrequency(time.Hour) + connPool = pool.NewConnPool(dummyDialer, 10, 0, time.Minute) // add stale connections for i := 0; i < 3; i++ { @@ -49,6 +123,10 @@ var _ = Describe("conns reapser", func() { Expect(n).To(Equal(3)) }) + AfterEach(func() { + connPool.Close() + }) + It("reaps stale connections", func() { Expect(connPool.Len()).To(Equal(3)) Expect(connPool.FreeLen()).To(Equal(3)) @@ -92,3 +170,47 @@ var _ = Describe("conns reapser", func() { } }) }) + +var _ = Describe("race", func() { + var connPool *pool.ConnPool + + var C, N = 10, 1000 + if testing.Short() { + C = 4 + N = 100 + } + + BeforeEach(func() { + pool.SetIdleCheckFrequency(time.Second) + connPool = pool.NewConnPool(dummyDialer, 10, time.Second, time.Second) + }) + + AfterEach(func() { + connPool.Close() + }) + + It("does not happend", func() { + perform(C, func(id int) { + for i := 0; i < N; i++ { + cn, err := connPool.Get() + if err == nil { + connPool.Put(cn) + } + } + }, func(id int) { + for i := 0; i < N; i++ { + cn, err := connPool.Get() + if err == nil { + connPool.Replace(cn, errors.New("test")) + } + } + }, func(id int) { + for i := 0; i < N; i++ { + cn, err := connPool.Get() + if err == nil { + connPool.Remove(cn, errors.New("test")) + } + } + }) + }) +}) diff --git a/pool_test.go b/pool_test.go index bf1ae4aa..ec5730c8 100644 --- a/pool_test.go +++ b/pool_test.go @@ -1,9 +1,6 @@ package redis_test import ( - "errors" - "time" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -131,65 +128,4 @@ var _ = Describe("pool", func() { Expect(stats.Waits).To(Equal(uint32(0))) Expect(stats.Timeouts).To(Equal(uint32(0))) }) - - It("should unblock client when connection is removed", func() { - pool := client.Pool() - - // Reserve one connection. - cn, err := pool.Get() - Expect(err).NotTo(HaveOccurred()) - - // Reserve the rest of connections. - for i := 0; i < 9; i++ { - _, err := pool.Get() - Expect(err).NotTo(HaveOccurred()) - } - - var ping *redis.StatusCmd - started := make(chan bool, 1) - done := make(chan bool, 1) - go func() { - started <- true - ping = client.Ping() - done <- true - }() - <-started - - // Check that Ping is blocked. - select { - case <-done: - panic("Ping is not blocked") - default: - // ok - } - - err = pool.Replace(cn, errors.New("test")) - Expect(err).NotTo(HaveOccurred()) - - // Check that Ping is unblocked. - select { - case <-done: - // ok - case <-time.After(time.Second): - panic("Ping is not unblocked") - } - Expect(ping.Err()).NotTo(HaveOccurred()) - }) - - It("should rate limit dial", func() { - pool := client.Pool() - - var rateErr error - for i := 0; i < 1000; i++ { - cn, err := pool.Get() - if err != nil { - rateErr = err - break - } - - _ = pool.Replace(cn, errors.New("test")) - } - - Expect(rateErr).To(MatchError(`redis: you open connections too fast (last_error="test")`)) - }) })