Re-enabled map claim tests. Added error return value to claim getter functions

This commit is contained in:
Christian Banse 2022-10-26 21:06:11 +02:00
parent 2281dd9079
commit 5d57c292ea
5 changed files with 134 additions and 82 deletions

View File

@ -7,9 +7,9 @@ package jwt
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 namely `exp`,
// `iat`, `nbf`, `iss` and `aud`.
type Claims interface {
GetExpirationTime() *NumericDate
GetIssuedAt() *NumericDate
GetNotBefore() *NumericDate
GetIssuer() string
GetAudience() ClaimStrings
GetExpirationTime() (*NumericDate, error)
GetIssuedAt() (*NumericDate, error)
GetNotBefore() (*NumericDate, error)
GetIssuer() (string, error)
GetAudience() (ClaimStrings, error)
}

View File

@ -2,65 +2,68 @@ package jwt
import (
"encoding/json"
"errors"
)
// MapClaims is a claims type that uses the map[string]interface{} for JSON decoding.
// This is the default claims type if you don't supply one
type MapClaims map[string]interface{}
var ErrInvalidType = errors.New("invalid type for claim")
// GetExpirationTime implements the Claims interface.
func (m MapClaims) GetExpirationTime() *NumericDate {
func (m MapClaims) GetExpirationTime() (*NumericDate, error) {
return m.ParseNumericDate("exp")
}
// GetNotBefore implements the Claims interface.
func (m MapClaims) GetNotBefore() *NumericDate {
func (m MapClaims) GetNotBefore() (*NumericDate, error) {
return m.ParseNumericDate("nbf")
}
// GetIssuedAt implements the Claims interface.
func (m MapClaims) GetIssuedAt() *NumericDate {
func (m MapClaims) GetIssuedAt() (*NumericDate, error) {
return m.ParseNumericDate("iat")
}
// GetAudience implements the Claims interface.
func (m MapClaims) GetAudience() ClaimStrings {
func (m MapClaims) GetAudience() (ClaimStrings, error) {
return m.ParseClaimsString("aud")
}
// GetIssuer implements the Claims interface.
func (m MapClaims) GetIssuer() string {
func (m MapClaims) GetIssuer() (string, error) {
return m.ParseString("iss")
}
// ParseNumericDate tries to parse a key in the map claims type as a number
// date. This will succeed, if the underlying type is either a [float64] or a
// [json.Number]. Otherwise, nil will be returned.
func (m MapClaims) ParseNumericDate(key string) *NumericDate {
func (m MapClaims) ParseNumericDate(key string) (*NumericDate, error) {
v, ok := m[key]
if !ok {
return nil
return nil, nil
}
switch exp := v.(type) {
case float64:
if exp == 0 {
return nil
return nil, nil
}
return newNumericDateFromSeconds(exp)
return newNumericDateFromSeconds(exp), nil
case json.Number:
v, _ := exp.Float64()
return newNumericDateFromSeconds(v)
return newNumericDateFromSeconds(v), nil
}
return nil
return nil, ErrInvalidType
}
// ParseClaimsString tries to parse a key in the map claims type as a
// [ClaimsStrings] type, which can either be a string or an array of string.
func (m MapClaims) ParseClaimsString(key string) ClaimStrings {
func (m MapClaims) ParseClaimsString(key string) (ClaimStrings, error) {
var cs []string
switch v := m[key].(type) {
case string:
@ -71,19 +74,33 @@ func (m MapClaims) ParseClaimsString(key string) ClaimStrings {
for _, a := range v {
vs, ok := a.(string)
if !ok {
return nil
return nil, ErrInvalidType
}
cs = append(cs, vs)
}
}
return cs
return cs, nil
}
// ParseString tries to parse a key in the map claims type as a
// [string] type. Otherwise, an empty string is returned.
func (m MapClaims) ParseString(key string) string {
iss, _ := m[key].(string)
return iss
// ParseString tries to parse a key in the map claims type as a [string] type.
// If the key does not exist, an empty string is returned. If the key has the
// wrong type, an error is returned.
func (m MapClaims) ParseString(key string) (string, error) {
var (
ok bool
raw interface{}
iss string
)
raw, ok = m[key]
if !ok {
return "", nil
}
iss, ok = raw.(string)
if !ok {
return "", ErrInvalidType
}
return iss, nil
}

View File

@ -1,7 +1,10 @@
package jwt
/*
TODO(oxisto): Re-enable tests with validation API
import (
"testing"
"time"
)
func TestVerifyAud(t *testing.T) {
var nilInterface interface{}
var nilListInterface []interface{}
@ -39,7 +42,7 @@ func TestVerifyAud(t *testing.T) {
{Name: "[]String Aud without match not required", MapClaims: MapClaims{"aud": []string{"not.example.com", "example.example.com"}}, Expected: false, Required: true, Comparison: "example.com"},
// Required = false
{Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: false, Required: true, Comparison: "example.com"},
{Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: true, Required: false, Comparison: "example.com"},
// []interface{}
{Name: "Empty []interface{} Aud without match required", MapClaims: MapClaims{"aud": nilListInterface}, Expected: true, Required: false, Comparison: "example.com"},
@ -53,10 +56,17 @@ func TestVerifyAud(t *testing.T) {
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
got := test.MapClaims.VerifyAudience(test.Comparison, test.Required)
var opts []ValidatorOption
if got != test.Expected {
t.Errorf("Expected %v, got %v", test.Expected, got)
if test.Required {
opts = append(opts, WithAudience(test.Comparison))
}
validator := NewValidator(opts...)
got := validator.Validate(test.MapClaims)
if (got == nil) != test.Expected {
t.Errorf("Expected %v, got %v", test.Expected, (got == nil))
}
})
}
@ -67,9 +77,9 @@ func TestMapclaimsVerifyIssuedAtInvalidTypeString(t *testing.T) {
"iat": "foo",
}
want := false
got := mapClaims.VerifyIssuedAt(0, false)
if want != got {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
got := NewValidator(WithIssuedAt()).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
}
@ -78,9 +88,9 @@ func TestMapclaimsVerifyNotBeforeInvalidTypeString(t *testing.T) {
"nbf": "foo",
}
want := false
got := mapClaims.VerifyNotBefore(0, false)
if want != got {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
got := NewValidator().Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
}
@ -89,33 +99,38 @@ func TestMapclaimsVerifyExpiresAtInvalidTypeString(t *testing.T) {
"exp": "foo",
}
want := false
got := mapClaims.VerifyExpiresAt(0, false)
got := NewValidator().Validate(mapClaims)
if want != got {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
}
func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) {
exp := time.Now().Unix()
exp := time.Now()
mapClaims := MapClaims{
"exp": float64(exp),
"exp": float64(exp.Unix()),
}
want := false
got := mapClaims.VerifyExpiresAt(exp, true)
if want != got {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
got := NewValidator(WithTimeFunc(func() time.Time {
return exp
})).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
got = mapClaims.VerifyExpiresAt(exp+1, true)
if want != got {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
got = NewValidator(WithTimeFunc(func() time.Time {
return exp.Add(1 * time.Second)
})).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
want = true
got = mapClaims.VerifyExpiresAt(exp-1, true)
if want != got {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
got = NewValidator(WithTimeFunc(func() time.Time {
return exp.Add(-1 * time.Second)
})).Validate(mapClaims)
if want != (got == nil) {
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
}
}
*/

View File

@ -33,26 +33,26 @@ type RegisteredClaims struct {
}
// GetExpirationTime implements the Claims interface.
func (c RegisteredClaims) GetExpirationTime() *NumericDate {
return c.ExpiresAt
func (c RegisteredClaims) GetExpirationTime() (*NumericDate, error) {
return c.ExpiresAt, nil
}
// GetNotBefore implements the Claims interface.
func (c RegisteredClaims) GetNotBefore() *NumericDate {
return c.NotBefore
func (c RegisteredClaims) GetNotBefore() (*NumericDate, error) {
return c.NotBefore, nil
}
// GetIssuedAt implements the Claims interface.
func (c RegisteredClaims) GetIssuedAt() *NumericDate {
return c.IssuedAt
func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) {
return c.IssuedAt, nil
}
// GetAudience implements the Claims interface.
func (c RegisteredClaims) GetAudience() ClaimStrings {
return c.Audience
func (c RegisteredClaims) GetAudience() (ClaimStrings, error) {
return c.Audience, nil
}
// GetIssuer implements the Claims interface.
func (c RegisteredClaims) GetIssuer() string {
return c.Issuer
func (c RegisteredClaims) GetIssuer() (string, error) {
return c.Issuer, nil
}

View File

@ -2,7 +2,6 @@ package jwt
import (
"crypto/subtle"
"fmt"
"time"
)
@ -62,9 +61,7 @@ func (v *Validator) Validate(claims Claims) error {
}
if !v.VerifyExpiresAt(claims, now, false) {
exp := claims.GetExpirationTime()
delta := now.Sub(exp.Time)
vErr.Inner = fmt.Errorf("%s by %s", ErrTokenExpired, delta)
vErr.Inner = ErrTokenExpired
vErr.Errors |= ValidationErrorExpired
}
@ -79,9 +76,10 @@ func (v *Validator) Validate(claims Claims) error {
vErr.Errors |= ValidationErrorNotValidYet
}
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, false) {
vErr.Inner = ErrTokenNotValidYet
vErr.Errors |= ValidationErrorNotValidYet
// If we have an expected audience, we also require the audience claim
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) {
vErr.Inner = ErrTokenInvalidAudience
vErr.Errors |= ValidationErrorAudience
}
// Finally, we want to give the claim itself some possibility to do some
@ -104,46 +102,68 @@ func (v *Validator) Validate(claims Claims) error {
// VerifyAudience compares the aud claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (v *Validator) VerifyAudience(claims Claims, cmp string, req bool) bool {
return verifyAud(claims.GetAudience(), cmp, req)
aud, err := claims.GetAudience()
if err != nil {
return false
}
return verifyAud(aud, cmp, req)
}
// VerifyExpiresAt compares the exp claim against cmp (cmp < exp).
// If req is false, it will return true, if exp is unset.
func (v *Validator) VerifyExpiresAt(claims Claims, cmp time.Time, req bool) bool {
exp := claims.GetExpirationTime()
if exp == nil {
return verifyExp(nil, cmp, req, v.leeway)
var time *time.Time = nil
exp, err := claims.GetExpirationTime()
if err != nil {
return false
} else if exp != nil {
time = &exp.Time
}
return verifyExp(&exp.Time, cmp, req, v.leeway)
return verifyExp(time, cmp, req, v.leeway)
}
// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat).
// If req is false, it will return true, if iat is unset.
func (v *Validator) VerifyIssuedAt(claims Claims, cmp time.Time, req bool) bool {
iat := claims.GetIssuedAt()
if iat == nil {
return verifyIat(nil, cmp, req, v.leeway)
var time *time.Time = nil
iat, err := claims.GetIssuedAt()
if err != nil {
return false
} else if iat != nil {
time = &iat.Time
}
return verifyIat(&iat.Time, cmp, req, v.leeway)
return verifyIat(time, cmp, req, v.leeway)
}
// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
// If req is false, it will return true, if nbf is unset.
func (v *Validator) VerifyNotBefore(claims Claims, cmp time.Time, req bool) bool {
nbf := claims.GetNotBefore()
if nbf == nil {
return verifyNbf(nil, cmp, req, v.leeway)
var time *time.Time = nil
nbf, err := claims.GetNotBefore()
if err != nil {
return false
} else if nbf != nil {
time = &nbf.Time
}
return verifyNbf(&nbf.Time, cmp, req, v.leeway)
return verifyNbf(time, cmp, req, v.leeway)
}
// VerifyIssuer compares the iss claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (v *Validator) VerifyIssuer(claims Claims, cmp string, req bool) bool {
return verifyIss(claims.GetIssuer(), cmp, req)
iss, err := claims.GetIssuer()
if err != nil {
return false
}
return verifyIss(iss, cmp, req)
}
// ----- helpers