Re-enabled map claim tests. Added error return value to claim getter functions

This commit is contained in:
Christian Banse 2022-10-26 21:06:11 +02:00
parent 2281dd9079
commit 5d57c292ea
5 changed files with 134 additions and 82 deletions

View File

@ -7,9 +7,9 @@ package jwt
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 namely `exp`, // https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 namely `exp`,
// `iat`, `nbf`, `iss` and `aud`. // `iat`, `nbf`, `iss` and `aud`.
type Claims interface { type Claims interface {
GetExpirationTime() *NumericDate GetExpirationTime() (*NumericDate, error)
GetIssuedAt() *NumericDate GetIssuedAt() (*NumericDate, error)
GetNotBefore() *NumericDate GetNotBefore() (*NumericDate, error)
GetIssuer() string GetIssuer() (string, error)
GetAudience() ClaimStrings GetAudience() (ClaimStrings, error)
} }

View File

@ -2,65 +2,68 @@ package jwt
import ( import (
"encoding/json" "encoding/json"
"errors"
) )
// MapClaims is a claims type that uses the map[string]interface{} for JSON decoding. // MapClaims is a claims type that uses the map[string]interface{} for JSON decoding.
// This is the default claims type if you don't supply one // This is the default claims type if you don't supply one
type MapClaims map[string]interface{} type MapClaims map[string]interface{}
var ErrInvalidType = errors.New("invalid type for claim")
// GetExpirationTime implements the Claims interface. // GetExpirationTime implements the Claims interface.
func (m MapClaims) GetExpirationTime() *NumericDate { func (m MapClaims) GetExpirationTime() (*NumericDate, error) {
return m.ParseNumericDate("exp") return m.ParseNumericDate("exp")
} }
// GetNotBefore implements the Claims interface. // GetNotBefore implements the Claims interface.
func (m MapClaims) GetNotBefore() *NumericDate { func (m MapClaims) GetNotBefore() (*NumericDate, error) {
return m.ParseNumericDate("nbf") return m.ParseNumericDate("nbf")
} }
// GetIssuedAt implements the Claims interface. // GetIssuedAt implements the Claims interface.
func (m MapClaims) GetIssuedAt() *NumericDate { func (m MapClaims) GetIssuedAt() (*NumericDate, error) {
return m.ParseNumericDate("iat") return m.ParseNumericDate("iat")
} }
// GetAudience implements the Claims interface. // GetAudience implements the Claims interface.
func (m MapClaims) GetAudience() ClaimStrings { func (m MapClaims) GetAudience() (ClaimStrings, error) {
return m.ParseClaimsString("aud") return m.ParseClaimsString("aud")
} }
// GetIssuer implements the Claims interface. // GetIssuer implements the Claims interface.
func (m MapClaims) GetIssuer() string { func (m MapClaims) GetIssuer() (string, error) {
return m.ParseString("iss") return m.ParseString("iss")
} }
// ParseNumericDate tries to parse a key in the map claims type as a number // ParseNumericDate tries to parse a key in the map claims type as a number
// date. This will succeed, if the underlying type is either a [float64] or a // date. This will succeed, if the underlying type is either a [float64] or a
// [json.Number]. Otherwise, nil will be returned. // [json.Number]. Otherwise, nil will be returned.
func (m MapClaims) ParseNumericDate(key string) *NumericDate { func (m MapClaims) ParseNumericDate(key string) (*NumericDate, error) {
v, ok := m[key] v, ok := m[key]
if !ok { if !ok {
return nil return nil, nil
} }
switch exp := v.(type) { switch exp := v.(type) {
case float64: case float64:
if exp == 0 { if exp == 0 {
return nil return nil, nil
} }
return newNumericDateFromSeconds(exp) return newNumericDateFromSeconds(exp), nil
case json.Number: case json.Number:
v, _ := exp.Float64() v, _ := exp.Float64()
return newNumericDateFromSeconds(v) return newNumericDateFromSeconds(v), nil
} }
return nil return nil, ErrInvalidType
} }
// ParseClaimsString tries to parse a key in the map claims type as a // ParseClaimsString tries to parse a key in the map claims type as a
// [ClaimsStrings] type, which can either be a string or an array of string. // [ClaimsStrings] type, which can either be a string or an array of string.
func (m MapClaims) ParseClaimsString(key string) ClaimStrings { func (m MapClaims) ParseClaimsString(key string) (ClaimStrings, error) {
var cs []string var cs []string
switch v := m[key].(type) { switch v := m[key].(type) {
case string: case string:
@ -71,19 +74,33 @@ func (m MapClaims) ParseClaimsString(key string) ClaimStrings {
for _, a := range v { for _, a := range v {
vs, ok := a.(string) vs, ok := a.(string)
if !ok { if !ok {
return nil return nil, ErrInvalidType
} }
cs = append(cs, vs) cs = append(cs, vs)
} }
} }
return cs return cs, nil
} }
// ParseString tries to parse a key in the map claims type as a // ParseString tries to parse a key in the map claims type as a [string] type.
// [string] type. Otherwise, an empty string is returned. // If the key does not exist, an empty string is returned. If the key has the
func (m MapClaims) ParseString(key string) string { // wrong type, an error is returned.
iss, _ := m[key].(string) func (m MapClaims) ParseString(key string) (string, error) {
var (
return iss ok bool
raw interface{}
iss string
)
raw, ok = m[key]
if !ok {
return "", nil
}
iss, ok = raw.(string)
if !ok {
return "", ErrInvalidType
}
return iss, nil
} }

View File

@ -1,7 +1,10 @@
package jwt package jwt
/* import (
TODO(oxisto): Re-enable tests with validation API "testing"
"time"
)
func TestVerifyAud(t *testing.T) { func TestVerifyAud(t *testing.T) {
var nilInterface interface{} var nilInterface interface{}
var nilListInterface []interface{} var nilListInterface []interface{}
@ -39,7 +42,7 @@ func TestVerifyAud(t *testing.T) {
{Name: "[]String Aud without match not required", MapClaims: MapClaims{"aud": []string{"not.example.com", "example.example.com"}}, Expected: false, Required: true, Comparison: "example.com"}, {Name: "[]String Aud without match not required", MapClaims: MapClaims{"aud": []string{"not.example.com", "example.example.com"}}, Expected: false, Required: true, Comparison: "example.com"},
// Required = false // Required = false
{Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: false, Required: true, Comparison: "example.com"}, {Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: true, Required: false, Comparison: "example.com"},
// []interface{} // []interface{}
{Name: "Empty []interface{} Aud without match required", MapClaims: MapClaims{"aud": nilListInterface}, Expected: true, Required: false, Comparison: "example.com"}, {Name: "Empty []interface{} Aud without match required", MapClaims: MapClaims{"aud": nilListInterface}, Expected: true, Required: false, Comparison: "example.com"},
@ -53,10 +56,17 @@ func TestVerifyAud(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.Name, func(t *testing.T) { t.Run(test.Name, func(t *testing.T) {
got := test.MapClaims.VerifyAudience(test.Comparison, test.Required) var opts []ValidatorOption
if got != test.Expected { if test.Required {
t.Errorf("Expected %v, got %v", test.Expected, got) opts = append(opts, WithAudience(test.Comparison))
}
validator := NewValidator(opts...)
got := validator.Validate(test.MapClaims)
if (got == nil) != test.Expected {
t.Errorf("Expected %v, got %v", test.Expected, (got == nil))
} }
}) })
} }
@ -67,9 +77,9 @@ func TestMapclaimsVerifyIssuedAtInvalidTypeString(t *testing.T) {
"iat": "foo", "iat": "foo",
} }
want := false want := false
got := mapClaims.VerifyIssuedAt(0, false) got := NewValidator(WithIssuedAt()).Validate(mapClaims)
if want != got { if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
} }
} }
@ -78,9 +88,9 @@ func TestMapclaimsVerifyNotBeforeInvalidTypeString(t *testing.T) {
"nbf": "foo", "nbf": "foo",
} }
want := false want := false
got := mapClaims.VerifyNotBefore(0, false) got := NewValidator().Validate(mapClaims)
if want != got { if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
} }
} }
@ -89,33 +99,38 @@ func TestMapclaimsVerifyExpiresAtInvalidTypeString(t *testing.T) {
"exp": "foo", "exp": "foo",
} }
want := false want := false
got := mapClaims.VerifyExpiresAt(0, false) got := NewValidator().Validate(mapClaims)
if want != got { if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
} }
} }
func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) { func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) {
exp := time.Now().Unix() exp := time.Now()
mapClaims := MapClaims{ mapClaims := MapClaims{
"exp": float64(exp), "exp": float64(exp.Unix()),
} }
want := false want := false
got := mapClaims.VerifyExpiresAt(exp, true) got := NewValidator(WithTimeFunc(func() time.Time {
if want != got { return exp
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) })).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
} }
got = mapClaims.VerifyExpiresAt(exp+1, true) got = NewValidator(WithTimeFunc(func() time.Time {
if want != got { return exp.Add(1 * time.Second)
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) })).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
} }
want = true want = true
got = mapClaims.VerifyExpiresAt(exp-1, true) got = NewValidator(WithTimeFunc(func() time.Time {
if want != got { return exp.Add(-1 * time.Second)
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) })).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
} }
} }
*/

View File

@ -33,26 +33,26 @@ type RegisteredClaims struct {
} }
// GetExpirationTime implements the Claims interface. // GetExpirationTime implements the Claims interface.
func (c RegisteredClaims) GetExpirationTime() *NumericDate { func (c RegisteredClaims) GetExpirationTime() (*NumericDate, error) {
return c.ExpiresAt return c.ExpiresAt, nil
} }
// GetNotBefore implements the Claims interface. // GetNotBefore implements the Claims interface.
func (c RegisteredClaims) GetNotBefore() *NumericDate { func (c RegisteredClaims) GetNotBefore() (*NumericDate, error) {
return c.NotBefore return c.NotBefore, nil
} }
// GetIssuedAt implements the Claims interface. // GetIssuedAt implements the Claims interface.
func (c RegisteredClaims) GetIssuedAt() *NumericDate { func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) {
return c.IssuedAt return c.IssuedAt, nil
} }
// GetAudience implements the Claims interface. // GetAudience implements the Claims interface.
func (c RegisteredClaims) GetAudience() ClaimStrings { func (c RegisteredClaims) GetAudience() (ClaimStrings, error) {
return c.Audience return c.Audience, nil
} }
// GetIssuer implements the Claims interface. // GetIssuer implements the Claims interface.
func (c RegisteredClaims) GetIssuer() string { func (c RegisteredClaims) GetIssuer() (string, error) {
return c.Issuer return c.Issuer, nil
} }

