diff --git a/prometheus/go_collector_test.go b/prometheus/go_collector_test.go index f93dcdc..7ab3968 100644 --- a/prometheus/go_collector_test.go +++ b/prometheus/go_collector_test.go @@ -21,28 +21,40 @@ import ( dto "github.com/prometheus/client_model/go" ) -func TestGoCollector(t *testing.T) { +func TestGoCollectorGoroutines(t *testing.T) { var ( - c = NewGoCollector() - ch = make(chan Metric) - waitc = make(chan struct{}) - closec = make(chan struct{}) - old = -1 + c = NewGoCollector() + metricCh = make(chan Metric) + waitCh = make(chan struct{}) + endGoroutineCh = make(chan struct{}) + endCollectionCh = make(chan struct{}) + old = -1 ) - defer close(closec) + defer func() { + close(endGoroutineCh) + // Drain the collect channel to prevent goroutine leak. + for { + select { + case <-metricCh: + case <-endCollectionCh: + return + } + } + }() go func() { - c.Collect(ch) + c.Collect(metricCh) go func(c <-chan struct{}) { <-c - }(closec) - <-waitc - c.Collect(ch) + }(endGoroutineCh) + <-waitCh + c.Collect(metricCh) + close(endCollectionCh) }() for { select { - case m := <-ch: + case m := <-metricCh: // m can be Gauge or Counter, // currently just test the go_goroutines Gauge // and ignore others. @@ -57,7 +69,7 @@ func TestGoCollector(t *testing.T) { if old == -1 { old = int(pb.GetGauge().GetValue()) - close(waitc) + close(waitCh) continue } @@ -65,43 +77,47 @@ func TestGoCollector(t *testing.T) { // TODO: This is flaky in highly concurrent situations. t.Errorf("want 1 new goroutine, got %d", diff) } - - // GoCollector performs three sends per call. - // On line 27 we need to receive three more sends - // to shut down cleanly. - <-ch - <-ch - <-ch - return case <-time.After(1 * time.Second): t.Fatalf("expected collect timed out") } + break } } -func TestGCCollector(t *testing.T) { +func TestGoCollectorGC(t *testing.T) { var ( - c = NewGoCollector() - ch = make(chan Metric) - waitc = make(chan struct{}) - closec = make(chan struct{}) - oldGC uint64 - oldPause float64 + c = NewGoCollector() + metricCh = make(chan Metric) + waitCh = make(chan struct{}) + endCollectionCh = make(chan struct{}) + oldGC uint64 + oldPause float64 ) - defer close(closec) go func() { - c.Collect(ch) + c.Collect(metricCh) // force GC runtime.GC() - <-waitc - c.Collect(ch) + <-waitCh + c.Collect(metricCh) + close(endCollectionCh) + }() + + defer func() { + // Drain the collect channel to prevent goroutine leak. + for { + select { + case <-metricCh: + case <-endCollectionCh: + return + } + } }() first := true for { select { - case metric := <-ch: + case metric := <-metricCh: pb := &dto.Metric{} metric.Write(pb) if pb.GetSummary() == nil { @@ -119,7 +135,7 @@ func TestGCCollector(t *testing.T) { first = false oldGC = *pb.GetSummary().SampleCount oldPause = *pb.GetSummary().SampleSum - close(waitc) + close(waitCh) continue } if diff := *pb.GetSummary().SampleCount - oldGC; diff != 1 { @@ -128,9 +144,9 @@ func TestGCCollector(t *testing.T) { if diff := *pb.GetSummary().SampleSum - oldPause; diff <= 0 { t.Errorf("want moar pause, got %f", diff) } - return case <-time.After(1 * time.Second): t.Fatalf("expected collect timed out") } + break } }