diff --git a/request/extractor_example_test.go b/request/extractor_example_test.go new file mode 100644 index 0000000..a994ffe --- /dev/null +++ b/request/extractor_example_test.go @@ -0,0 +1,32 @@ +package request + +import ( + "fmt" + "net/url" +) + +const ( + exampleTokenA = "A" +) + +func ExampleHeaderExtractor() { + req := makeExampleRequest("GET", "/", map[string]string{"Token": exampleTokenA}, nil) + tokenString, err := HeaderExtractor{"Token"}.ExtractToken(req) + if err == nil { + fmt.Println(tokenString) + } else { + fmt.Println(err) + } + //Output: A +} + +func ExampleArgumentExtractor() { + req := makeExampleRequest("GET", "/", nil, url.Values{"token": {extractorTestTokenA}}) + tokenString, err := ArgumentExtractor{"token"}.ExtractToken(req) + if err == nil { + fmt.Println(tokenString) + } else { + fmt.Println(err) + } + //Output: A +} diff --git a/request/extractor_test.go b/request/extractor_test.go index ecfd122..e3bbb0a 100644 --- a/request/extractor_test.go +++ b/request/extractor_test.go @@ -67,10 +67,7 @@ func TestExtractor(t *testing.T) { // Bearer token request for _, data := range extractorTestData { // Make request from test struct - r, _ := http.NewRequest("GET", fmt.Sprintf("/?%v", data.query.Encode()), nil) - for k, v := range data.headers { - r.Header.Set(k, v) - } + r := makeExampleRequest("GET", "/", data.headers, data.query) // Test extractor token, err := data.extractor.ExtractToken(r) @@ -84,3 +81,11 @@ func TestExtractor(t *testing.T) { } } } + +func makeExampleRequest(method, path string, headers map[string]string, urlArgs url.Values) *http.Request { + r, _ := http.NewRequest(method, fmt.Sprintf("%v?%v", path, urlArgs.Encode()), nil) + for k, v := range headers { + r.Header.Set(k, v) + } + return r +} diff --git a/request/oauth2.go b/request/oauth2.go index c84cbd7..5948694 100644 --- a/request/oauth2.go +++ b/request/oauth2.go @@ -4,24 +4,25 @@ import ( "strings" ) +// Strips 'Bearer ' prefix from bearer token string +func stripBearerPrefixFromTokenString(tok string) (string, error) { + // Should be a bearer token + if len(tok) > 6 && strings.ToUpper(tok[0:7]) == "BEARER " { + return tok[7:], nil + } + return tok, nil +} + // Extract bearer token from Authorization header // Uses PostExtractionFilter to strip "Bearer " prefix from header var AuthorizationHeaderExtractor = &PostExtractionFilter{ HeaderExtractor{"Authorization"}, - func(tok string) (string, error) { - // Should be a bearer token - if len(tok) > 6 && strings.ToUpper(tok[0:7]) == "BEARER " { - return tok[7:], nil - } - return tok, nil - }, + stripBearerPrefixFromTokenString, } // Extractor for OAuth2 access tokens. Looks in 'Authorization' // header then 'access_token' argument for a token. var OAuth2Extractor = &MultiExtractor{ - // Look for authorization token first AuthorizationHeaderExtractor, - // Extract access_token from form or GET argument ArgumentExtractor{"access_token"}, }