Merge pull request #699 from joe-elliott/http-client-warnings

Return Prometheus Warnings
This commit is contained in:
Björn Rabenstein 2019-12-13 01:07:50 +01:00 committed by GitHub
commit e7776d2c54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 232 additions and 166 deletions

View File

@ -25,8 +25,6 @@ import (
"time" "time"
) )
type Warnings []string
// DefaultRoundTripper is used if no RoundTripper is set in Config. // DefaultRoundTripper is used if no RoundTripper is set in Config.
var DefaultRoundTripper http.RoundTripper = &http.Transport{ var DefaultRoundTripper http.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
@ -57,32 +55,7 @@ func (cfg *Config) roundTripper() http.RoundTripper {
// Client is the interface for an API client. // Client is the interface for an API client.
type Client interface { type Client interface {
URL(ep string, args map[string]string) *url.URL URL(ep string, args map[string]string) *url.URL
Do(context.Context, *http.Request) (*http.Response, []byte, Warnings, error) Do(context.Context, *http.Request) (*http.Response, []byte, error)
}
// DoGetFallback will attempt to do the request as-is, and on a 405 it will fallback to a GET request.
func DoGetFallback(c Client, ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, Warnings, error) {
req, err := http.NewRequest(http.MethodPost, u.String(), strings.NewReader(args.Encode()))
if err != nil {
return nil, nil, nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, body, warnings, err := c.Do(ctx, req)
if resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
u.RawQuery = args.Encode()
req, err = http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return nil, nil, warnings, err
}
} else {
if err != nil {
return resp, body, warnings, err
}
return resp, body, warnings, nil
}
return c.Do(ctx, req)
} }
// NewClient returns a new Client. // NewClient returns a new Client.
@ -120,7 +93,7 @@ func (c *httpClient) URL(ep string, args map[string]string) *url.URL {
return &u return &u
} }
func (c *httpClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, Warnings, error) { func (c *httpClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, error) {
if ctx != nil { if ctx != nil {
req = req.WithContext(ctx) req = req.WithContext(ctx)
} }
@ -132,7 +105,7 @@ func (c *httpClient) Do(ctx context.Context, req *http.Request) (*http.Response,
}() }()
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
var body []byte var body []byte
@ -152,5 +125,5 @@ func (c *httpClient) Do(ctx context.Context, req *http.Request) (*http.Response,
case <-done: case <-done:
} }
return resp, body, nil, err return resp, body, err
} }

View File

@ -14,10 +14,7 @@
package api package api
import ( import (
"context"
"encoding/json"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"testing" "testing"
) )
@ -114,72 +111,3 @@ func TestClientURL(t *testing.T) {
} }
} }
} }
func TestDoGetFallback(t *testing.T) {
v := url.Values{"a": []string{"1", "2"}}
type testResponse struct {
Values string
Method string
}
// Start a local HTTP server.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.ParseForm()
r := &testResponse{
Values: req.Form.Encode(),
Method: req.Method,
}
body, _ := json.Marshal(r)
if req.Method == http.MethodPost {
if req.URL.Path == "/blockPost" {
http.Error(w, string(body), http.StatusMethodNotAllowed)
return
}
}
w.Write(body)
}))
// Close the server when test finishes.
defer server.Close()
u, err := url.Parse(server.URL)
if err != nil {
t.Fatal(err)
}
client := &httpClient{client: *(server.Client())}
// Do a post, and ensure that the post succeeds.
_, b, _, err := DoGetFallback(client, context.TODO(), u, v)
if err != nil {
t.Fatalf("Error doing local request: %v", err)
}
resp := &testResponse{}
if err := json.Unmarshal(b, resp); err != nil {
t.Fatal(err)
}
if resp.Method != http.MethodPost {
t.Fatalf("Mismatch method")
}
if resp.Values != v.Encode() {
t.Fatalf("Mismatch in values")
}
// Do a fallbcak to a get.
u.Path = "/blockPost"
_, b, _, err = DoGetFallback(client, context.TODO(), u, v)
if err != nil {
t.Fatalf("Error doing local request: %v", err)
}
if err := json.Unmarshal(b, resp); err != nil {
t.Fatal(err)
}
if resp.Method != http.MethodGet {
t.Fatalf("Mismatch method")
}
if resp.Values != v.Encode() {
t.Fatalf("Mismatch in values")
}
}

