diff --git a/prometheus/promhttp/instrument_client_test.go b/prometheus/promhttp/instrument_client_test.go index 59c0a2f..50d64bd 100644 --- a/prometheus/promhttp/instrument_client_test.go +++ b/prometheus/promhttp/instrument_client_test.go @@ -14,15 +14,18 @@ package promhttp import ( + "context" + "fmt" "log" "net/http" + "net/http/httptest" "testing" "time" "github.com/prometheus/client_golang/prometheus" ) -func TestClientMiddlewareAPI(t *testing.T) { +func makeInstrumentedClient() (*http.Client, *prometheus.Registry) { client := http.DefaultClient client.Timeout = 1 * time.Second @@ -92,12 +95,100 @@ func TestClientMiddlewareAPI(t *testing.T) { ), ), ) + return client, reg +} - resp, err := client.Get("http://google.com") +func TestClientMiddlewareAPI(t *testing.T) { + client, reg := makeInstrumentedClient() + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + resp, err := client.Get(backend.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + mfs, err := reg.Gather() + if err != nil { + t.Fatal(err) + } + if want, got := 3, len(mfs); want != got { + t.Fatalf("unexpected number of metric families gathered, want %d, got %d", want, got) + } + for _, mf := range mfs { + if len(mf.Metric) == 0 { + t.Errorf("metric family %s must not be empty", mf.GetName()) + } + } +} + +func TestClientMiddlewareAPIWithRequestContext(t *testing.T) { + client, reg := makeInstrumentedClient() + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + req, err := http.NewRequest("GET", backend.URL, nil) if err != nil { t.Fatalf("%v", err) } + + // Set a context with a long timeout. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + req = req.WithContext(ctx) + + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } defer resp.Body.Close() + + mfs, err := reg.Gather() + if err != nil { + t.Fatal(err) + } + if want, got := 3, len(mfs); want != got { + t.Fatalf("unexpected number of metric families gathered, want %d, got %d", want, got) + } + for _, mf := range mfs { + if len(mf.Metric) == 0 { + t.Errorf("metric family %s must not be empty", mf.GetName()) + } + } +} + +func TestClientMiddlewareAPIWithRequestContextTimeout(t *testing.T) { + client, _ := makeInstrumentedClient() + + // Slow testserver responding in 100ms. + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + req, err := http.NewRequest("GET", backend.URL, nil) + if err != nil { + t.Fatalf("%v", err) + } + + // Set a context with a short timeout. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + req = req.WithContext(ctx) + + _, err = client.Do(req) + if err == nil { + t.Fatal("did not get timeout error") + } + if want, got := fmt.Sprintf("Get %s: context deadline exceeded", backend.URL), err.Error(); want != got { + t.Fatalf("want error %q, got %q", want, got) + } } func ExampleInstrumentRoundTripperDuration() {