Merge branch 'release_3_0_0' into dg/request

This commit is contained in:
Dave Grijalva 2016-04-12 14:34:54 -07:00
commit 4400800062
4 changed files with 43 additions and 16 deletions

View File

@ -16,7 +16,7 @@ type Parser struct {
// keyFunc will receive the parsed token and should return the key for validating. // keyFunc will receive the parsed token and should return the key for validating.
// If everything is kosher, err will be nil // If everything is kosher, err will be nil
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
return p.ParseWithClaims(tokenString, keyFunc, &MapClaims{}) return p.ParseWithClaims(tokenString, keyFunc, MapClaims{})
} }
func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Claims) (*Token, error) { func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Claims) (*Token, error) {
@ -42,6 +42,7 @@ func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Cla
// parse Claims // parse Claims
var claimBytes []byte var claimBytes []byte
token.Claims = claims
if claimBytes, err = DecodeSegment(parts[1]); err != nil { if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed}
@ -50,12 +51,17 @@ func (p *Parser) ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Cla
if p.UseJSONNumber { if p.UseJSONNumber {
dec.UseNumber() dec.UseNumber()
} }
if err = dec.Decode(&claims); err != nil { // JSON Decode. Special case for map type to avoid weird pointer behavior
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
// Handle decode error
if err != nil {
return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed} return token, &ValidationError{err: err.Error(), Errors: ValidationErrorMalformed}
} }
token.Claims = claims
// Lookup signature method // Lookup signature method
if method, ok := token.Header["alg"].(string); ok { if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil { if token.Method = GetSigningMethod(method); token.Method == nil {

View File

@ -28,7 +28,7 @@ var jwtTestData = []struct {
name string name string
tokenString string tokenString string
keyfunc jwt.Keyfunc keyfunc jwt.Keyfunc
claims jwt.MapClaims claims jwt.Claims
valid bool valid bool
errors uint32 errors uint32
parser *jwt.Parser parser *jwt.Parser
@ -109,7 +109,7 @@ var jwtTestData = []struct {
"invalid signing method", "invalid signing method",
"", "",
defaultKeyFunc, defaultKeyFunc,
map[string]interface{}{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorSignatureInvalid, jwt.ValidationErrorSignatureInvalid,
&jwt.Parser{ValidMethods: []string{"HS256"}}, &jwt.Parser{ValidMethods: []string{"HS256"}},
@ -118,7 +118,7 @@ var jwtTestData = []struct {
"valid signing method", "valid signing method",
"", "",
defaultKeyFunc, defaultKeyFunc,
map[string]interface{}{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
true, true,
0, 0,
&jwt.Parser{ValidMethods: []string{"RS256", "HS256"}}, &jwt.Parser{ValidMethods: []string{"RS256", "HS256"}},
@ -127,7 +127,18 @@ var jwtTestData = []struct {
"JSON Number", "JSON Number",
"", "",
defaultKeyFunc, defaultKeyFunc,
map[string]interface{}{"foo": json.Number("123.4")}, jwt.MapClaims{"foo": json.Number("123.4")},
true,
0,
&jwt.Parser{UseJSONNumber: true},
},
{
"Standard Claims",
"",
defaultKeyFunc,
&jwt.StandardClaims{
ExpiresAt: time.Now().Add(time.Second * 10).Unix(),
},
true, true,
0, 0,
&jwt.Parser{UseJSONNumber: true}, &jwt.Parser{UseJSONNumber: true},
@ -137,20 +148,30 @@ var jwtTestData = []struct {
func TestParser_Parse(t *testing.T) { func TestParser_Parse(t *testing.T) {
privateKey := test.LoadRSAPrivateKeyFromDisk("test/sample_key") privateKey := test.LoadRSAPrivateKeyFromDisk("test/sample_key")
// Iterate over test data set and run tests
for _, data := range jwtTestData { for _, data := range jwtTestData {
// If the token string is blank, use helper function to generate string
if data.tokenString == "" { if data.tokenString == "" {
data.tokenString = test.MakeSampleToken(data.claims, privateKey) data.tokenString = test.MakeSampleToken(data.claims, privateKey)
} }
// Parse the token
var token *jwt.Token var token *jwt.Token
var err error var err error
if data.parser != nil { var parser = data.parser
token, err = data.parser.Parse(data.tokenString, data.keyfunc) if parser == nil {
} else { parser = new(jwt.Parser)
token, err = jwt.Parse(data.tokenString, data.keyfunc) }
// Figure out correct claims type
switch data.claims.(type) {
case jwt.MapClaims:
token, err = parser.ParseWithClaims(data.tokenString, data.keyfunc, jwt.MapClaims{})
case *jwt.StandardClaims:
token, err = parser.ParseWithClaims(data.tokenString, data.keyfunc, &jwt.StandardClaims{})
} }
if !reflect.DeepEqual(&data.claims, token.Claims) { // Verify result matches expectation
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
} }

View File

@ -65,13 +65,13 @@ func TestParseRequest(t *testing.T) {
r.Header.Set(k, tokenString) r.Header.Set(k, tokenString)
} }
} }
token, err := ParseFromRequestWithClaims(r, keyfunc, &jwt.MapClaims{}) token, err := ParseFromRequestWithClaims(r, keyfunc, jwt.MapClaims{})
if token == nil { if token == nil {
t.Errorf("[%v] Token was not found: %v", data.name, err) t.Errorf("[%v] Token was not found: %v", data.name, err)
continue continue
} }
if !reflect.DeepEqual(&data.claims, token.Claims) { if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
} }
if data.valid && err != nil { if data.valid && err != nil {

View File

@ -30,7 +30,7 @@ func LoadRSAPublicKeyFromDisk(location string) *rsa.PublicKey {
return key return key
} }
func MakeSampleToken(c jwt.MapClaims, key interface{}) string { func MakeSampleToken(c jwt.Claims, key interface{}) string {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, c) token := jwt.NewWithClaims(jwt.SigningMethodRS256, c)
s, e := token.SignedString(key) s, e := token.SignedString(key)