View File

@ -21,7 +21,9 @@ import (
"fmt" "fmt"
"math" "math"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings"
"time" "time"
"unsafe" "unsafe"
@ -228,15 +230,15 @@ type API interface {
// Flags returns the flag values that Prometheus was launched with. // Flags returns the flag values that Prometheus was launched with.
Flags(ctx context.Context) (FlagsResult, error) Flags(ctx context.Context) (FlagsResult, error)
// LabelNames returns all the unique label names present in the block in sorted order. // LabelNames returns all the unique label names present in the block in sorted order.
LabelNames(ctx context.Context) ([]string, api.Warnings, error) LabelNames(ctx context.Context) ([]string, Warnings, error)
// LabelValues performs a query for the values of the given label. // LabelValues performs a query for the values of the given label.
LabelValues(ctx context.Context, label string) (model.LabelValues, api.Warnings, error) LabelValues(ctx context.Context, label string) (model.LabelValues, Warnings, error)
// Query performs a query for the given time. // Query performs a query for the given time.
Query(ctx context.Context, query string, ts time.Time) (model.Value, api.Warnings, error) Query(ctx context.Context, query string, ts time.Time) (model.Value, Warnings, error)
// QueryRange performs a query for the given range. // QueryRange performs a query for the given range.
QueryRange(ctx context.Context, query string, r Range) (model.Value, api.Warnings, error) QueryRange(ctx context.Context, query string, r Range) (model.Value, Warnings, error)
// Series finds series by label matchers. // Series finds series by label matchers.
Series(ctx context.Context, matches []string, startTime time.Time, endTime time.Time) ([]model.LabelSet, api.Warnings, error) Series(ctx context.Context, matches []string, startTime time.Time, endTime time.Time) ([]model.LabelSet, Warnings, error)
// Snapshot creates a snapshot of all current data into snapshots/<datetime>-<rand> // Snapshot creates a snapshot of all current data into snapshots/<datetime>-<rand>
// under the TSDB's data directory and returns the directory as response. // under the TSDB's data directory and returns the directory as response.
Snapshot(ctx context.Context, skipHead bool) (SnapshotResult, error) Snapshot(ctx context.Context, skipHead bool) (SnapshotResult, error)
@ -515,11 +517,15 @@ func (qr *queryResult) UnmarshalJSON(b []byte) error {
// //
// It is safe to use the returned API from multiple goroutines. // It is safe to use the returned API from multiple goroutines.
func NewAPI(c api.Client) API { func NewAPI(c api.Client) API {
return &httpAPI{client: apiClient{c}} return &httpAPI{
client: &apiClientImpl{
client: c,
},
}
} }
type httpAPI struct { type httpAPI struct {
client api.Client client apiClient
} }
func (h *httpAPI) Alerts(ctx context.Context) (AlertsResult, error) { func (h *httpAPI) Alerts(ctx context.Context) (AlertsResult, error) {
@ -624,7 +630,7 @@ func (h *httpAPI) Flags(ctx context.Context) (FlagsResult, error) {
return res, json.Unmarshal(body, &res) return res, json.Unmarshal(body, &res)
} }
func (h *httpAPI) LabelNames(ctx context.Context) ([]string, api.Warnings, error) { func (h *httpAPI) LabelNames(ctx context.Context) ([]string, Warnings, error) {
u := h.client.URL(epLabels, nil) u := h.client.URL(epLabels, nil)
req, err := http.NewRequest(http.MethodGet, u.String(), nil) req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil { if err != nil {
@ -638,7 +644,7 @@ func (h *httpAPI) LabelNames(ctx context.Context) ([]string, api.Warnings, error
return labelNames, w, json.Unmarshal(body, &labelNames) return labelNames, w, json.Unmarshal(body, &labelNames)
} }
func (h *httpAPI) LabelValues(ctx context.Context, label string) (model.LabelValues, api.Warnings, error) { func (h *httpAPI) LabelValues(ctx context.Context, label string) (model.LabelValues, Warnings, error) {
u := h.client.URL(epLabelValues, map[string]string{"name": label}) u := h.client.URL(epLabelValues, map[string]string{"name": label})
req, err := http.NewRequest(http.MethodGet, u.String(), nil) req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil { if err != nil {
@ -652,7 +658,7 @@ func (h *httpAPI) LabelValues(ctx context.Context, label string) (model.LabelVal
return labelValues, w, json.Unmarshal(body, &labelValues) return labelValues, w, json.Unmarshal(body, &labelValues)
} }
func (h *httpAPI) Query(ctx context.Context, query string, ts time.Time) (model.Value, api.Warnings, error) { func (h *httpAPI) Query(ctx context.Context, query string, ts time.Time) (model.Value, Warnings, error) {
u := h.client.URL(epQuery, nil) u := h.client.URL(epQuery, nil)
q := u.Query() q := u.Query()
@ -661,7 +667,7 @@ func (h *httpAPI) Query(ctx context.Context, query string, ts time.Time) (model.
q.Set("time", formatTime(ts)) q.Set("time", formatTime(ts))
} }
_, body, warnings, err := api.DoGetFallback(h.client, ctx, u, q) _, body, warnings, err := h.client.DoGetFallback(ctx, u, q)
if err != nil { if err != nil {
return nil, warnings, err return nil, warnings, err
} }
@ -670,7 +676,7 @@ func (h *httpAPI) Query(ctx context.Context, query string, ts time.Time) (model.
return model.Value(qres.v), warnings, json.Unmarshal(body, &qres) return model.Value(qres.v), warnings, json.Unmarshal(body, &qres)
} }
func (h *httpAPI) QueryRange(ctx context.Context, query string, r Range) (model.Value, api.Warnings, error) { func (h *httpAPI) QueryRange(ctx context.Context, query string, r Range) (model.Value, Warnings, error) {
u := h.client.URL(epQueryRange, nil) u := h.client.URL(epQueryRange, nil)
q := u.Query() q := u.Query()
@ -679,7 +685,7 @@ func (h *httpAPI) QueryRange(ctx context.Context, query string, r Range) (model.
q.Set("end", formatTime(r.End)) q.Set("end", formatTime(r.End))
q.Set("step", strconv.FormatFloat(r.Step.Seconds(), 'f', -1, 64)) q.Set("step", strconv.FormatFloat(r.Step.Seconds(), 'f', -1, 64))
_, body, warnings, err := api.DoGetFallback(h.client, ctx, u, q) _, body, warnings, err := h.client.DoGetFallback(ctx, u, q)
if err != nil { if err != nil {
return nil, warnings, err return nil, warnings, err
} }
@ -689,7 +695,7 @@ func (h *httpAPI) QueryRange(ctx context.Context, query string, r Range) (model.
return model.Value(qres.v), warnings, json.Unmarshal(body, &qres) return model.Value(qres.v), warnings, json.Unmarshal(body, &qres)
} }
func (h *httpAPI) Series(ctx context.Context, matches []string, startTime time.Time, endTime time.Time) ([]model.LabelSet, api.Warnings, error) { func (h *httpAPI) Series(ctx context.Context, matches []string, startTime time.Time, endTime time.Time) ([]model.LabelSet, Warnings, error) {
u := h.client.URL(epSeries, nil) u := h.client.URL(epSeries, nil)
q := u.Query() q := u.Query()
@ -796,10 +802,19 @@ func (h *httpAPI) TargetsMetadata(ctx context.Context, matchTarget string, metri
return res, json.Unmarshal(body, &res) return res, json.Unmarshal(body, &res)
} }
// Warnings is an array of non critical errors
type Warnings []string
// apiClient wraps a regular client and processes successful API responses. // apiClient wraps a regular client and processes successful API responses.
// Successful also includes responses that errored at the API level. // Successful also includes responses that errored at the API level.
type apiClient struct { type apiClient interface {
api.Client URL(ep string, args map[string]string) *url.URL
Do(context.Context, *http.Request) (*http.Response, []byte, Warnings, error)
DoGetFallback(ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, Warnings, error)
}
type apiClientImpl struct {
client api.Client
} }
type apiResponse struct { type apiResponse struct {
@ -825,17 +840,21 @@ func errorTypeAndMsgFor(resp *http.Response) (ErrorType, string) {
return ErrBadResponse, fmt.Sprintf("bad response code %d", resp.StatusCode) return ErrBadResponse, fmt.Sprintf("bad response code %d", resp.StatusCode)
} }
func (c apiClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, api.Warnings, error) { func (h *apiClientImpl) URL(ep string, args map[string]string) *url.URL {
resp, body, warnings, err := c.Client.Do(ctx, req) return h.client.URL(ep, args)
}
func (h *apiClientImpl) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, Warnings, error) {
resp, body, err := h.client.Do(ctx, req)
if err != nil { if err != nil {
return resp, body, warnings, err return resp, body, nil, err
} }
code := resp.StatusCode code := resp.StatusCode
if code/100 != 2 && !apiError(code) { if code/100 != 2 && !apiError(code) {
errorType, errorMsg := errorTypeAndMsgFor(resp) errorType, errorMsg := errorTypeAndMsgFor(resp)
return resp, body, warnings, &Error{ return resp, body, nil, &Error{
Type: errorType, Type: errorType,
Msg: errorMsg, Msg: errorMsg,
Detail: string(body), Detail: string(body),
@ -846,7 +865,7 @@ func (c apiClient) Do(ctx context.Context, req *http.Request) (*http.Response, [
if http.StatusNoContent != code { if http.StatusNoContent != code {
if jsonErr := json.Unmarshal(body, &result); jsonErr != nil { if jsonErr := json.Unmarshal(body, &result); jsonErr != nil {
return resp, body, warnings, &Error{ return resp, body, nil, &Error{
Type: ErrBadResponse, Type: ErrBadResponse,
Msg: jsonErr.Error(), Msg: jsonErr.Error(),
} }
@ -867,10 +886,35 @@ func (c apiClient) Do(ctx context.Context, req *http.Request) (*http.Response, [
} }
} }
return resp, []byte(result.Data), warnings, err return resp, []byte(result.Data), result.Warnings, err
} }
// DoGetFallback will attempt to do the request as-is, and on a 405 it will fallback to a GET request.
func (h *apiClientImpl) DoGetFallback(ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, Warnings, error) {
req, err := http.NewRequest(http.MethodPost, u.String(), strings.NewReader(args.Encode()))
if err != nil {
return nil, nil, nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, body, warnings, err := h.Do(ctx, req)
if resp != nil && resp.StatusCode == http.StatusMethodNotAllowed {
u.RawQuery = args.Encode()
req, err = http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return nil, nil, warnings, err
}
} else {
if err != nil {
return resp, body, warnings, err
}
return resp, body, warnings, nil
}
return h.Do(ctx, req)
}
func formatTime(t time.Time) string { func formatTime(t time.Time) string {
return strconv.FormatFloat(float64(t.Unix())+float64(t.Nanosecond())/1e9, 'f', -1, 64) return strconv.FormatFloat(float64(t.Unix())+float64(t.Nanosecond())/1e9, 'f', -1, 64)
} }

View File

@ -17,8 +17,10 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"math" "math"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
@ -28,12 +30,10 @@ import (
json "github.com/json-iterator/go" json "github.com/json-iterator/go"
"github.com/prometheus/common/model" "github.com/prometheus/common/model"
"github.com/prometheus/client_golang/api"
) )
type apiTest struct { type apiTest struct {
do func() (interface{}, api.Warnings, error) do func() (interface{}, Warnings, error)
inWarnings []string inWarnings []string
inErr error inErr error
inStatusCode int inStatusCode int
@ -43,7 +43,7 @@ type apiTest struct {
reqParam url.Values reqParam url.Values
reqMethod string reqMethod string
res interface{} res interface{}
warnings api.Warnings warnings Warnings
err error err error
} }
@ -64,7 +64,7 @@ func (c *apiTestClient) URL(ep string, args map[string]string) *url.URL {
return u return u
} }
func (c *apiTestClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, api.Warnings, error) { func (c *apiTestClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, Warnings, error) {
test := c.curTest test := c.curTest
@ -92,102 +92,111 @@ func (c *apiTestClient) Do(ctx context.Context, req *http.Request) (*http.Respon
return resp, b, test.inWarnings, test.inErr return resp, b, test.inWarnings, test.inErr
} }
func (c *apiTestClient) DoGetFallback(ctx context.Context, u *url.URL, args url.Values) (*http.Response, []byte, Warnings, error) {
req, err := http.NewRequest(http.MethodPost, u.String(), strings.NewReader(args.Encode()))
if err != nil {
return nil, nil, nil, err
}
return c.Do(ctx, req)
}
func TestAPIs(t *testing.T) { func TestAPIs(t *testing.T) {
testTime := time.Now() testTime := time.Now()
client := &apiTestClient{T: t} tc := &apiTestClient{
T: t,
}
promAPI := &httpAPI{ promAPI := &httpAPI{
client: client, client: tc,
} }
doAlertManagers := func() func() (interface{}, api.Warnings, error) { doAlertManagers := func() func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
v, err := promAPI.AlertManagers(context.Background()) v, err := promAPI.AlertManagers(context.Background())
return v, nil, err return v, nil, err
} }
} }
doCleanTombstones := func() func() (interface{}, api.Warnings, error) { doCleanTombstones := func() func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
return nil, nil, promAPI.CleanTombstones(context.Background()) return nil, nil, promAPI.CleanTombstones(context.Background())
} }
} }
doConfig := func() func() (interface{}, api.Warnings, error) { doConfig := func() func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
v, err := promAPI.Config(context.Background()) v, err := promAPI.Config(context.Background())
return v, nil, err return v, nil, err
} }
} }
doDeleteSeries := func(matcher string, startTime time.Time, endTime time.Time) func() (interface{}, api.Warnings, error) { doDeleteSeries := func(matcher string, startTime time.Time, endTime time.Time) func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
return nil, nil, promAPI.DeleteSeries(context.Background(), []string{matcher}, startTime, endTime) return nil, nil, promAPI.DeleteSeries(context.Background(), []string{matcher}, startTime, endTime)
} }
} }
doFlags := func() func() (interface{}, api.Warnings, error) { doFlags := func() func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
v, err := promAPI.Flags(context.Background()) v, err := promAPI.Flags(context.Background())
return v, nil, err return v, nil, err
} }
} }
doLabelNames := func(label string) func() (interface{}, api.Warnings, error) { doLabelNames := func(label string) func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
return promAPI.LabelNames(context.Background()) return promAPI.LabelNames(context.Background())
} }
} }
doLabelValues := func(label string) func() (interface{}, api.Warnings, error) { doLabelValues := func(label string) func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
return promAPI.LabelValues(context.Background(), label) return promAPI.LabelValues(context.Background(), label)
} }
} }
doQuery := func(q string, ts time.Time) func() (interface{}, api.Warnings, error) { doQuery := func(q string, ts time.Time) func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
return promAPI.Query(context.Background(), q, ts) return promAPI.Query(context.Background(), q, ts)
} }
} }
doQueryRange := func(q string, rng Range) func() (interface{}, api.Warnings, error) { doQueryRange := func(q string, rng Range) func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
return promAPI.QueryRange(context.Background(), q, rng) return promAPI.QueryRange(context.Background(), q, rng)
} }
} }
doSeries := func(matcher string, startTime time.Time, endTime time.Time) func() (interface{}, api.Warnings, error) { doSeries := func(matcher string, startTime time.Time, endTime time.Time) func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
return promAPI.Series(context.Background(), []string{matcher}, startTime, endTime) return promAPI.Series(context.Background(), []string{matcher}, startTime, endTime)
} }
} }
doSnapshot := func(skipHead bool) func() (interface{}, api.Warnings, error) { doSnapshot := func(skipHead bool) func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
v, err := promAPI.Snapshot(context.Background(), skipHead) v, err := promAPI.Snapshot(context.Background(), skipHead)
return v, nil, err return v, nil, err
} }
} }
doRules := func() func() (interface{}, api.Warnings, error) { doRules := func() func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
v, err := promAPI.Rules(context.Background()) v, err := promAPI.Rules(context.Background())
return v, nil, err return v, nil, err
} }
} }
doTargets := func() func() (interface{}, api.Warnings, error) { doTargets := func() func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
v, err := promAPI.Targets(context.Background()) v, err := promAPI.Targets(context.Background())
return v, nil, err return v, nil, err
} }
} }
doTargetsMetadata := func(matchTarget string, metric string, limit string) func() (interface{}, api.Warnings, error) { doTargetsMetadata := func(matchTarget string, metric string, limit string) func() (interface{}, Warnings, error) {
return func() (interface{}, api.Warnings, error) { return func() (interface{}, Warnings, error) {
v, err := promAPI.TargetsMetadata(context.Background(), matchTarget, metric, limit) v, err := promAPI.TargetsMetadata(context.Background(), matchTarget, metric, limit)
return v, nil, err return v, nil, err
} }
@ -855,7 +864,7 @@ func TestAPIs(t *testing.T) {
for i, test := range tests { for i, test := range tests {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
client.curTest = test tc.curTest = test
res, warnings, err := test.do() res, warnings, err := test.do()
@ -900,14 +909,14 @@ type apiClientTest struct {
response interface{} response interface{}
expectedBody string expectedBody string
expectedErr *Error expectedErr *Error
expectedWarnings api.Warnings expectedWarnings Warnings
} }
func (c *testClient) URL(ep string, args map[string]string) *url.URL { func (c *testClient) URL(ep string, args map[string]string) *url.URL {
return nil return nil
} }
func (c *testClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, api.Warnings, error) { func (c *testClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, error) {
if ctx == nil { if ctx == nil {
c.Fatalf("context was not passed down") c.Fatalf("context was not passed down")
} }
@ -934,7 +943,7 @@ func (c *testClient) Do(ctx context.Context, req *http.Request) (*http.Response,
StatusCode: test.code, StatusCode: test.code,
} }
return resp, b, test.expectedWarnings, nil return resp, b, nil
} }
func TestAPIClientDo(t *testing.T) { func TestAPIClientDo(t *testing.T) {
@ -1065,7 +1074,9 @@ func TestAPIClientDo(t *testing.T) {
ch: make(chan apiClientTest, 1), ch: make(chan apiClientTest, 1),
req: &http.Request{}, req: &http.Request{},
} }
client := &apiClient{tc} client := &apiClientImpl{
client: tc,
}
for i, test := range tests { for i, test := range tests {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
@ -1209,3 +1220,113 @@ func TestSamplesJsonSerialization(t *testing.T) {
}) })
} }
} }
type httpTestClient struct {
client http.Client
}
func (c *httpTestClient) URL(ep string, args map[string]string) *url.URL {
return nil
}
func (c *httpTestClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, error) {
resp, err := c.client.Do(req)
if err != nil {
return nil, nil, err
}
var body []byte
done := make(chan struct{})
go func() {
body, err = ioutil.ReadAll(resp.Body)
close(done)
}()
select {
case <-ctx.Done():
<-done
err = resp.Body.Close()
if err == nil {
err = ctx.Err()
}
case <-done:
}
return resp, body, err
}
func TestDoGetFallback(t *testing.T) {
v := url.Values{"a": []string{"1", "2"}}
type testResponse struct {
Values string
Method string
}
// Start a local HTTP server.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.ParseForm()
testResp, _ := json.Marshal(&testResponse{
Values: req.Form.Encode(),
Method: req.Method,
})
apiResp := &apiResponse{
Data: testResp,
}
body, _ := json.Marshal(apiResp)
if req.Method == http.MethodPost {
if req.URL.Path == "/blockPost" {
http.Error(w, string(body), http.StatusMethodNotAllowed)
return
}
}
w.Write(body)
}))
// Close the server when test finishes.
defer server.Close()
u, err := url.Parse(server.URL)
if err != nil {
t.Fatal(err)
}
client := &httpTestClient{client: *(server.Client())}
api := &apiClientImpl{
client: client,
}
// Do a post, and ensure that the post succeeds.
_, b, _, err := api.DoGetFallback(context.TODO(), u, v)
if err != nil {
t.Fatalf("Error doing local request: %v", err)
}
resp := &testResponse{}
if err := json.Unmarshal(b, resp); err != nil {
t.Fatal(err)
}
if resp.Method != http.MethodPost {
t.Fatalf("Mismatch method")
}
if resp.Values != v.Encode() {
t.Fatalf("Mismatch in values")
}
// Do a fallbcak to a get.
u.Path = "/blockPost"
_, b, _, err = api.DoGetFallback(context.TODO(), u, v)
if err != nil {
t.Fatalf("Error doing local request: %v", err)
}
if err := json.Unmarshal(b, resp); err != nil {
t.Fatal(err)
}
if resp.Method != http.MethodGet {
t.Fatalf("Mismatch method")
}
if resp.Values != v.Encode() {
t.Fatalf("Mismatch in values")
}
}