diff --git a/api/client.go b/api/client.go index 1aa1591..a6ef8f7 100644 --- a/api/client.go +++ b/api/client.go @@ -23,19 +23,10 @@ import ( "path" "strings" "time" - - "golang.org/x/net/context/ctxhttp" ) -// CancelableTransport is like net.Transport but provides -// per-request cancelation functionality. -type CancelableTransport interface { - http.RoundTripper - CancelRequest(req *http.Request) -} - -// DefaultTransport is used if no Transport is set in Config. -var DefaultTransport CancelableTransport = &http.Transport{ +// DefaultRoundTripper is used if no RoundTripper is set in Config. +var DefaultRoundTripper http.RoundTripper = &http.Transport{ Proxy: http.ProxyFromEnvironment, Dial: (&net.Dialer{ Timeout: 30 * time.Second, @@ -49,16 +40,16 @@ type Config struct { // The address of the Prometheus to connect to. Address string - // Transport is used by the Client to drive HTTP requests. If not - // provided, DefaultTransport will be used. - Transport CancelableTransport + // RoundTripper is used by the Client to drive HTTP requests. If not + // provided, DefaultRoundTripper will be used. + RoundTripper http.RoundTripper } -func (cfg *Config) transport() CancelableTransport { - if cfg.Transport == nil { - return DefaultTransport +func (cfg *Config) roundTripper() http.RoundTripper { + if cfg.RoundTripper == nil { + return DefaultRoundTripper } - return cfg.Transport + return cfg.RoundTripper } // Client is the interface for an API client. @@ -78,14 +69,14 @@ func New(cfg Config) (Client, error) { u.Path = strings.TrimRight(u.Path, "/") return &httpClient{ - endpoint: u, - transport: cfg.transport(), + endpoint: u, + client: http.Client{Transport: cfg.roundTripper()}, }, nil } type httpClient struct { - endpoint *url.URL - transport CancelableTransport + endpoint *url.URL + client http.Client } func (c *httpClient) URL(ep string, args map[string]string) *url.URL { @@ -103,8 +94,10 @@ func (c *httpClient) URL(ep string, args map[string]string) *url.URL { } func (c *httpClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, error) { - resp, err := ctxhttp.Do(ctx, &http.Client{Transport: c.transport}, req) - + if ctx != nil { + req = req.WithContext(ctx) + } + resp, err := c.client.Do(req) defer func() { if resp != nil { resp.Body.Close() diff --git a/api/client_test.go b/api/client_test.go index a068e2f..8db4f76 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -14,14 +14,15 @@ package api import ( + "net/http" "net/url" "testing" ) func TestConfig(t *testing.T) { c := Config{} - if c.transport() != DefaultTransport { - t.Fatalf("expected default transport for nil Transport field") + if c.roundTripper() != DefaultRoundTripper { + t.Fatalf("expected default roundtripper for nil RoundTripper field") } } @@ -99,8 +100,8 @@ func TestClientURL(t *testing.T) { } hclient := &httpClient{ - endpoint: ep, - transport: DefaultTransport, + endpoint: ep, + client: http.Client{Transport: DefaultRoundTripper}, } u := hclient.URL(test.endpoint, test.args)