diff --git a/parser.go b/parser.go new file mode 100644 index 0000000..40b02bc --- /dev/null +++ b/parser.go @@ -0,0 +1,112 @@ +package jwt + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" +) + +type Parser struct { + ValidMethods []string // If populated, only these methods will be considered valid + UseJSONNumber bool // Use JSON Number format in JSON decoder +} + +// Parse, validate, and return a token. +// keyFunc will receive the parsed token and should return the key for validating. +// If everything is kosher, err will be nil +func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return nil, &ValidationError{err: "token contains an invalid number of segments", Errors: ValidationErrorMalformed} + } + + var err error + token := &Token{Raw: tokenString} + // parse Header + var headerBytes []byte + if headerBytes, err = DecodeSegment(parts[0]); err != nil { + return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} + } + if err = json.Unmarshal(headerBytes, &token.Header); err != nil { + return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} + } + + // parse Claims + var claimBytes []byte + if claimBytes, err = DecodeSegment(parts[1]); err != nil { + return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} + } + dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) + if p.UseJSONNumber { + dec.UseNumber() + } + if err = dec.Decode(&token.Claims); err != nil { + return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} + } + + // Lookup signature method + if method, ok := token.Header["alg"].(string); ok { + if token.Method = GetSigningMethod(method); token.Method == nil { + return token, &ValidationError{err: "signing method (alg) is unavailable.", Errors: ValidationErrorUnverifiable} + } + } else { + return token, &ValidationError{err: "signing method (alg) is unspecified.", Errors: ValidationErrorUnverifiable} + } + + // Verify signing method is in the required set + if p.ValidMethods != nil { + var signingMethodValid = false + var alg = token.Method.Alg() + for _, m := range p.ValidMethods { + if m == alg { + signingMethodValid = true + break + } + } + if !signingMethodValid { + // signing method is not in the listed set + return token, &ValidationError{err: fmt.Sprintf("signing method %v is invalid", alg), Errors: ValidationErrorSignatureInvalid} + } + } + + // Lookup key + var key interface{} + if keyFunc == nil { + // keyFunc was not provided. short circuiting validation + return token, &ValidationError{err: "no Keyfunc was provided.", Errors: ValidationErrorUnverifiable} + } + if key, err = keyFunc(token); err != nil { + // keyFunc returned an error + return token, &ValidationError{err: err.Error(), Errors: ValidationErrorUnverifiable} + } + + // Check expiration times + vErr := &ValidationError{} + now := TimeFunc().Unix() + if exp, ok := token.Claims["exp"].(float64); ok { + if now > int64(exp) { + vErr.err = "token is expired" + vErr.Errors |= ValidationErrorExpired + } + } + if nbf, ok := token.Claims["nbf"].(float64); ok { + if now < int64(nbf) { + vErr.err = "token is not valid yet" + vErr.Errors |= ValidationErrorNotValidYet + } + } + + // Perform validation + if err = token.Method.Verify(strings.Join(parts[0:2], "."), parts[2], key); err != nil { + vErr.err = err.Error() + vErr.Errors |= ValidationErrorSignatureInvalid + } + + if vErr.valid() { + token.Valid = true + return token, nil + } + + return token, vErr +} diff --git a/jwt_test.go b/parser_test.go similarity index 83% rename from jwt_test.go rename to parser_test.go index 9108ded..97d9eee 100644 --- a/jwt_test.go +++ b/parser_test.go @@ -1,6 +1,7 @@ package jwt_test import ( + "encoding/json" "fmt" "github.com/dgrijalva/jwt-go" "io/ioutil" @@ -25,6 +26,7 @@ var jwtTestData = []struct { claims map[string]interface{} valid bool errors uint32 + parser *jwt.Parser }{ { "basic", @@ -33,6 +35,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar"}, true, 0, + nil, }, { "basic expired", @@ -41,6 +44,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, false, jwt.ValidationErrorExpired, + nil, }, { "basic nbf", @@ -49,6 +53,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, false, jwt.ValidationErrorNotValidYet, + nil, }, { "expired and nbf", @@ -57,6 +62,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, false, jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, + nil, }, { "basic invalid", @@ -65,6 +71,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar"}, false, jwt.ValidationErrorSignatureInvalid, + nil, }, { "basic nokeyfunc", @@ -73,6 +80,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar"}, false, jwt.ValidationErrorUnverifiable, + nil, }, { "basic nokey", @@ -81,6 +89,7 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar"}, false, jwt.ValidationErrorSignatureInvalid, + nil, }, { "basic errorkey", @@ -89,6 +98,34 @@ var jwtTestData = []struct { map[string]interface{}{"foo": "bar"}, false, jwt.ValidationErrorUnverifiable, + nil, + }, + { + "invalid signing method", + "", + defaultKeyFunc, + map[string]interface{}{"foo": "bar"}, + false, + jwt.ValidationErrorSignatureInvalid, + &jwt.Parser{ValidMethods: []string{"HS256"}}, + }, + { + "valid signing method", + "", + defaultKeyFunc, + map[string]interface{}{"foo": "bar"}, + true, + 0, + &jwt.Parser{ValidMethods: []string{"RS256", "HS256"}}, + }, + { + "JSON Number", + "", + defaultKeyFunc, + map[string]interface{}{"foo": json.Number("123.4")}, + true, + 0, + &jwt.Parser{UseJSONNumber: true}, }, } @@ -116,12 +153,19 @@ func makeSample(c map[string]interface{}) string { return s } -func TestJWT(t *testing.T) { +func TestParser_Parse(t *testing.T) { for _, data := range jwtTestData { if data.tokenString == "" { data.tokenString = makeSample(data.claims) } - token, err := jwt.Parse(data.tokenString, data.keyfunc) + + var token *jwt.Token + var err error + if data.parser != nil { + token, err = data.parser.Parse(data.tokenString, data.keyfunc) + } else { + token, err = jwt.Parse(data.tokenString, data.keyfunc) + } if !reflect.DeepEqual(data.claims, token.Claims) { t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) @@ -137,8 +181,8 @@ func TestJWT(t *testing.T) { t.Errorf("[%v] Expecting error. Didn't get one.", data.name) } else { // compare the bitfield part of the error - if err.(*jwt.ValidationError).Errors != data.errors { - t.Errorf("[%v] Errors don't match expectation", data.name) + if e := err.(*jwt.ValidationError).Errors; e != data.errors { + t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors) } } @@ -149,6 +193,12 @@ func TestJWT(t *testing.T) { func TestParseRequest(t *testing.T) { // Bearer token request for _, data := range jwtTestData { + // FIXME: custom parsers are not supported by this helper. skip tests that require them + if data.parser != nil { + t.Logf("Skipping [%v]. Custom parsers are not supported by ParseRequest", data.name) + continue + } + if data.tokenString == "" { data.tokenString = makeSample(data.claims) } diff --git a/jwt.go b/token.go similarity index 61% rename from jwt.go rename to token.go index 06995aa..d35aaa4 100644 --- a/jwt.go +++ b/token.go @@ -84,79 +84,7 @@ func (t *Token) SigningString() (string, error) { // keyFunc will receive the parsed token and should return the key for validating. // If everything is kosher, err will be nil func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { - parts := strings.Split(tokenString, ".") - if len(parts) != 3 { - return nil, &ValidationError{err: "token contains an invalid number of segments", Errors: ValidationErrorMalformed} - } - - var err error - token := &Token{Raw: tokenString} - // parse Header - var headerBytes []byte - if headerBytes, err = DecodeSegment(parts[0]); err != nil { - return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} - } - if err = json.Unmarshal(headerBytes, &token.Header); err != nil { - return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} - } - - // parse Claims - var claimBytes []byte - if claimBytes, err = DecodeSegment(parts[1]); err != nil { - return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} - } - if err = json.Unmarshal(claimBytes, &token.Claims); err != nil { - return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} - } - - // Lookup signature method - if method, ok := token.Header["alg"].(string); ok { - if token.Method = GetSigningMethod(method); token.Method == nil { - return token, &ValidationError{err: "signing method (alg) is unavailable.", Errors: ValidationErrorUnverifiable} - } - } else { - return token, &ValidationError{err: "signing method (alg) is unspecified.", Errors: ValidationErrorUnverifiable} - } - - // Lookup key - var key interface{} - if keyFunc == nil { - // keyFunc was not provided. short circuiting validation - return token, &ValidationError{err: "no Keyfunc was provided.", Errors: ValidationErrorUnverifiable} - } - if key, err = keyFunc(token); err != nil { - // keyFunc returned an error - return token, &ValidationError{err: err.Error(), Errors: ValidationErrorUnverifiable} - } - - // Check expiration times - vErr := &ValidationError{} - now := TimeFunc().Unix() - if exp, ok := token.Claims["exp"].(float64); ok { - if now > int64(exp) { - vErr.err = "token is expired" - vErr.Errors |= ValidationErrorExpired - } - } - if nbf, ok := token.Claims["nbf"].(float64); ok { - if now < int64(nbf) { - vErr.err = "token is not valid yet" - vErr.Errors |= ValidationErrorNotValidYet - } - } - - // Perform validation - if err = token.Method.Verify(strings.Join(parts[0:2], "."), parts[2], key); err != nil { - vErr.err = err.Error() - vErr.Errors |= ValidationErrorSignatureInvalid - } - - if vErr.valid() { - token.Valid = true - return token, nil - } - - return token, vErr + return new(Parser).Parse(tokenString, keyFunc) } // Try to find the token in an http.Request.