diff --git a/parser.go b/parser.go index cea5591..884fea9 100644 --- a/parser.go +++ b/parser.go @@ -16,7 +16,7 @@ type Parser struct { // keyFunc will receive the parsed token and should return the key for validating. // If everything is kosher, err will be nil func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { - return p.ParseWithClaims(tokenString, keyFunc, &MapClaims{}) + return p.ParseWithClaims(tokenString, keyFunc, MapClaims{}) } func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Claims) (*Token, error) { @@ -42,6 +42,7 @@ func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Cla // parse Claims var claimBytes []byte + token.Claims = claims if claimBytes, err = DecodeSegment(parts[1]); err != nil { return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} @@ -50,12 +51,17 @@ func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Cla if p.UseJSONNumber { dec.UseNumber() } - if err = dec.Decode(&claims); err != nil { + // JSON Decode. Special case for map type to avoid weird pointer behavior + if c, ok := token.Claims.(MapClaims); ok { + err = dec.Decode(&c) + } else { + err = dec.Decode(&claims) + } + // Handle decode error + if err != nil { return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} } - token.Claims = claims - // Lookup signature method if method, ok := token.Header["alg"].(string); ok { if token.Method = GetSigningMethod(method); token.Method == nil { diff --git a/parser_test.go b/parser_test.go index 11cc82c..622a423 100644 --- a/parser_test.go +++ b/parser_test.go @@ -28,7 +28,7 @@ var jwtTestData = []struct { name string tokenString string keyfunc jwt.Keyfunc - claims jwt.MapClaims + claims jwt.Claims valid bool errors uint32 parser *jwt.Parser @@ -109,7 +109,7 @@ var jwtTestData = []struct { "invalid signing method", "", defaultKeyFunc, - map[string]interface{}{"foo": "bar"}, + jwt.MapClaims{"foo": "bar"}, false, jwt.ValidationErrorSignatureInvalid, &jwt.Parser{ValidMethods: []string{"HS256"}}, @@ -118,7 +118,7 @@ var jwtTestData = []struct { "valid signing method", "", defaultKeyFunc, - map[string]interface{}{"foo": "bar"}, + jwt.MapClaims{"foo": "bar"}, true, 0, &jwt.Parser{ValidMethods: []string{"RS256", "HS256"}}, @@ -127,7 +127,18 @@ var jwtTestData = []struct { "JSON Number", "", defaultKeyFunc, - map[string]interface{}{"foo": json.Number("123.4")}, + jwt.MapClaims{"foo": json.Number("123.4")}, + true, + 0, + &jwt.Parser{UseJSONNumber: true}, + }, + { + "Standard Claims", + "", + defaultKeyFunc, + &jwt.StandardClaims{ + ExpiresAt: time.Now().Add(time.Second * 10).Unix(), + }, true, 0, &jwt.Parser{UseJSONNumber: true}, @@ -137,20 +148,30 @@ var jwtTestData = []struct { func TestParser_Parse(t *testing.T) { privateKey := test.LoadRSAPrivateKeyFromDisk("test/sample_key") + // Iterate over test data set and run tests for _, data := range jwtTestData { + // If the token string is blank, use helper function to generate string if data.tokenString == "" { data.tokenString = test.MakeSampleToken(data.claims, privateKey) } + // Parse the token var token *jwt.Token var err error - if data.parser != nil { - token, err = data.parser.Parse(data.tokenString, data.keyfunc) - } else { - token, err = jwt.Parse(data.tokenString, data.keyfunc) + var parser = data.parser + if parser == nil { + parser = new(jwt.Parser) + } + // Figure out correct claims type + switch data.claims.(type) { + case jwt.MapClaims: + token, err = parser.ParseWithClaims(data.tokenString, data.keyfunc, jwt.MapClaims{}) + case *jwt.StandardClaims: + token, err = parser.ParseWithClaims(data.tokenString, data.keyfunc, &jwt.StandardClaims{}) } - if !reflect.DeepEqual(&data.claims, token.Claims) { + // Verify result matches expectation + if !reflect.DeepEqual(data.claims, token.Claims) { t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) } diff --git a/request/request_test.go b/request/request_test.go index 8306912..0f2fb9b 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -65,13 +65,13 @@ func TestParseRequest(t *testing.T) { r.Header.Set(k, tokenString) } } - token, err := ParseFromRequestWithClaims(r, keyfunc, &jwt.MapClaims{}) + token, err := ParseFromRequestWithClaims(r, keyfunc, jwt.MapClaims{}) if token == nil { t.Errorf("[%v] Token was not found: %v", data.name, err) continue } - if !reflect.DeepEqual(&data.claims, token.Claims) { + if !reflect.DeepEqual(data.claims, token.Claims) { t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) } if data.valid && err != nil { diff --git a/test/helpers.go b/test/helpers.go index 39b5208..f84c3ef 100644 --- a/test/helpers.go +++ b/test/helpers.go @@ -30,7 +30,7 @@ func LoadRSAPublicKeyFromDisk(location string) *rsa.PublicKey { return key } -func MakeSampleToken(c jwt.MapClaims, key interface{}) string { +func MakeSampleToken(c jwt.Claims, key interface{}) string { token := jwt.NewWithClaims(jwt.SigningMethodRS256, c) s, e := token.SignedString(key)