diff --git a/parser.go b/parser.go index 7c0a891..40b02bc 100644 --- a/parser.go +++ b/parser.go @@ -32,21 +32,6 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} } - // Verify signing method is in the required set - if p.ValidMethods != nil { - var signingMethodValid = false - var alg = token.Method.Alg() - for _, m := range p.ValidMethods { - if m == alg { - signingMethodValid = true - break - } - } - if !signingMethodValid { - return nil, &ValidationError{err: fmt.Sprintf("signing method %v is invalid", alg), Errors: ValidationErrorSignatureInvalid} - } - } - // parse Claims var claimBytes []byte if claimBytes, err = DecodeSegment(parts[1]); err != nil { @@ -69,6 +54,22 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { return token, &ValidationError{err: "signing method (alg) is unspecified.", Errors: ValidationErrorUnverifiable} } + // Verify signing method is in the required set + if p.ValidMethods != nil { + var signingMethodValid = false + var alg = token.Method.Alg() + for _, m := range p.ValidMethods { + if m == alg { + signingMethodValid = true + break + } + } + if !signingMethodValid { + // signing method is not in the listed set + return token, &ValidationError{err: fmt.Sprintf("signing method %v is invalid", alg), Errors: ValidationErrorSignatureInvalid} + } + } + // Lookup key var key interface{} if keyFunc == nil { diff --git a/parser_test.go b/parser_test.go index 9108ded..6592aff 100644 --- a/parser_test.go +++ b/parser_test.go @@ -1,6 +1,7 @@ package jwt_test import ( + "encoding/json" "fmt" "github.com/dgrijalva/jwt-go" "io/ioutil" @@ -25,6 +26,7 @@ var jwtTestData = []struct { claims map[string]interface{} valid bool errors uint32 + parser *jwt.Parser }{ { "basic", @@ -33,6 +35,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar"}, true, 0, + nil, }, { "basic expired", @@ -41,6 +44,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, false, jwt.ValidationErrorExpired, + nil, }, { "basic nbf", @@ -49,6 +53,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, false, jwt.ValidationErrorNotValidYet, + nil, }, { "expired and nbf", @@ -57,6 +62,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, false, jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, + nil, }, { "basic invalid", @@ -65,6 +71,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar"}, false, jwt.ValidationErrorSignatureInvalid, + nil, }, { "basic nokeyfunc", @@ -73,6 +80,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar"}, false, jwt.ValidationErrorUnverifiable, + nil, }, { "basic nokey", @@ -81,6 +89,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar"}, false, jwt.ValidationErrorSignatureInvalid, + nil, }, { "basic errorkey", @@ -89,6 +98,25 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar"}, false, jwt.ValidationErrorUnverifiable, + nil, + }, + { + "invalid signing method", + "", + defaultKeyFunc, + map[string]interface{}{"foo": "bar"}, + false, + jwt.ValidationErrorSignatureInvalid, + &jwt.Parser{ValidMethods: []string{"HS256"}}, + }, + { + "JSON Number", + "", + defaultKeyFunc, + map[string]interface{}{"foo": json.Number("123.4")}, + true, + 0, + &jwt.Parser{UseJSONNumber: true}, }, } @@ -116,12 +144,19 @@ func makeSample(c map[string]interface{}) string { return s } -func TestJWT(t *testing.T) { +func TestParser_Parse(t *testing.T) { for _, data := range jwtTestData { if data.tokenString == "" { data.tokenString = makeSample(data.claims) } - token, err := jwt.Parse(data.tokenString, data.keyfunc) + + var token *jwt.Token + var err error + if data.parser != nil { + token, err = data.parser.Parse(data.tokenString, data.keyfunc) + } else { + token, err = jwt.Parse(data.tokenString, data.keyfunc) + } if !reflect.DeepEqual(data.claims, token.Claims) { t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) @@ -137,8 +172,8 @@ func TestJWT(t *testing.T) { t.Errorf("[%v] Expecting error. Didn't get one.", data.name) } else { // compare the bitfield part of the error - if err.(*jwt.ValidationError).Errors != data.errors { - t.Errorf("[%v] Errors don't match expectation", data.name) + if e := err.(*jwt.ValidationError).Errors; e != data.errors { + t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors) } } @@ -149,6 +184,12 @@ func TestJWT(t *testing.T) { func TestParseRequest(t *testing.T) { // Bearer token request for _, data := range jwtTestData { + // FIXME: custom parsers are not supported by this helper. skip tests that require them + if data.parser != nil { + t.Logf("Skipping [%v]. Custom parsers are not supported by ParseRequest", data.name) + continue + } + if data.tokenString == "" { data.tokenString = makeSample(data.claims) }