From bb45bfcdecac35b6afa09673bd17a780b258a541 Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Mon, 6 Jun 2016 15:27:44 -0700 Subject: [PATCH 1/4] add new interface Extractor for making token extraction pluggable --- request/extractor.go | 75 +++++++++++++++++++++++++++++++++++++++++ request/oauth2.go | 25 ++++++++++++++ request/request.go | 31 +++++------------ request/request_test.go | 2 +- 4 files changed, 109 insertions(+), 24 deletions(-) create mode 100644 request/extractor.go create mode 100644 request/oauth2.go diff --git a/request/extractor.go b/request/extractor.go new file mode 100644 index 0000000..57ea2dd --- /dev/null +++ b/request/extractor.go @@ -0,0 +1,75 @@ +package request + +import ( + "errors" + "net/http" +) + +// Errors +var ( + ErrNoTokenInRequest = errors.New("no token present in request") +) + +// Interface for extracting a token from an HTTP request +type Extractor interface { + ExtractToken(*http.Request) (string, error) +} + +// Extract token from headers +type HeaderExtractor []string + +func (e HeaderExtractor) ExtractToken(req *http.Request) (string, error) { + // loop over header names and return the first one that contains data + for _, header := range e { + if ah := req.Header.Get(header); ah != "" { + return ah, nil + } + } + return "", ErrNoTokenInRequest +} + +// Extract token from request args +type ArgumentExtractor []string + +func (e ArgumentExtractor) ExtractToken(req *http.Request) (string, error) { + // Make sure form is parsed + req.ParseMultipartForm(10e6) + + // loop over arg names and return the first one that contains data + for _, arg := range e { + if ah := req.Form.Get(arg); ah != "" { + return ah, nil + } + } + + return "", ErrNoTokenInRequest +} + +// Tries extractors in order until one works or an error occurs +type MultiExtractor []Extractor + +func (e MultiExtractor) ExtractToken(req *http.Request) (string, error) { + // loop over header names and return the first one that contains data + for _, extractor := range e { + if tok, err := extractor.ExtractToken(req); tok != "" { + return tok, nil + } else if err != ErrNoTokenInRequest { + return "", err + } + } + return "", ErrNoTokenInRequest +} + +// Wrap an Extractor in this to post-process the value before it's handed off +type PostExtractionFilter struct { + Extractor + Filter func(string) (string, error) +} + +func (e *PostExtractionFilter) ExtractToken(req *http.Request) (string, error) { + if tok, err := e.Extractor.ExtractToken(req); tok != "" { + return e.Filter(tok) + } else { + return "", err + } +} diff --git a/request/oauth2.go b/request/oauth2.go new file mode 100644 index 0000000..bbb6ac2 --- /dev/null +++ b/request/oauth2.go @@ -0,0 +1,25 @@ +package request + +import ( + "strings" +) + +// Extract Authorization header and strip 'Bearer ' from it +var AuthorizationHeaderExtractor = &PostExtractionFilter{ + HeaderExtractor{"Authorization"}, + func(tok string) (string, error) { + // Should be a bearer token + if len(tok) > 6 && strings.ToUpper(tok[0:7]) == "BEARER " { + return tok[7:], nil + } + return tok, nil + }, +} + +// Extractor for OAuth2 access tokens +var OAuth2Extractor = &MultiExtractor{ + // Look for authorization token first + AuthorizationHeaderExtractor, + // Extract access_token from form or GET argument + &ArgumentExtractor{"access_token"}, +} diff --git a/request/request.go b/request/request.go index 98e1be2..9daaa48 100644 --- a/request/request.go +++ b/request/request.go @@ -1,39 +1,24 @@ package request import ( - "errors" "github.com/dgrijalva/jwt-go" "net/http" - "strings" -) - -// Errors -var ( - ErrNoTokenInRequest = errors.New("no token present in request") ) // Try to find the token in an http.Request. // This method will call ParseMultipartForm if there's no token in the header. // Currently, it looks in the Authorization header as well as // looking for an 'access_token' request parameter in req.Form. -func ParseFromRequest(req *http.Request, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) { - return ParseFromRequestWithClaims(req, jwt.MapClaims{}, keyFunc) +func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) { + return ParseFromRequestWithClaims(req, extractor, jwt.MapClaims{}, keyFunc) } -func ParseFromRequestWithClaims(req *http.Request, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) { - // Look for an Authorization header - if ah := req.Header.Get("Authorization"); ah != "" { - // Should be a bearer token - if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " { - return jwt.ParseWithClaims(ah[7:], claims, keyFunc) - } +func ParseFromRequestWithClaims(req *http.Request, extractor Extractor, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) { + // Extract token from request + tokStr, err := extractor.ExtractToken(req) + if err != nil { + return nil, err } - // Look for "access_token" parameter - req.ParseMultipartForm(10e6) - if tokStr := req.Form.Get("access_token"); tokStr != "" { - return jwt.ParseWithClaims(tokStr, claims, keyFunc) - } - - return nil, ErrNoTokenInRequest + return jwt.ParseWithClaims(tokStr, claims, keyFunc) } diff --git a/request/request_test.go b/request/request_test.go index 6b09521..504e1bc 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -65,7 +65,7 @@ func TestParseRequest(t *testing.T) { r.Header.Set(k, tokenString) } } - token, err := ParseFromRequestWithClaims(r, jwt.MapClaims{}, keyfunc) + token, err := ParseFromRequestWithClaims(r, OAuth2Extractor, jwt.MapClaims{}, keyfunc) if token == nil { t.Errorf("[%v] Token was not found: %v", data.name, err) From de0a819d8dadfba7a8f666d3c80b72c7a9a2bd45 Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Mon, 6 Jun 2016 16:55:41 -0700 Subject: [PATCH 2/4] added documentation --- request/doc.go | 7 +++++++ request/extractor.go | 15 ++++++++++----- request/oauth2.go | 6 ++++-- request/request.go | 16 ++++++++-------- 4 files changed, 29 insertions(+), 15 deletions(-) create mode 100644 request/doc.go diff --git a/request/doc.go b/request/doc.go new file mode 100644 index 0000000..c01069c --- /dev/null +++ b/request/doc.go @@ -0,0 +1,7 @@ +// Utility package for extracting JWT tokens from +// HTTP requests. +// +// The main function is ParseFromRequest and it's WithClaims variant. +// See examples for how to use the various Extractor implementations +// or roll your own. +package request diff --git a/request/extractor.go b/request/extractor.go index 57ea2dd..61c330a 100644 --- a/request/extractor.go +++ b/request/extractor.go @@ -10,12 +10,15 @@ var ( ErrNoTokenInRequest = errors.New("no token present in request") ) -// Interface for extracting a token from an HTTP request +// Interface for extracting a token from an HTTP request. +// The ExtractToken method should return a token string or an error. +// If no token is present, you must return ErrNoTokenInRequest. type Extractor interface { ExtractToken(*http.Request) (string, error) } -// Extract token from headers +// Extractor for finding a token in a header. Looks at each specified +// header in order until there's a match type HeaderExtractor []string func (e HeaderExtractor) ExtractToken(req *http.Request) (string, error) { @@ -28,7 +31,8 @@ func (e HeaderExtractor) ExtractToken(req *http.Request) (string, error) { return "", ErrNoTokenInRequest } -// Extract token from request args +// Extract token from request arguments. This includes a POSTed form or +// GET URL arguments. Argument names are tried in order until there's a match. type ArgumentExtractor []string func (e ArgumentExtractor) ExtractToken(req *http.Request) (string, error) { @@ -45,7 +49,7 @@ func (e ArgumentExtractor) ExtractToken(req *http.Request) (string, error) { return "", ErrNoTokenInRequest } -// Tries extractors in order until one works or an error occurs +// Tries Extractors in order until one returns a token string or an error occurs type MultiExtractor []Extractor func (e MultiExtractor) ExtractToken(req *http.Request) (string, error) { @@ -60,7 +64,8 @@ func (e MultiExtractor) ExtractToken(req *http.Request) (string, error) { return "", ErrNoTokenInRequest } -// Wrap an Extractor in this to post-process the value before it's handed off +// Wrap an Extractor in this to post-process the value before it's handed off. +// See AuthorizationHeaderExtractor for an example type PostExtractionFilter struct { Extractor Filter func(string) (string, error) diff --git a/request/oauth2.go b/request/oauth2.go index bbb6ac2..0f7fa02 100644 --- a/request/oauth2.go +++ b/request/oauth2.go @@ -4,7 +4,8 @@ import ( "strings" ) -// Extract Authorization header and strip 'Bearer ' from it +// Extract bearer token from Authorization header +// Uses PostExtractionFilter to strip "Bearer " prefix from header var AuthorizationHeaderExtractor = &PostExtractionFilter{ HeaderExtractor{"Authorization"}, func(tok string) (string, error) { @@ -16,7 +17,8 @@ var AuthorizationHeaderExtractor = &PostExtractionFilter{ }, } -// Extractor for OAuth2 access tokens +// Extractor for OAuth2 access tokens. Looks in 'Authorization' +// header then 'access_token' argument for a token. var OAuth2Extractor = &MultiExtractor{ // Look for authorization token first AuthorizationHeaderExtractor, diff --git a/request/request.go b/request/request.go index 9daaa48..1807b39 100644 --- a/request/request.go +++ b/request/request.go @@ -5,20 +5,20 @@ import ( "net/http" ) -// Try to find the token in an http.Request. -// This method will call ParseMultipartForm if there's no token in the header. -// Currently, it looks in the Authorization header as well as -// looking for an 'access_token' request parameter in req.Form. +// Extract and parse a JWT token from an HTTP request. +// This behaves the same as Parse, but accepts a request and an extractor +// instead of a token string. The Extractor interface allows you to define +// the logic for extracting a token. Several useful implementations are provided. func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) { return ParseFromRequestWithClaims(req, extractor, jwt.MapClaims{}, keyFunc) } +// ParseFromRequest but with custom Claims type func ParseFromRequestWithClaims(req *http.Request, extractor Extractor, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) { // Extract token from request - tokStr, err := extractor.ExtractToken(req) - if err != nil { + if tokStr, err := extractor.ExtractToken(req); err == nil { + return jwt.ParseWithClaims(tokStr, claims, keyFunc) + } else { return nil, err } - - return jwt.ParseWithClaims(tokStr, claims, keyFunc) } From f93fcfd3f91f046d0bb0d2f537083864f59725ac Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Mon, 6 Jun 2016 17:18:24 -0700 Subject: [PATCH 3/4] unit tests for extractor --- request/extractor_test.go | 86 +++++++++++++++++++++++++++++++++++++++ request/oauth2.go | 2 +- request/request_test.go | 31 +++++++++++--- 3 files changed, 112 insertions(+), 7 deletions(-) create mode 100644 request/extractor_test.go diff --git a/request/extractor_test.go b/request/extractor_test.go new file mode 100644 index 0000000..ecfd122 --- /dev/null +++ b/request/extractor_test.go @@ -0,0 +1,86 @@ +package request + +import ( + "fmt" + "net/http" + "net/url" + "testing" +) + +var extractorTestTokenA = "A" +var extractorTestTokenB = "B" + +var extractorTestData = []struct { + name string + extractor Extractor + headers map[string]string + query url.Values + token string + err error +}{ + { + name: "simple header", + extractor: HeaderExtractor{"Foo"}, + headers: map[string]string{"Foo": extractorTestTokenA}, + query: nil, + token: extractorTestTokenA, + err: nil, + }, + { + name: "simple argument", + extractor: ArgumentExtractor{"token"}, + headers: map[string]string{}, + query: url.Values{"token": {extractorTestTokenA}}, + token: extractorTestTokenA, + err: nil, + }, + { + name: "multiple extractors", + extractor: MultiExtractor{ + HeaderExtractor{"Foo"}, + ArgumentExtractor{"token"}, + }, + headers: map[string]string{"Foo": extractorTestTokenA}, + query: url.Values{"token": {extractorTestTokenB}}, + token: extractorTestTokenA, + err: nil, + }, + { + name: "simple miss", + extractor: HeaderExtractor{"This-Header-Is-Not-Set"}, + headers: map[string]string{"Foo": extractorTestTokenA}, + query: nil, + token: "", + err: ErrNoTokenInRequest, + }, + { + name: "filter", + extractor: AuthorizationHeaderExtractor, + headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenA}, + query: nil, + token: extractorTestTokenA, + err: nil, + }, +} + +func TestExtractor(t *testing.T) { + // Bearer token request + for _, data := range extractorTestData { + // Make request from test struct + r, _ := http.NewRequest("GET", fmt.Sprintf("/?%v", data.query.Encode()), nil) + for k, v := range data.headers { + r.Header.Set(k, v) + } + + // Test extractor + token, err := data.extractor.ExtractToken(r) + if token != data.token { + t.Errorf("[%v] Expected token '%v'. Got '%v'", data.name, data.token, token) + continue + } + if err != data.err { + t.Errorf("[%v] Expected error '%v'. Got '%v'", data.name, data.err, err) + continue + } + } +} diff --git a/request/oauth2.go b/request/oauth2.go index 0f7fa02..c84cbd7 100644 --- a/request/oauth2.go +++ b/request/oauth2.go @@ -23,5 +23,5 @@ var OAuth2Extractor = &MultiExtractor{ // Look for authorization token first AuthorizationHeaderExtractor, // Extract access_token from form or GET argument - &ArgumentExtractor{"access_token"}, + ArgumentExtractor{"access_token"}, } diff --git a/request/request_test.go b/request/request_test.go index 504e1bc..b4365cd 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -12,15 +12,25 @@ import ( ) var requestTestData = []struct { - name string - claims jwt.MapClaims - headers map[string]string - query url.Values - valid bool + name string + claims jwt.MapClaims + extractor Extractor + headers map[string]string + query url.Values + valid bool }{ + { + "authorization bearer token", + jwt.MapClaims{"foo": "bar"}, + AuthorizationHeaderExtractor, + map[string]string{"Authorization": "Bearer %v"}, + url.Values{}, + true, + }, { "oauth bearer token - header", jwt.MapClaims{"foo": "bar"}, + OAuth2Extractor, map[string]string{"Authorization": "Bearer %v"}, url.Values{}, true, @@ -28,10 +38,19 @@ var requestTestData = []struct { { "oauth bearer token - url", jwt.MapClaims{"foo": "bar"}, + OAuth2Extractor, map[string]string{}, url.Values{"access_token": {"%v"}}, true, }, + { + "url token", + jwt.MapClaims{"foo": "bar"}, + ArgumentExtractor{"token"}, + map[string]string{}, + url.Values{"token": {"%v"}}, + true, + }, } func TestParseRequest(t *testing.T) { @@ -65,7 +84,7 @@ func TestParseRequest(t *testing.T) { r.Header.Set(k, tokenString) } } - token, err := ParseFromRequestWithClaims(r, OAuth2Extractor, jwt.MapClaims{}, keyfunc) + token, err := ParseFromRequestWithClaims(r, data.extractor, jwt.MapClaims{}, keyfunc) if token == nil { t.Errorf("[%v] Token was not found: %v", data.name, err) From 5eef21b7edec11f3819794484eee852323a7e204 Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Mon, 6 Jun 2016 17:45:30 -0700 Subject: [PATCH 4/4] a few examples and some documentation cleanup --- request/extractor_example_test.go | 32 +++++++++++++++++++++++++++++++ request/extractor_test.go | 13 +++++++++---- request/oauth2.go | 19 +++++++++--------- 3 files changed, 51 insertions(+), 13 deletions(-) create mode 100644 request/extractor_example_test.go diff --git a/request/extractor_example_test.go b/request/extractor_example_test.go new file mode 100644 index 0000000..a994ffe --- /dev/null +++ b/request/extractor_example_test.go @@ -0,0 +1,32 @@ +package request + +import ( + "fmt" + "net/url" +) + +const ( + exampleTokenA = "A" +) + +func ExampleHeaderExtractor() { + req := makeExampleRequest("GET", "/", map[string]string{"Token": exampleTokenA}, nil) + tokenString, err := HeaderExtractor{"Token"}.ExtractToken(req) + if err == nil { + fmt.Println(tokenString) + } else { + fmt.Println(err) + } + //Output: A +} + +func ExampleArgumentExtractor() { + req := makeExampleRequest("GET", "/", nil, url.Values{"token": {extractorTestTokenA}}) + tokenString, err := ArgumentExtractor{"token"}.ExtractToken(req) + if err == nil { + fmt.Println(tokenString) + } else { + fmt.Println(err) + } + //Output: A +} diff --git a/request/extractor_test.go b/request/extractor_test.go index ecfd122..e3bbb0a 100644 --- a/request/extractor_test.go +++ b/request/extractor_test.go @@ -67,10 +67,7 @@ func TestExtractor(t *testing.T) { // Bearer token request for _, data := range extractorTestData { // Make request from test struct - r, _ := http.NewRequest("GET", fmt.Sprintf("/?%v", data.query.Encode()), nil) - for k, v := range data.headers { - r.Header.Set(k, v) - } + r := makeExampleRequest("GET", "/", data.headers, data.query) // Test extractor token, err := data.extractor.ExtractToken(r) @@ -84,3 +81,11 @@ func TestExtractor(t *testing.T) { } } } + +func makeExampleRequest(method, path string, headers map[string]string, urlArgs url.Values) *http.Request { + r, _ := http.NewRequest(method, fmt.Sprintf("%v?%v", path, urlArgs.Encode()), nil) + for k, v := range headers { + r.Header.Set(k, v) + } + return r +} diff --git a/request/oauth2.go b/request/oauth2.go index c84cbd7..5948694 100644 --- a/request/oauth2.go +++ b/request/oauth2.go @@ -4,24 +4,25 @@ import ( "strings" ) +// Strips 'Bearer ' prefix from bearer token string +func stripBearerPrefixFromTokenString(tok string) (string, error) { + // Should be a bearer token + if len(tok) > 6 && strings.ToUpper(tok[0:7]) == "BEARER " { + return tok[7:], nil + } + return tok, nil +} + // Extract bearer token from Authorization header // Uses PostExtractionFilter to strip "Bearer " prefix from header var AuthorizationHeaderExtractor = &PostExtractionFilter{ HeaderExtractor{"Authorization"}, - func(tok string) (string, error) { - // Should be a bearer token - if len(tok) > 6 && strings.ToUpper(tok[0:7]) == "BEARER " { - return tok[7:], nil - } - return tok, nil - }, + stripBearerPrefixFromTokenString, } // Extractor for OAuth2 access tokens. Looks in 'Authorization' // header then 'access_token' argument for a token. var OAuth2Extractor = &MultiExtractor{ - // Look for authorization token first AuthorizationHeaderExtractor, - // Extract access_token from form or GET argument ArgumentExtractor{"access_token"}, }