diff --git a/request/extractor_test.go b/request/extractor_test.go new file mode 100644 index 0000000..ecfd122 --- /dev/null +++ b/request/extractor_test.go @@ -0,0 +1,86 @@ +package request + +import ( + "fmt" + "net/http" + "net/url" + "testing" +) + +var extractorTestTokenA = "A" +var extractorTestTokenB = "B" + +var extractorTestData = []struct { + name string + extractor Extractor + headers map[string]string + query url.Values + token string + err error +}{ + { + name: "simple header", + extractor: HeaderExtractor{"Foo"}, + headers: map[string]string{"Foo": extractorTestTokenA}, + query: nil, + token: extractorTestTokenA, + err: nil, + }, + { + name: "simple argument", + extractor: ArgumentExtractor{"token"}, + headers: map[string]string{}, + query: url.Values{"token": {extractorTestTokenA}}, + token: extractorTestTokenA, + err: nil, + }, + { + name: "multiple extractors", + extractor: MultiExtractor{ + HeaderExtractor{"Foo"}, + ArgumentExtractor{"token"}, + }, + headers: map[string]string{"Foo": extractorTestTokenA}, + query: url.Values{"token": {extractorTestTokenB}}, + token: extractorTestTokenA, + err: nil, + }, + { + name: "simple miss", + extractor: HeaderExtractor{"This-Header-Is-Not-Set"}, + headers: map[string]string{"Foo": extractorTestTokenA}, + query: nil, + token: "", + err: ErrNoTokenInRequest, + }, + { + name: "filter", + extractor: AuthorizationHeaderExtractor, + headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenA}, + query: nil, + token: extractorTestTokenA, + err: nil, + }, +} + +func TestExtractor(t *testing.T) { + // Bearer token request + for _, data := range extractorTestData { + // Make request from test struct + r, _ := http.NewRequest("GET", fmt.Sprintf("/?%v", data.query.Encode()), nil) + for k, v := range data.headers { + r.Header.Set(k, v) + } + + // Test extractor + token, err := data.extractor.ExtractToken(r) + if token != data.token { + t.Errorf("[%v] Expected token '%v'. Got '%v'", data.name, data.token, token) + continue + } + if err != data.err { + t.Errorf("[%v] Expected error '%v'. Got '%v'", data.name, data.err, err) + continue + } + } +} diff --git a/request/oauth2.go b/request/oauth2.go index 0f7fa02..c84cbd7 100644 --- a/request/oauth2.go +++ b/request/oauth2.go @@ -23,5 +23,5 @@ var OAuth2Extractor = &MultiExtractor{ // Look for authorization token first AuthorizationHeaderExtractor, // Extract access_token from form or GET argument - &ArgumentExtractor{"access_token"}, + ArgumentExtractor{"access_token"}, } diff --git a/request/request_test.go b/request/request_test.go index 504e1bc..b4365cd 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -12,15 +12,25 @@ import ( ) var requestTestData = []struct { - name string - claims jwt.MapClaims - headers map[string]string - query url.Values - valid bool + name string + claims jwt.MapClaims + extractor Extractor + headers map[string]string + query url.Values + valid bool }{ + { + "authorization bearer token", + jwt.MapClaims{"foo": "bar"}, + AuthorizationHeaderExtractor, + map[string]string{"Authorization": "Bearer %v"}, + url.Values{}, + true, + }, { "oauth bearer token - header", jwt.MapClaims{"foo": "bar"}, + OAuth2Extractor, map[string]string{"Authorization": "Bearer %v"}, url.Values{}, true, @@ -28,10 +38,19 @@ var requestTestData = []struct { { "oauth bearer token - url", jwt.MapClaims{"foo": "bar"}, + OAuth2Extractor, map[string]string{}, url.Values{"access_token": {"%v"}}, true, }, + { + "url token", + jwt.MapClaims{"foo": "bar"}, + ArgumentExtractor{"token"}, + map[string]string{}, + url.Values{"token": {"%v"}}, + true, + }, } func TestParseRequest(t *testing.T) { @@ -65,7 +84,7 @@ func TestParseRequest(t *testing.T) { r.Header.Set(k, tokenString) } } - token, err := ParseFromRequestWithClaims(r, OAuth2Extractor, jwt.MapClaims{}, keyfunc) + token, err := ParseFromRequestWithClaims(r, data.extractor, jwt.MapClaims{}, keyfunc) if token == nil { t.Errorf("[%v] Token was not found: %v", data.name, err)