From 393adc926170ffaf9eded50fd90a61b0b2b169f6 Mon Sep 17 00:00:00 2001 From: Joe Elliott Date: Tue, 10 Dec 2019 16:53:38 -0500 Subject: [PATCH] Refactor ~worked. All tests passing except one Signed-off-by: Joe Elliott --- api/client_test.go | 72 ------------------ api/prometheus/v1/api.go | 68 +++++++++++------ api/prometheus/v1/api_test.go | 133 ++++++++++++++++++++++++++++++++-- 3 files changed, 170 insertions(+), 103 deletions(-) diff --git a/api/client_test.go b/api/client_test.go index b3c95ee..47094fc 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -14,10 +14,7 @@ package api import ( - "context" - "encoding/json" "net/http" - "net/http/httptest" "net/url" "testing" ) @@ -114,72 +111,3 @@ func TestClientURL(t *testing.T) { } } } - -func TestDoGetFallback(t *testing.T) { - v := url.Values{"a": []string{"1", "2"}} - - type testResponse struct { - Values string - Method string - } - - // Start a local HTTP server. - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - req.ParseForm() - r := &testResponse{ - Values: req.Form.Encode(), - Method: req.Method, - } - - body, _ := json.Marshal(r) - - if req.Method == http.MethodPost { - if req.URL.Path == "/blockPost" { - http.Error(w, string(body), http.StatusMethodNotAllowed) - return - } - } - - w.Write(body) - })) - // Close the server when test finishes. - defer server.Close() - - u, err := url.Parse(server.URL) - if err != nil { - t.Fatal(err) - } - client := &httpClient{client: *(server.Client())} - - // Do a post, and ensure that the post succeeds. - _, b, _, err := DoGetFallback(client, context.TODO(), u, v) - if err != nil { - t.Fatalf("Error doing local request: %v", err) - } - resp := &testResponse{} - if err := json.Unmarshal(b, resp); err != nil { - t.Fatal(err) - } - if resp.Method != http.MethodPost { - t.Fatalf("Mismatch method") - } - if resp.Values != v.Encode() { - t.Fatalf("Mismatch in values") - } - - // Do a fallbcak to a get. - u.Path = "/blockPost" - _, b, _, err = DoGetFallback(client, context.TODO(), u, v) - if err != nil { - t.Fatalf("Error doing local request: %v", err) - } - if err := json.Unmarshal(b, resp); err != nil { - t.Fatal(err) - } - if resp.Method != http.MethodGet { - t.Fatalf("Mismatch method") - } - if resp.Values != v.Encode() { - t.Fatalf("Mismatch in values") - } -} diff --git a/api/prometheus/v1/api.go b/api/prometheus/v1/api.go index 154e4a8..9d1345d 100644 --- a/api/prometheus/v1/api.go +++ b/api/prometheus/v1/api.go @@ -517,11 +517,15 @@ func (qr *queryResult) UnmarshalJSON(b []byte) error { // // It is safe to use the returned API from multiple goroutines. func NewAPI(c api.Client) API { - return &httpAPI{client: c} + return &httpAPI{ + client: &apiClientImpl{ + client: c, + }, + } } type httpAPI struct { - client api.Client + client apiClient } func (h *httpAPI) Alerts(ctx context.Context) (AlertsResult, error) { @@ -532,7 +536,7 @@ func (h *httpAPI) Alerts(ctx context.Context) (AlertsResult, error) { return AlertsResult{}, err } - _, body, _, err := h.Do(ctx, req) + _, body, _, err := h.client.Do(ctx, req) if err != nil { return AlertsResult{}, err } @@ -549,7 +553,7 @@ func (h *httpAPI) AlertManagers(ctx context.Context) (AlertManagersResult, error return AlertManagersResult{}, err } - _, body, _, err := h.Do(ctx, req) + _, body, _, err := h.client.Do(ctx, req) if err != nil { return AlertManagersResult{}, err } @@ -566,7 +570,7 @@ func (h *httpAPI) CleanTombstones(ctx context.Context) error { return err } - _, _, _, err = h.Do(ctx, req) + _, _, _, err = h.client.Do(ctx, req) return err } @@ -578,7 +582,7 @@ func (h *httpAPI) Config(ctx context.Context) (ConfigResult, error) { return ConfigResult{}, err } - _, body, _, err := h.Do(ctx, req) + _, body, _, err := h.client.Do(ctx, req) if err != nil { return ConfigResult{}, err } @@ -605,7 +609,7 @@ func (h *httpAPI) DeleteSeries(ctx context.Context, matches []string, startTime return err } - _, _, _, err = h.Do(ctx, req) + _, _, _, err = h.client.Do(ctx, req) return err } @@ -617,7 +621,7 @@ func (h *httpAPI) Flags(ctx context.Context) (FlagsResult, error) { return FlagsResult{}, err } - _, body, _, err := h.Do(ctx, req) + _, body, _, err := h.client.Do(ctx, req) if err != nil { return FlagsResult{}, err } @@ -632,7 +636,7 @@ func (h *httpAPI) LabelNames(ctx context.Context) ([]string, api.Warnings, error if err != nil { return nil, nil, err } - _, body, w, err := h.Do(ctx, req) + _, body, w, err := h.client.Do(ctx, req) if err != nil { return nil, w, err } @@ -646,7 +650,7 @@ func (h *httpAPI) LabelValues(ctx context.Context, label string) (model.LabelVal if err != nil { return nil, nil, err } - _, body, w, err := h.Do(ctx, req) + _, body, w, err := h.client.Do(ctx, req) if err != nil { return nil, w, err } @@ -663,7 +667,7 @@ func (h *httpAPI) Query(ctx context.Context, query string, ts time.Time) (model. q.Set("time", formatTime(ts)) } - _, body, warnings, err := h.DoGetFallback(ctx, u, q) + _, body, warnings, err := h.client.DoGetFallback(ctx, u, q) if err != nil { return nil, warnings, err } @@ -681,7 +685,7 @@ func (h *httpAPI) QueryRange(ctx context.Context, query string, r Range) (model. q.Set("end", formatTime(r.End)) q.Set("step", strconv.FormatFloat(r.Step.Seconds(), 'f', -1, 64)) - _, body, warnings, err := h.DoGetFallback(ctx, u, q) + _, body, warnings, err := h.client.DoGetFallback(ctx, u, q) if err != nil { return nil, warnings, err } @@ -709,7 +713,7 @@ func (h *httpAPI) Series(ctx context.Context, matches []string, startTime time.T return nil, nil, err } - _, body, warnings, err := h.Do(ctx, req) + _, body, warnings, err := h.client.Do(ctx, req) if err != nil { return nil, warnings, err } @@ -731,7 +735,7 @@ func (h *httpAPI) Snapshot(ctx context.Context, skipHead bool) (SnapshotResult, return SnapshotResult{}, err } - _, body, _, err := h.Do(ctx, req) + _, body, _, err := h.client.Do(ctx, req) if err != nil { return SnapshotResult{}, err } @@ -748,7 +752,7 @@ func (h *httpAPI) Rules(ctx context.Context) (RulesResult, error) { return RulesResult{}, err } - _, body, _, err := h.Do(ctx, req) + _, body, _, err := h.client.Do(ctx, req) if err != nil { return RulesResult{}, err } @@ -765,7 +769,7 @@ func (h *httpAPI) Targets(ctx context.Context) (TargetsResult, error) { return TargetsResult{}, err } - _, body, _, err := h.Do(ctx, req) + _, body, _, err := h.client.Do(ctx, req) if err != nil { return TargetsResult{}, err } @@ -789,7 +793,7 @@ func (h *httpAPI) TargetsMetadata(ctx context.Context, matchTarget string, metri return nil, err } - _, body, _, err := h.Do(ctx, req) + _, body, _, err := h.client.Do(ctx, req) if err != nil { return nil, err } @@ -798,6 +802,18 @@ func (h *httpAPI) TargetsMetadata(ctx context.Context, matchTarget string, metri return res, json.Unmarshal(body, &res) } +// apiClient wraps a regular client and processes successful API responses. +// Successful also includes responses that errored at the API level. +type apiClient interface { + URL(ep string, args map[string]string) *url.URL + Do(context.Context, *http.Request) (*http.Response, []byte, api.Warnings, error) + DoGetFallback(ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, api.Warnings, error) +} + +type apiClientImpl struct { + client api.Client +} + type apiResponse struct { Status string `json:"status"` Data json.RawMessage `json:"data"` @@ -821,17 +837,21 @@ func errorTypeAndMsgFor(resp *http.Response) (ErrorType, string) { return ErrBadResponse, fmt.Sprintf("bad response code %d", resp.StatusCode) } -func (h *httpAPI) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, api.Warnings, error) { - resp, body, warnings, err := h.Do(ctx, req) +func (h *apiClientImpl) URL(ep string, args map[string]string) *url.URL { + return h.client.URL(ep, args) +} + +func (h *apiClientImpl) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, api.Warnings, error) { + resp, body, err := h.client.Do(ctx, req) if err != nil { - return resp, body, warnings, err + return resp, body, nil, err } code := resp.StatusCode if code/100 != 2 && !apiError(code) { errorType, errorMsg := errorTypeAndMsgFor(resp) - return resp, body, warnings, &Error{ + return resp, body, nil, &Error{ Type: errorType, Msg: errorMsg, Detail: string(body), @@ -842,7 +862,7 @@ func (h *httpAPI) Do(ctx context.Context, req *http.Request) (*http.Response, [] if http.StatusNoContent != code { if jsonErr := json.Unmarshal(body, &result); jsonErr != nil { - return resp, body, warnings, &Error{ + return resp, body, nil, &Error{ Type: ErrBadResponse, Msg: jsonErr.Error(), } @@ -863,12 +883,12 @@ func (h *httpAPI) Do(ctx context.Context, req *http.Request) (*http.Response, [] } } - return resp, []byte(result.Data), warnings, err + return resp, []byte(result.Data), result.Warnings, err } // DoGetFallback will attempt to do the request as-is, and on a 405 it will fallback to a GET request. -func (h *httpAPI) DoGetFallback(ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, api.Warnings, error) { +func (h *apiClientImpl) DoGetFallback(ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, api.Warnings, error) { req, err := http.NewRequest(http.MethodPost, u.String(), strings.NewReader(args.Encode())) if err != nil { return nil, nil, nil, err diff --git a/api/prometheus/v1/api_test.go b/api/prometheus/v1/api_test.go index 0572f86..374d6cd 100644 --- a/api/prometheus/v1/api_test.go +++ b/api/prometheus/v1/api_test.go @@ -17,8 +17,10 @@ import ( "context" "errors" "fmt" + "io/ioutil" "math" "net/http" + "net/http/httptest" "net/url" "reflect" "strings" @@ -92,14 +94,23 @@ func (c *apiTestClient) Do(ctx context.Context, req *http.Request) (*http.Respon return resp, b, test.inWarnings, test.inErr } +func (c *apiTestClient) DoGetFallback(ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, api.Warnings, error) { + req, err := http.NewRequest(http.MethodPost, u.String(), strings.NewReader(args.Encode())) + if err != nil { + return nil, nil, nil, err + } + return c.Do(ctx, req) +} + func TestAPIs(t *testing.T) { testTime := time.Now() - client := &apiTestClient{T: t} - + tc := &apiTestClient{ + T: t, + } promAPI := &httpAPI{ - client: client, + client: tc, } doAlertManagers := func() func() (interface{}, api.Warnings, error) { @@ -855,7 +866,7 @@ func TestAPIs(t *testing.T) { for i, test := range tests { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - client.curTest = test + tc.curTest = test res, warnings, err := test.do() @@ -907,7 +918,7 @@ func (c *testClient) URL(ep string, args map[string]string) *url.URL { return nil } -func (c *testClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, api.Warnings, error) { +func (c *testClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, error) { if ctx == nil { c.Fatalf("context was not passed down") } @@ -934,7 +945,7 @@ func (c *testClient) Do(ctx context.Context, req *http.Request) (*http.Response, StatusCode: test.code, } - return resp, b, test.expectedWarnings, nil + return resp, b, nil } func TestAPIClientDo(t *testing.T) { @@ -1065,7 +1076,9 @@ func TestAPIClientDo(t *testing.T) { ch: make(chan apiClientTest, 1), req: &http.Request{}, } - client := &apiClient{tc} + client := &apiClientImpl{ + client: tc, + } for i, test := range tests { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { @@ -1209,3 +1222,109 @@ func TestSamplesJsonSerialization(t *testing.T) { }) } } + +type httpTestClient struct { + client http.Client +} + +func (c *httpTestClient) URL(ep string, args map[string]string) *url.URL { + return nil +} + +func (c *httpTestClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, error) { + resp, err := c.client.Do(req) + if err != nil { + return nil, nil, err + } + + var body []byte + done := make(chan struct{}) + go func() { + body, err = ioutil.ReadAll(resp.Body) + close(done) + }() + + select { + case <-ctx.Done(): + <-done + err = resp.Body.Close() + if err == nil { + err = ctx.Err() + } + case <-done: + } + + return resp, body, err +} + +func TestDoGetFallback(t *testing.T) { + v := url.Values{"a": []string{"1", "2"}} + + type testResponse struct { + Values string + Method string + } + + // Start a local HTTP server. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + req.ParseForm() + r := &testResponse{ + Values: req.Form.Encode(), + Method: req.Method, + } + + body, _ := json.Marshal(r) + + if req.Method == http.MethodPost { + if req.URL.Path == "/blockPost" { + http.Error(w, string(body), http.StatusMethodNotAllowed) + return + } + } + + w.Write(body) + })) + // Close the server when test finishes. + defer server.Close() + + u, err := url.Parse(server.URL) + if err != nil { + t.Fatal(err) + } + client := &httpTestClient{client: *(server.Client())} + api := &apiClientImpl{ + client: client, + } + + // Do a post, and ensure that the post succeeds. + _, b, _, err := api.DoGetFallback(context.TODO(), u, v) + if err != nil { + t.Fatalf("Error doing local request: %v", err) + } + resp := &testResponse{} + if err := json.Unmarshal(b, resp); err != nil { + t.Fatal(err) + } + if resp.Method != http.MethodPost { + t.Fatalf("Mismatch method") + } + if resp.Values != v.Encode() { + t.Fatalf("Mismatch in values") + } + + // Do a fallbcak to a get. + u.Path = "/blockPost" + _, b, _, err = api.DoGetFallback(context.TODO(), u, v) + if err != nil { + t.Fatalf("Error doing local request: %v", err) + } + if err := json.Unmarshal(b, resp); err != nil { + t.Fatal(err) + } + if resp.Method != http.MethodGet { + t.Fatalf("Mismatch method") + } + if resp.Values != v.Encode() { + t.Fatalf("Mismatch in values") + } +}