diff --git a/errors.go b/errors.go index b055f3b..6a60e91 100644 --- a/errors.go +++ b/errors.go @@ -6,9 +6,8 @@ import ( // Error constants var ( - ErrInvalidKey = errors.New("key is invalid or of invalid type") - ErrHashUnavailable = errors.New("the requested hash function is unavailable") - ErrNoTokenInRequest = errors.New("no token present in request") + ErrInvalidKey = errors.New("key is invalid or of invalid type") + ErrHashUnavailable = errors.New("the requested hash function is unavailable") ) // The errors that might occur when parsing and validating a token diff --git a/parser_test.go b/parser_test.go index a4eb83f..622a423 100644 --- a/parser_test.go +++ b/parser_test.go @@ -4,13 +4,12 @@ import ( "crypto/rsa" "encoding/json" "fmt" - "io/ioutil" - "net/http" "reflect" "testing" "time" "github.com/dgrijalva/jwt-go" + "github.com/dgrijalva/jwt-go/test" ) var ( @@ -21,6 +20,10 @@ var ( nilKeyFunc jwt.Keyfunc = nil ) +func init() { + jwtTestDefaultKey = test.LoadRSAPublicKeyFromDisk("test/sample_key.pub") +} + var jwtTestData = []struct { name string tokenString string @@ -142,42 +145,14 @@ var jwtTestData = []struct { }, } -func init() { - if keyData, e := ioutil.ReadFile("test/sample_key.pub"); e == nil { - if jwtTestDefaultKey, e = jwt.ParseRSAPublicKeyFromPEM(keyData); e != nil { - panic(e) - } - } else { - panic(e) - } -} - -func makeSample(c jwt.Claims) string { - keyData, e := ioutil.ReadFile("test/sample_key") - if e != nil { - panic(e.Error()) - } - key, e := jwt.ParseRSAPrivateKeyFromPEM(keyData) - if e != nil { - panic(e.Error()) - } - - token := jwt.NewWithClaims(jwt.SigningMethodRS256, c) - s, e := token.SignedString(key) - - if e != nil { - panic(e.Error()) - } - - return s -} - func TestParser_Parse(t *testing.T) { + privateKey := test.LoadRSAPrivateKeyFromDisk("test/sample_key") + // Iterate over test data set and run tests for _, data := range jwtTestData { // If the token string is blank, use helper function to generate string if data.tokenString == "" { - data.tokenString = makeSample(data.claims) + data.tokenString = test.MakeSampleToken(data.claims, privateKey) } // Parse the token @@ -224,39 +199,6 @@ func TestParser_Parse(t *testing.T) { } } -func TestParseRequest(t *testing.T) { - // Bearer token request - for _, data := range jwtTestData { - // FIXME: custom parsers are not supported by this helper. skip tests that require them - if data.parser != nil { - t.Logf("Skipping [%v]. Custom parsers are not supported by ParseRequest", data.name) - continue - } - - if data.tokenString == "" { - data.tokenString = makeSample(data.claims) - } - - r, _ := http.NewRequest("GET", "/", nil) - r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString)) - token, err := jwt.ParseFromRequestWithClaims(r, data.keyfunc, jwt.MapClaims{}) - - if token == nil { - t.Errorf("[%v] Token was not found: %v", data.name, err) - continue - } - if !reflect.DeepEqual(data.claims, token.Claims) { - t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) - } - if data.valid && err != nil { - t.Errorf("[%v] Error while verifying token: %v", data.name, err) - } - if !data.valid && err == nil { - t.Errorf("[%v] Invalid token passed validation", data.name) - } - } -} - // Helper method for benchmarking various methods func benchmarkSigning(b *testing.B, method jwt.SigningMethod, key interface{}) { t := jwt.New(method) diff --git a/request/request.go b/request/request.go new file mode 100644 index 0000000..c812488 --- /dev/null +++ b/request/request.go @@ -0,0 +1,39 @@ +package request + +import ( + "errors" + "github.com/dgrijalva/jwt-go" + "net/http" + "strings" +) + +// Errors +var ( + ErrNoTokenInRequest = errors.New("no token present in request") +) + +// Try to find the token in an http.Request. +// This method will call ParseMultipartForm if there's no token in the header. +// Currently, it looks in the Authorization header as well as +// looking for an 'access_token' request parameter in req.Form. +func ParseFromRequest(req *http.Request, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) { + return ParseFromRequestWithClaims(req, keyFunc, &jwt.MapClaims{}) +} + +func ParseFromRequestWithClaims(req *http.Request, keyFunc jwt.Keyfunc, claims jwt.Claims) (token *jwt.Token, err error) { + // Look for an Authorization header + if ah := req.Header.Get("Authorization"); ah != "" { + // Should be a bearer token + if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " { + return jwt.ParseWithClaims(ah[7:], keyFunc, claims) + } + } + + // Look for "access_token" parameter + req.ParseMultipartForm(10e6) + if tokStr := req.Form.Get("access_token"); tokStr != "" { + return jwt.ParseWithClaims(tokStr, keyFunc, claims) + } + + return nil, ErrNoTokenInRequest +} diff --git a/request/request_test.go b/request/request_test.go new file mode 100644 index 0000000..0f2fb9b --- /dev/null +++ b/request/request_test.go @@ -0,0 +1,84 @@ +package request + +import ( + "fmt" + "github.com/dgrijalva/jwt-go" + "github.com/dgrijalva/jwt-go/test" + "net/http" + "net/url" + "reflect" + "strings" + "testing" +) + +var requestTestData = []struct { + name string + claims jwt.MapClaims + headers map[string]string + query url.Values + valid bool +}{ + { + "oauth bearer token - header", + jwt.MapClaims{"foo": "bar"}, + map[string]string{"Authorization": "Bearer %v"}, + url.Values{}, + true, + }, + { + "oauth bearer token - url", + jwt.MapClaims{"foo": "bar"}, + map[string]string{}, + url.Values{"access_token": {"%v"}}, + true, + }, +} + +func TestParseRequest(t *testing.T) { + // load keys from disk + privateKey := test.LoadRSAPrivateKeyFromDisk("../test/sample_key") + publicKey := test.LoadRSAPublicKeyFromDisk("../test/sample_key.pub") + keyfunc := func(*jwt.Token) (interface{}, error) { + return publicKey, nil + } + + // Bearer token request + for _, data := range requestTestData { + // Make token from claims + tokenString := test.MakeSampleToken(data.claims, privateKey) + + // Make query string + for k, vv := range data.query { + for i, v := range vv { + if strings.Contains(v, "%v") { + data.query[k][i] = fmt.Sprintf(v, tokenString) + } + } + } + + // Make request from test struct + r, _ := http.NewRequest("GET", fmt.Sprintf("/?%v", data.query.Encode()), nil) + for k, v := range data.headers { + if strings.Contains(v, "%v") { + r.Header.Set(k, fmt.Sprintf(v, tokenString)) + } else { + r.Header.Set(k, tokenString) + } + } + token, err := ParseFromRequestWithClaims(r, keyfunc, jwt.MapClaims{}) + + if token == nil { + t.Errorf("[%v] Token was not found: %v", data.name, err) + continue + } + if !reflect.DeepEqual(data.claims, token.Claims) { + t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) + } + if data.valid && err != nil { + t.Errorf("[%v] Error while verifying token: %v", data.name, err) + } + if !data.valid && err == nil { + t.Errorf("[%v] Invalid token passed validation", data.name) + } + } +} diff --git a/test/helpers.go b/test/helpers.go new file mode 100644 index 0000000..f84c3ef --- /dev/null +++ b/test/helpers.go @@ -0,0 +1,42 @@ +package test + +import ( + "crypto/rsa" + "github.com/dgrijalva/jwt-go" + "io/ioutil" +) + +func LoadRSAPrivateKeyFromDisk(location string) *rsa.PrivateKey { + keyData, e := ioutil.ReadFile(location) + if e != nil { + panic(e.Error()) + } + key, e := jwt.ParseRSAPrivateKeyFromPEM(keyData) + if e != nil { + panic(e.Error()) + } + return key +} + +func LoadRSAPublicKeyFromDisk(location string) *rsa.PublicKey { + keyData, e := ioutil.ReadFile(location) + if e != nil { + panic(e.Error()) + } + key, e := jwt.ParseRSAPublicKeyFromPEM(keyData) + if e != nil { + panic(e.Error()) + } + return key +} + +func MakeSampleToken(c jwt.Claims, key interface{}) string { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, c) + s, e := token.SignedString(key) + + if e != nil { + panic(e.Error()) + } + + return s +} diff --git a/token.go b/token.go index 8252d86..7ac24fa 100644 --- a/token.go +++ b/token.go @@ -3,7 +3,6 @@ package jwt import ( "encoding/base64" "encoding/json" - "net/http" "strings" "time" ) @@ -94,32 +93,6 @@ func ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Claims) (*Token return new(Parser).ParseWithClaims(tokenString, keyFunc, claims) } -// Try to find the token in an http.Request. -// This method will call ParseMultipartForm if there's no token in the header. -// Currently, it looks in the Authorization header as well as -// looking for an 'access_token' request parameter in req.Form. -func ParseFromRequest(req *http.Request, keyFunc Keyfunc) (token *Token, err error) { - return ParseFromRequestWithClaims(req, keyFunc, &MapClaims{}) -} - -func ParseFromRequestWithClaims(req *http.Request, keyFunc Keyfunc, claims Claims) (token *Token, err error) { - // Look for an Authorization header - if ah := req.Header.Get("Authorization"); ah != "" { - // Should be a bearer token - if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " { - return ParseWithClaims(ah[7:], keyFunc, claims) - } - } - - // Look for "access_token" parameter - req.ParseMultipartForm(10e6) - if tokStr := req.Form.Get("access_token"); tokStr != "" { - return ParseWithClaims(tokStr, keyFunc, claims) - } - - return nil, ErrNoTokenInRequest -} - // Encode JWT specific base64url encoding with padding stripped func EncodeSegment(seg []byte) string { return strings.TrimRight(base64.URLEncoding.EncodeToString(seg), "=")