View File

@ -2,7 +2,6 @@ package jwt
import ( import (
"crypto/subtle" "crypto/subtle"
"fmt"
"time" "time"
) )
@ -62,9 +61,7 @@ func (v *Validator) Validate(claims Claims) error {
} }
if !v.VerifyExpiresAt(claims, now, false) { if !v.VerifyExpiresAt(claims, now, false) {
exp := claims.GetExpirationTime() vErr.Inner = ErrTokenExpired
delta := now.Sub(exp.Time)
vErr.Inner = fmt.Errorf("%s by %s", ErrTokenExpired, delta)
vErr.Errors |= ValidationErrorExpired vErr.Errors |= ValidationErrorExpired
} }
@ -79,9 +76,10 @@ func (v *Validator) Validate(claims Claims) error {
vErr.Errors |= ValidationErrorNotValidYet vErr.Errors |= ValidationErrorNotValidYet
} }
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, false) { // If we have an expected audience, we also require the audience claim
vErr.Inner = ErrTokenNotValidYet if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) {
vErr.Errors |= ValidationErrorNotValidYet vErr.Inner = ErrTokenInvalidAudience
vErr.Errors |= ValidationErrorAudience
} }
// Finally, we want to give the claim itself some possibility to do some // Finally, we want to give the claim itself some possibility to do some
@ -104,46 +102,68 @@ func (v *Validator) Validate(claims Claims) error {
// VerifyAudience compares the aud claim against cmp. // VerifyAudience compares the aud claim against cmp.
// If required is false, this method will return true if the value matches or is unset // If required is false, this method will return true if the value matches or is unset
func (v *Validator) VerifyAudience(claims Claims, cmp string, req bool) bool { func (v *Validator) VerifyAudience(claims Claims, cmp string, req bool) bool {
return verifyAud(claims.GetAudience(), cmp, req) aud, err := claims.GetAudience()
if err != nil {
return false
}
return verifyAud(aud, cmp, req)
} }
// VerifyExpiresAt compares the exp claim against cmp (cmp < exp). // VerifyExpiresAt compares the exp claim against cmp (cmp < exp).
// If req is false, it will return true, if exp is unset. // If req is false, it will return true, if exp is unset.
func (v *Validator) VerifyExpiresAt(claims Claims, cmp time.Time, req bool) bool { func (v *Validator) VerifyExpiresAt(claims Claims, cmp time.Time, req bool) bool {
exp := claims.GetExpirationTime() var time *time.Time = nil
if exp == nil {
return verifyExp(nil, cmp, req, v.leeway) exp, err := claims.GetExpirationTime()
if err != nil {
return false
} else if exp != nil {
time = &exp.Time
} }
return verifyExp(&exp.Time, cmp, req, v.leeway) return verifyExp(time, cmp, req, v.leeway)
} }
// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat). // VerifyIssuedAt compares the iat claim against cmp (cmp >= iat).
// If req is false, it will return true, if iat is unset. // If req is false, it will return true, if iat is unset.
func (v *Validator) VerifyIssuedAt(claims Claims, cmp time.Time, req bool) bool { func (v *Validator) VerifyIssuedAt(claims Claims, cmp time.Time, req bool) bool {
iat := claims.GetIssuedAt() var time *time.Time = nil
if iat == nil {
return verifyIat(nil, cmp, req, v.leeway) iat, err := claims.GetIssuedAt()
if err != nil {
return false
} else if iat != nil {
time = &iat.Time
} }
return verifyIat(&iat.Time, cmp, req, v.leeway) return verifyIat(time, cmp, req, v.leeway)
} }
// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). // VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
// If req is false, it will return true, if nbf is unset. // If req is false, it will return true, if nbf is unset.
func (v *Validator) VerifyNotBefore(claims Claims, cmp time.Time, req bool) bool { func (v *Validator) VerifyNotBefore(claims Claims, cmp time.Time, req bool) bool {
nbf := claims.GetNotBefore() var time *time.Time = nil
if nbf == nil {
return verifyNbf(nil, cmp, req, v.leeway) nbf, err := claims.GetNotBefore()
if err != nil {
return false
} else if nbf != nil {
time = &nbf.Time
} }
return verifyNbf(&nbf.Time, cmp, req, v.leeway) return verifyNbf(time, cmp, req, v.leeway)
} }
// VerifyIssuer compares the iss claim against cmp. // VerifyIssuer compares the iss claim against cmp.
// If required is false, this method will return true if the value matches or is unset // If required is false, this method will return true if the value matches or is unset
func (v *Validator) VerifyIssuer(claims Claims, cmp string, req bool) bool { func (v *Validator) VerifyIssuer(claims Claims, cmp string, req bool) bool {
return verifyIss(claims.GetIssuer(), cmp, req) iss, err := claims.GetIssuer()
if err != nil {
return false
}
return verifyIss(iss, cmp, req)
} }
// ----- helpers // ----- helpers