forked from mirror/jwt
Re-enabled map claim tests. Added error return value to claim getter functions
This commit is contained in:
parent
2281dd9079
commit
5d57c292ea
10
claims.go
10
claims.go
|
@ -7,9 +7,9 @@ package jwt
|
|||
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 namely `exp`,
|
||||
// `iat`, `nbf`, `iss` and `aud`.
|
||||
type Claims interface {
|
||||
GetExpirationTime() *NumericDate
|
||||
GetIssuedAt() *NumericDate
|
||||
GetNotBefore() *NumericDate
|
||||
GetIssuer() string
|
||||
GetAudience() ClaimStrings
|
||||
GetExpirationTime() (*NumericDate, error)
|
||||
GetIssuedAt() (*NumericDate, error)
|
||||
GetNotBefore() (*NumericDate, error)
|
||||
GetIssuer() (string, error)
|
||||
GetAudience() (ClaimStrings, error)
|
||||
}
|
||||
|
|
|
@ -2,65 +2,68 @@ package jwt
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// MapClaims is a claims type that uses the map[string]interface{} for JSON decoding.
|
||||
// This is the default claims type if you don't supply one
|
||||
type MapClaims map[string]interface{}
|
||||
|
||||
var ErrInvalidType = errors.New("invalid type for claim")
|
||||
|
||||
// GetExpirationTime implements the Claims interface.
|
||||
func (m MapClaims) GetExpirationTime() *NumericDate {
|
||||
func (m MapClaims) GetExpirationTime() (*NumericDate, error) {
|
||||
return m.ParseNumericDate("exp")
|
||||
}
|
||||
|
||||
// GetNotBefore implements the Claims interface.
|
||||
func (m MapClaims) GetNotBefore() *NumericDate {
|
||||
func (m MapClaims) GetNotBefore() (*NumericDate, error) {
|
||||
return m.ParseNumericDate("nbf")
|
||||
}
|
||||
|
||||
// GetIssuedAt implements the Claims interface.
|
||||
func (m MapClaims) GetIssuedAt() *NumericDate {
|
||||
func (m MapClaims) GetIssuedAt() (*NumericDate, error) {
|
||||
return m.ParseNumericDate("iat")
|
||||
}
|
||||
|
||||
// GetAudience implements the Claims interface.
|
||||
func (m MapClaims) GetAudience() ClaimStrings {
|
||||
func (m MapClaims) GetAudience() (ClaimStrings, error) {
|
||||
return m.ParseClaimsString("aud")
|
||||
}
|
||||
|
||||
// GetIssuer implements the Claims interface.
|
||||
func (m MapClaims) GetIssuer() string {
|
||||
func (m MapClaims) GetIssuer() (string, error) {
|
||||
return m.ParseString("iss")
|
||||
}
|
||||
|
||||
// ParseNumericDate tries to parse a key in the map claims type as a number
|
||||
// date. This will succeed, if the underlying type is either a [float64] or a
|
||||
// [json.Number]. Otherwise, nil will be returned.
|
||||
func (m MapClaims) ParseNumericDate(key string) *NumericDate {
|
||||
func (m MapClaims) ParseNumericDate(key string) (*NumericDate, error) {
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch exp := v.(type) {
|
||||
case float64:
|
||||
if exp == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return newNumericDateFromSeconds(exp)
|
||||
return newNumericDateFromSeconds(exp), nil
|
||||
case json.Number:
|
||||
v, _ := exp.Float64()
|
||||
|
||||
return newNumericDateFromSeconds(v)
|
||||
return newNumericDateFromSeconds(v), nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, ErrInvalidType
|
||||
}
|
||||
|
||||
// ParseClaimsString tries to parse a key in the map claims type as a
|
||||
// [ClaimsStrings] type, which can either be a string or an array of string.
|
||||
func (m MapClaims) ParseClaimsString(key string) ClaimStrings {
|
||||
func (m MapClaims) ParseClaimsString(key string) (ClaimStrings, error) {
|
||||
var cs []string
|
||||
switch v := m[key].(type) {
|
||||
case string:
|
||||
|
@ -71,19 +74,33 @@ func (m MapClaims) ParseClaimsString(key string) ClaimStrings {
|
|||
for _, a := range v {
|
||||
vs, ok := a.(string)
|
||||
if !ok {
|
||||
return nil
|
||||
return nil, ErrInvalidType
|
||||
}
|
||||
cs = append(cs, vs)
|
||||
}
|
||||
}
|
||||
|
||||
return cs
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
// ParseString tries to parse a key in the map claims type as a
|
||||
// [string] type. Otherwise, an empty string is returned.
|
||||
func (m MapClaims) ParseString(key string) string {
|
||||
iss, _ := m[key].(string)
|
||||
// ParseString tries to parse a key in the map claims type as a [string] type.
|
||||
// If the key does not exist, an empty string is returned. If the key has the
|
||||
// wrong type, an error is returned.
|
||||
func (m MapClaims) ParseString(key string) (string, error) {
|
||||
var (
|
||||
ok bool
|
||||
raw interface{}
|
||||
iss string
|
||||
)
|
||||
raw, ok = m[key]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return iss
|
||||
iss, ok = raw.(string)
|
||||
if !ok {
|
||||
return "", ErrInvalidType
|
||||
}
|
||||
|
||||
return iss, nil
|
||||
}
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
package jwt
|
||||
|
||||
/*
|
||||
TODO(oxisto): Re-enable tests with validation API
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestVerifyAud(t *testing.T) {
|
||||
var nilInterface interface{}
|
||||
var nilListInterface []interface{}
|
||||
|
@ -39,7 +42,7 @@ func TestVerifyAud(t *testing.T) {
|
|||
{Name: "[]String Aud without match not required", MapClaims: MapClaims{"aud": []string{"not.example.com", "example.example.com"}}, Expected: false, Required: true, Comparison: "example.com"},
|
||||
|
||||
// Required = false
|
||||
{Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: false, Required: true, Comparison: "example.com"},
|
||||
{Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: true, Required: false, Comparison: "example.com"},
|
||||
|
||||
// []interface{}
|
||||
{Name: "Empty []interface{} Aud without match required", MapClaims: MapClaims{"aud": nilListInterface}, Expected: true, Required: false, Comparison: "example.com"},
|
||||
|
@ -53,10 +56,17 @@ func TestVerifyAud(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
got := test.MapClaims.VerifyAudience(test.Comparison, test.Required)
|
||||
var opts []ValidatorOption
|
||||
|
||||
if got != test.Expected {
|
||||
t.Errorf("Expected %v, got %v", test.Expected, got)
|
||||
if test.Required {
|
||||
opts = append(opts, WithAudience(test.Comparison))
|
||||
}
|
||||
|
||||
validator := NewValidator(opts...)
|
||||
got := validator.Validate(test.MapClaims)
|
||||
|
||||
if (got == nil) != test.Expected {
|
||||
t.Errorf("Expected %v, got %v", test.Expected, (got == nil))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -67,9 +77,9 @@ func TestMapclaimsVerifyIssuedAtInvalidTypeString(t *testing.T) {
|
|||
"iat": "foo",
|
||||
}
|
||||
want := false
|
||||
got := mapClaims.VerifyIssuedAt(0, false)
|
||||
if want != got {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
||||
got := NewValidator(WithIssuedAt()).Validate(mapClaims)
|
||||
if want != (got == nil) {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,9 +88,9 @@ func TestMapclaimsVerifyNotBeforeInvalidTypeString(t *testing.T) {
|
|||
"nbf": "foo",
|
||||
}
|
||||
want := false
|
||||
got := mapClaims.VerifyNotBefore(0, false)
|
||||
if want != got {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
||||
got := NewValidator().Validate(mapClaims)
|
||||
if want != (got == nil) {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,33 +99,38 @@ func TestMapclaimsVerifyExpiresAtInvalidTypeString(t *testing.T) {
|
|||
"exp": "foo",
|
||||
}
|
||||
want := false
|
||||
got := mapClaims.VerifyExpiresAt(0, false)
|
||||
got := NewValidator().Validate(mapClaims)
|
||||
|
||||
if want != got {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
||||
if want != (got == nil) {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) {
|
||||
exp := time.Now().Unix()
|
||||
exp := time.Now()
|
||||
mapClaims := MapClaims{
|
||||
"exp": float64(exp),
|
||||
"exp": float64(exp.Unix()),
|
||||
}
|
||||
want := false
|
||||
got := mapClaims.VerifyExpiresAt(exp, true)
|
||||
if want != got {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
||||
got := NewValidator(WithTimeFunc(func() time.Time {
|
||||
return exp
|
||||
})).Validate(mapClaims)
|
||||
if want != (got == nil) {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||
}
|
||||
|
||||
got = mapClaims.VerifyExpiresAt(exp+1, true)
|
||||
if want != got {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
||||
got = NewValidator(WithTimeFunc(func() time.Time {
|
||||
return exp.Add(1 * time.Second)
|
||||
})).Validate(mapClaims)
|
||||
if want != (got == nil) {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||
}
|
||||
|
||||
want = true
|
||||
got = mapClaims.VerifyExpiresAt(exp-1, true)
|
||||
if want != got {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
||||
got = NewValidator(WithTimeFunc(func() time.Time {
|
||||
return exp.Add(-1 * time.Second)
|
||||
})).Validate(mapClaims)
|
||||
if want != (got == nil) {
|
||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -33,26 +33,26 @@ type RegisteredClaims struct {
|
|||
}
|
||||
|
||||
// GetExpirationTime implements the Claims interface.
|
||||
func (c RegisteredClaims) GetExpirationTime() *NumericDate {
|
||||
return c.ExpiresAt
|
||||
func (c RegisteredClaims) GetExpirationTime() (*NumericDate, error) {
|
||||
return c.ExpiresAt, nil
|
||||
}
|
||||
|
||||
// GetNotBefore implements the Claims interface.
|
||||
func (c RegisteredClaims) GetNotBefore() *NumericDate {
|
||||
return c.NotBefore
|
||||
func (c RegisteredClaims) GetNotBefore() (*NumericDate, error) {
|
||||
return c.NotBefore, nil
|
||||
}
|
||||
|
||||
// GetIssuedAt implements the Claims interface.
|
||||
func (c RegisteredClaims) GetIssuedAt() *NumericDate {
|
||||
return c.IssuedAt
|
||||
func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) {
|
||||
return c.IssuedAt, nil
|
||||
}
|
||||
|
||||
// GetAudience implements the Claims interface.
|
||||
func (c RegisteredClaims) GetAudience() ClaimStrings {
|
||||
return c.Audience
|
||||
func (c RegisteredClaims) GetAudience() (ClaimStrings, error) {
|
||||
return c.Audience, nil
|
||||
}
|
||||
|
||||
// GetIssuer implements the Claims interface.
|
||||
func (c RegisteredClaims) GetIssuer() string {
|
||||
return c.Issuer
|
||||
func (c RegisteredClaims) GetIssuer() (string, error) {
|
||||
return c.Issuer, nil
|
||||
}
|
||||
|
|
62
validator.go
62
validator.go
|
@ -2,7 +2,6 @@ package jwt
|
|||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -62,9 +61,7 @@ func (v *Validator) Validate(claims Claims) error {
|
|||
}
|
||||
|
||||
if !v.VerifyExpiresAt(claims, now, false) {
|
||||
exp := claims.GetExpirationTime()
|
||||
delta := now.Sub(exp.Time)
|
||||
vErr.Inner = fmt.Errorf("%s by %s", ErrTokenExpired, delta)
|
||||
vErr.Inner = ErrTokenExpired
|
||||
vErr.Errors |= ValidationErrorExpired
|
||||
}
|
||||
|
||||
|
@ -79,9 +76,10 @@ func (v *Validator) Validate(claims Claims) error {
|
|||
vErr.Errors |= ValidationErrorNotValidYet
|
||||
}
|
||||
|
||||
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, false) {
|
||||
vErr.Inner = ErrTokenNotValidYet
|
||||
vErr.Errors |= ValidationErrorNotValidYet
|
||||
// If we have an expected audience, we also require the audience claim
|
||||
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) {
|
||||
vErr.Inner = ErrTokenInvalidAudience
|
||||
vErr.Errors |= ValidationErrorAudience
|
||||
}
|
||||
|
||||
// Finally, we want to give the claim itself some possibility to do some
|
||||
|
@ -104,46 +102,68 @@ func (v *Validator) Validate(claims Claims) error {
|
|||
// VerifyAudience compares the aud claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (v *Validator) VerifyAudience(claims Claims, cmp string, req bool) bool {
|
||||
return verifyAud(claims.GetAudience(), cmp, req)
|
||||
aud, err := claims.GetAudience()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return verifyAud(aud, cmp, req)
|
||||
}
|
||||
|
||||
// VerifyExpiresAt compares the exp claim against cmp (cmp < exp).
|
||||
// If req is false, it will return true, if exp is unset.
|
||||
func (v *Validator) VerifyExpiresAt(claims Claims, cmp time.Time, req bool) bool {
|
||||
exp := claims.GetExpirationTime()
|
||||
if exp == nil {
|
||||
return verifyExp(nil, cmp, req, v.leeway)
|
||||
var time *time.Time = nil
|
||||
|
||||
exp, err := claims.GetExpirationTime()
|
||||
if err != nil {
|
||||
return false
|
||||
} else if exp != nil {
|
||||
time = &exp.Time
|
||||
}
|
||||
|
||||
return verifyExp(&exp.Time, cmp, req, v.leeway)
|
||||
return verifyExp(time, cmp, req, v.leeway)
|
||||
}
|
||||
|
||||
// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat).
|
||||
// If req is false, it will return true, if iat is unset.
|
||||
func (v *Validator) VerifyIssuedAt(claims Claims, cmp time.Time, req bool) bool {
|
||||
iat := claims.GetIssuedAt()
|
||||
if iat == nil {
|
||||
return verifyIat(nil, cmp, req, v.leeway)
|
||||
var time *time.Time = nil
|
||||
|
||||
iat, err := claims.GetIssuedAt()
|
||||
if err != nil {
|
||||
return false
|
||||
} else if iat != nil {
|
||||
time = &iat.Time
|
||||
}
|
||||
|
||||
return verifyIat(&iat.Time, cmp, req, v.leeway)
|
||||
return verifyIat(time, cmp, req, v.leeway)
|
||||
}
|
||||
|
||||
// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
|
||||
// If req is false, it will return true, if nbf is unset.
|
||||
func (v *Validator) VerifyNotBefore(claims Claims, cmp time.Time, req bool) bool {
|
||||
nbf := claims.GetNotBefore()
|
||||
if nbf == nil {
|
||||
return verifyNbf(nil, cmp, req, v.leeway)
|
||||
var time *time.Time = nil
|
||||
|
||||
nbf, err := claims.GetNotBefore()
|
||||
if err != nil {
|
||||
return false
|
||||
} else if nbf != nil {
|
||||
time = &nbf.Time
|
||||
}
|
||||
|
||||
return verifyNbf(&nbf.Time, cmp, req, v.leeway)
|
||||
return verifyNbf(time, cmp, req, v.leeway)
|
||||
}
|
||||
|
||||
// VerifyIssuer compares the iss claim against cmp.
|
||||
// If required is false, this method will return true if the value matches or is unset
|
||||
func (v *Validator) VerifyIssuer(claims Claims, cmp string, req bool) bool {
|
||||
return verifyIss(claims.GetIssuer(), cmp, req)
|
||||
iss, err := claims.GetIssuer()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return verifyIss(iss, cmp, req)
|
||||
}
|
||||
|
||||
// ----- helpers
|
||||
|
|
Loading…
Reference in New Issue