From fb4ca74c9f326415a43ecfbb8ff504b06858a02a Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Tue, 12 Apr 2016 14:32:24 -0700 Subject: [PATCH] added special case behavior for MapClaims so they aren't all weird --- parser.go | 14 ++++++++++---- parser_test.go | 45 +++++++++++++++++++++++++++++++++------------ 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/parser.go b/parser.go index cea5591..884fea9 100644 --- a/parser.go +++ b/parser.go @@ -16,7 +16,7 @@ type Parser struct { // keyFunc will receive the parsed token and should return the key for validating. // If everything is kosher, err will be nil func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { - return p.ParseWithClaims(tokenString, keyFunc, &MapClaims{}) + return p.ParseWithClaims(tokenString, keyFunc, MapClaims{}) } func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Claims) (*Token, error) { @@ -42,6 +42,7 @@ func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Cla // parse Claims var claimBytes []byte + token.Claims = claims if claimBytes, err = DecodeSegment(parts[1]); err != nil { return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} @@ -50,12 +51,17 @@ func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Cla if p.UseJSONNumber { dec.UseNumber() } - if err = dec.Decode(&claims); err != nil { + // JSON Decode. Special case for map type to avoid weird pointer behavior + if c, ok := token.Claims.(MapClaims); ok { + err = dec.Decode(&c) + } else { + err = dec.Decode(&claims) + } + // Handle decode error + if err != nil { return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} } - token.Claims = claims - // Lookup signature method if method, ok := token.Header["alg"].(string); ok { if token.Method = GetSigningMethod(method); token.Method == nil { diff --git a/parser_test.go b/parser_test.go index 0c09204..a4eb83f 100644 --- a/parser_test.go +++ b/parser_test.go @@ -25,7 +25,7 @@ var jwtTestData = []struct { name string tokenString string keyfunc jwt.Keyfunc - claims jwt.MapClaims + claims jwt.Claims valid bool errors uint32 parser *jwt.Parser @@ -106,7 +106,7 @@ var jwtTestData = []struct { "invalid signing method", "", defaultKeyFunc, - map[string]interface{}{"foo": "bar"}, + jwt.MapClaims{"foo": "bar"}, false, jwt.ValidationErrorSignatureInvalid, &jwt.Parser{ValidMethods: []string{"HS256"}}, @@ -115,7 +115,7 @@ var jwtTestData = []struct { "valid signing method", "", defaultKeyFunc, - map[string]interface{}{"foo": "bar"}, + jwt.MapClaims{"foo": "bar"}, true, 0, &jwt.Parser{ValidMethods: []string{"RS256", "HS256"}}, @@ -124,7 +124,18 @@ var jwtTestData = []struct { "JSON Number", "", defaultKeyFunc, - map[string]interface{}{"foo": json.Number("123.4")}, + jwt.MapClaims{"foo": json.Number("123.4")}, + true, + 0, + &jwt.Parser{UseJSONNumber: true}, + }, + { + "Standard Claims", + "", + defaultKeyFunc, + &jwt.StandardClaims{ + ExpiresAt: time.Now().Add(time.Second * 10).Unix(), + }, true, 0, &jwt.Parser{UseJSONNumber: true}, @@ -141,7 +152,7 @@ func init() { } } -func makeSample(c jwt.MapClaims) string { +func makeSample(c jwt.Claims) string { keyData, e := ioutil.ReadFile("test/sample_key") if e != nil { panic(e.Error()) @@ -162,20 +173,30 @@ func makeSample(c jwt.MapClaims) string { } func TestParser_Parse(t *testing.T) { + // Iterate over test data set and run tests for _, data := range jwtTestData { + // If the token string is blank, use helper function to generate string if data.tokenString == "" { data.tokenString = makeSample(data.claims) } + // Parse the token var token *jwt.Token var err error - if data.parser != nil { - token, err = data.parser.Parse(data.tokenString, data.keyfunc) - } else { - token, err = jwt.Parse(data.tokenString, data.keyfunc) + var parser = data.parser + if parser == nil { + parser = new(jwt.Parser) + } + // Figure out correct claims type + switch data.claims.(type) { + case jwt.MapClaims: + token, err = parser.ParseWithClaims(data.tokenString, data.keyfunc, jwt.MapClaims{}) + case *jwt.StandardClaims: + token, err = parser.ParseWithClaims(data.tokenString, data.keyfunc, &jwt.StandardClaims{}) } - if !reflect.DeepEqual(&data.claims, token.Claims) { + // Verify result matches expectation + if !reflect.DeepEqual(data.claims, token.Claims) { t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) } @@ -218,13 +239,13 @@ func TestParseRequest(t *testing.T) { r, _ := http.NewRequest("GET", "/", nil) r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString)) - token, err := jwt.ParseFromRequestWithClaims(r, data.keyfunc, &jwt.MapClaims{}) + token, err := jwt.ParseFromRequestWithClaims(r, data.keyfunc, jwt.MapClaims{}) if token == nil { t.Errorf("[%v] Token was not found: %v", data.name, err) continue } - if !reflect.DeepEqual(&data.claims, token.Claims) { + if !reflect.DeepEqual(data.claims, token.Claims) { t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) } if data.valid && err != nil {