forked from mirror/jwt
Re-enabled map claim tests. Added error return value to claim getter functions
This commit is contained in:
parent
2281dd9079
commit
5d57c292ea
10
claims.go
10
claims.go
|
@ -7,9 +7,9 @@ package jwt
|
||||||
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 namely `exp`,
|
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 namely `exp`,
|
||||||
// `iat`, `nbf`, `iss` and `aud`.
|
// `iat`, `nbf`, `iss` and `aud`.
|
||||||
type Claims interface {
|
type Claims interface {
|
||||||
GetExpirationTime() *NumericDate
|
GetExpirationTime() (*NumericDate, error)
|
||||||
GetIssuedAt() *NumericDate
|
GetIssuedAt() (*NumericDate, error)
|
||||||
GetNotBefore() *NumericDate
|
GetNotBefore() (*NumericDate, error)
|
||||||
GetIssuer() string
|
GetIssuer() (string, error)
|
||||||
GetAudience() ClaimStrings
|
GetAudience() (ClaimStrings, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,65 +2,68 @@ package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MapClaims is a claims type that uses the map[string]interface{} for JSON decoding.
|
// MapClaims is a claims type that uses the map[string]interface{} for JSON decoding.
|
||||||
// This is the default claims type if you don't supply one
|
// This is the default claims type if you don't supply one
|
||||||
type MapClaims map[string]interface{}
|
type MapClaims map[string]interface{}
|
||||||
|
|
||||||
|
var ErrInvalidType = errors.New("invalid type for claim")
|
||||||
|
|
||||||
// GetExpirationTime implements the Claims interface.
|
// GetExpirationTime implements the Claims interface.
|
||||||
func (m MapClaims) GetExpirationTime() *NumericDate {
|
func (m MapClaims) GetExpirationTime() (*NumericDate, error) {
|
||||||
return m.ParseNumericDate("exp")
|
return m.ParseNumericDate("exp")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNotBefore implements the Claims interface.
|
// GetNotBefore implements the Claims interface.
|
||||||
func (m MapClaims) GetNotBefore() *NumericDate {
|
func (m MapClaims) GetNotBefore() (*NumericDate, error) {
|
||||||
return m.ParseNumericDate("nbf")
|
return m.ParseNumericDate("nbf")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIssuedAt implements the Claims interface.
|
// GetIssuedAt implements the Claims interface.
|
||||||
func (m MapClaims) GetIssuedAt() *NumericDate {
|
func (m MapClaims) GetIssuedAt() (*NumericDate, error) {
|
||||||
return m.ParseNumericDate("iat")
|
return m.ParseNumericDate("iat")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAudience implements the Claims interface.
|
// GetAudience implements the Claims interface.
|
||||||
func (m MapClaims) GetAudience() ClaimStrings {
|
func (m MapClaims) GetAudience() (ClaimStrings, error) {
|
||||||
return m.ParseClaimsString("aud")
|
return m.ParseClaimsString("aud")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIssuer implements the Claims interface.
|
// GetIssuer implements the Claims interface.
|
||||||
func (m MapClaims) GetIssuer() string {
|
func (m MapClaims) GetIssuer() (string, error) {
|
||||||
return m.ParseString("iss")
|
return m.ParseString("iss")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseNumericDate tries to parse a key in the map claims type as a number
|
// ParseNumericDate tries to parse a key in the map claims type as a number
|
||||||
// date. This will succeed, if the underlying type is either a [float64] or a
|
// date. This will succeed, if the underlying type is either a [float64] or a
|
||||||
// [json.Number]. Otherwise, nil will be returned.
|
// [json.Number]. Otherwise, nil will be returned.
|
||||||
func (m MapClaims) ParseNumericDate(key string) *NumericDate {
|
func (m MapClaims) ParseNumericDate(key string) (*NumericDate, error) {
|
||||||
v, ok := m[key]
|
v, ok := m[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch exp := v.(type) {
|
switch exp := v.(type) {
|
||||||
case float64:
|
case float64:
|
||||||
if exp == 0 {
|
if exp == 0 {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return newNumericDateFromSeconds(exp)
|
return newNumericDateFromSeconds(exp), nil
|
||||||
case json.Number:
|
case json.Number:
|
||||||
v, _ := exp.Float64()
|
v, _ := exp.Float64()
|
||||||
|
|
||||||
return newNumericDateFromSeconds(v)
|
return newNumericDateFromSeconds(v), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil, ErrInvalidType
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseClaimsString tries to parse a key in the map claims type as a
|
// ParseClaimsString tries to parse a key in the map claims type as a
|
||||||
// [ClaimsStrings] type, which can either be a string or an array of string.
|
// [ClaimsStrings] type, which can either be a string or an array of string.
|
||||||
func (m MapClaims) ParseClaimsString(key string) ClaimStrings {
|
func (m MapClaims) ParseClaimsString(key string) (ClaimStrings, error) {
|
||||||
var cs []string
|
var cs []string
|
||||||
switch v := m[key].(type) {
|
switch v := m[key].(type) {
|
||||||
case string:
|
case string:
|
||||||
|
@ -71,19 +74,33 @@ func (m MapClaims) ParseClaimsString(key string) ClaimStrings {
|
||||||
for _, a := range v {
|
for _, a := range v {
|
||||||
vs, ok := a.(string)
|
vs, ok := a.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil, ErrInvalidType
|
||||||
}
|
}
|
||||||
cs = append(cs, vs)
|
cs = append(cs, vs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return cs
|
return cs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseString tries to parse a key in the map claims type as a
|
// ParseString tries to parse a key in the map claims type as a [string] type.
|
||||||
// [string] type. Otherwise, an empty string is returned.
|
// If the key does not exist, an empty string is returned. If the key has the
|
||||||
func (m MapClaims) ParseString(key string) string {
|
// wrong type, an error is returned.
|
||||||
iss, _ := m[key].(string)
|
func (m MapClaims) ParseString(key string) (string, error) {
|
||||||
|
var (
|
||||||
return iss
|
ok bool
|
||||||
|
raw interface{}
|
||||||
|
iss string
|
||||||
|
)
|
||||||
|
raw, ok = m[key]
|
||||||
|
if !ok {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
iss, ok = raw.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", ErrInvalidType
|
||||||
|
}
|
||||||
|
|
||||||
|
return iss, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
/*
|
import (
|
||||||
TODO(oxisto): Re-enable tests with validation API
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
func TestVerifyAud(t *testing.T) {
|
func TestVerifyAud(t *testing.T) {
|
||||||
var nilInterface interface{}
|
var nilInterface interface{}
|
||||||
var nilListInterface []interface{}
|
var nilListInterface []interface{}
|
||||||
|
@ -39,7 +42,7 @@ func TestVerifyAud(t *testing.T) {
|
||||||
{Name: "[]String Aud without match not required", MapClaims: MapClaims{"aud": []string{"not.example.com", "example.example.com"}}, Expected: false, Required: true, Comparison: "example.com"},
|
{Name: "[]String Aud without match not required", MapClaims: MapClaims{"aud": []string{"not.example.com", "example.example.com"}}, Expected: false, Required: true, Comparison: "example.com"},
|
||||||
|
|
||||||
// Required = false
|
// Required = false
|
||||||
{Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: false, Required: true, Comparison: "example.com"},
|
{Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: true, Required: false, Comparison: "example.com"},
|
||||||
|
|
||||||
// []interface{}
|
// []interface{}
|
||||||
{Name: "Empty []interface{} Aud without match required", MapClaims: MapClaims{"aud": nilListInterface}, Expected: true, Required: false, Comparison: "example.com"},
|
{Name: "Empty []interface{} Aud without match required", MapClaims: MapClaims{"aud": nilListInterface}, Expected: true, Required: false, Comparison: "example.com"},
|
||||||
|
@ -53,10 +56,17 @@ func TestVerifyAud(t *testing.T) {
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.Name, func(t *testing.T) {
|
t.Run(test.Name, func(t *testing.T) {
|
||||||
got := test.MapClaims.VerifyAudience(test.Comparison, test.Required)
|
var opts []ValidatorOption
|
||||||
|
|
||||||
if got != test.Expected {
|
if test.Required {
|
||||||
t.Errorf("Expected %v, got %v", test.Expected, got)
|
opts = append(opts, WithAudience(test.Comparison))
|
||||||
|
}
|
||||||
|
|
||||||
|
validator := NewValidator(opts...)
|
||||||
|
got := validator.Validate(test.MapClaims)
|
||||||
|
|
||||||
|
if (got == nil) != test.Expected {
|
||||||
|
t.Errorf("Expected %v, got %v", test.Expected, (got == nil))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -67,9 +77,9 @@ func TestMapclaimsVerifyIssuedAtInvalidTypeString(t *testing.T) {
|
||||||
"iat": "foo",
|
"iat": "foo",
|
||||||
}
|
}
|
||||||
want := false
|
want := false
|
||||||
got := mapClaims.VerifyIssuedAt(0, false)
|
got := NewValidator(WithIssuedAt()).Validate(mapClaims)
|
||||||
if want != got {
|
if want != (got == nil) {
|
||||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,9 +88,9 @@ func TestMapclaimsVerifyNotBeforeInvalidTypeString(t *testing.T) {
|
||||||
"nbf": "foo",
|
"nbf": "foo",
|
||||||
}
|
}
|
||||||
want := false
|
want := false
|
||||||
got := mapClaims.VerifyNotBefore(0, false)
|
got := NewValidator().Validate(mapClaims)
|
||||||
if want != got {
|
if want != (got == nil) {
|
||||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,33 +99,38 @@ func TestMapclaimsVerifyExpiresAtInvalidTypeString(t *testing.T) {
|
||||||
"exp": "foo",
|
"exp": "foo",
|
||||||
}
|
}
|
||||||
want := false
|
want := false
|
||||||
got := mapClaims.VerifyExpiresAt(0, false)
|
got := NewValidator().Validate(mapClaims)
|
||||||
|
|
||||||
if want != got {
|
if want != (got == nil) {
|
||||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) {
|
func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) {
|
||||||
exp := time.Now().Unix()
|
exp := time.Now()
|
||||||
mapClaims := MapClaims{
|
mapClaims := MapClaims{
|
||||||
"exp": float64(exp),
|
"exp": float64(exp.Unix()),
|
||||||
}
|
}
|
||||||
want := false
|
want := false
|
||||||
got := mapClaims.VerifyExpiresAt(exp, true)
|
got := NewValidator(WithTimeFunc(func() time.Time {
|
||||||
if want != got {
|
return exp
|
||||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
})).Validate(mapClaims)
|
||||||
|
if want != (got == nil) {
|
||||||
|
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
got = mapClaims.VerifyExpiresAt(exp+1, true)
|
got = NewValidator(WithTimeFunc(func() time.Time {
|
||||||
if want != got {
|
return exp.Add(1 * time.Second)
|
||||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
})).Validate(mapClaims)
|
||||||
|
if want != (got == nil) {
|
||||||
|
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
want = true
|
want = true
|
||||||
got = mapClaims.VerifyExpiresAt(exp-1, true)
|
got = NewValidator(WithTimeFunc(func() time.Time {
|
||||||
if want != got {
|
return exp.Add(-1 * time.Second)
|
||||||
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got)
|
})).Validate(mapClaims)
|
||||||
|
if want != (got == nil) {
|
||||||
|
t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
|
@ -33,26 +33,26 @@ type RegisteredClaims struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetExpirationTime implements the Claims interface.
|
// GetExpirationTime implements the Claims interface.
|
||||||
func (c RegisteredClaims) GetExpirationTime() *NumericDate {
|
func (c RegisteredClaims) GetExpirationTime() (*NumericDate, error) {
|
||||||
return c.ExpiresAt
|
return c.ExpiresAt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNotBefore implements the Claims interface.
|
// GetNotBefore implements the Claims interface.
|
||||||
func (c RegisteredClaims) GetNotBefore() *NumericDate {
|
func (c RegisteredClaims) GetNotBefore() (*NumericDate, error) {
|
||||||
return c.NotBefore
|
return c.NotBefore, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIssuedAt implements the Claims interface.
|
// GetIssuedAt implements the Claims interface.
|
||||||
func (c RegisteredClaims) GetIssuedAt() *NumericDate {
|
func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) {
|
||||||
return c.IssuedAt
|
return c.IssuedAt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAudience implements the Claims interface.
|
// GetAudience implements the Claims interface.
|
||||||
func (c RegisteredClaims) GetAudience() ClaimStrings {
|
func (c RegisteredClaims) GetAudience() (ClaimStrings, error) {
|
||||||
return c.Audience
|
return c.Audience, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIssuer implements the Claims interface.
|
// GetIssuer implements the Claims interface.
|
||||||
func (c RegisteredClaims) GetIssuer() string {
|
func (c RegisteredClaims) GetIssuer() (string, error) {
|
||||||
return c.Issuer
|
return c.Issuer, nil
|
||||||
}
|
}
|
||||||
|
|
62
validator.go
62
validator.go
|
@ -2,7 +2,6 @@ package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"fmt"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -62,9 +61,7 @@ func (v *Validator) Validate(claims Claims) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !v.VerifyExpiresAt(claims, now, false) {
|
if !v.VerifyExpiresAt(claims, now, false) {
|
||||||
exp := claims.GetExpirationTime()
|
vErr.Inner = ErrTokenExpired
|
||||||
delta := now.Sub(exp.Time)
|
|
||||||
vErr.Inner = fmt.Errorf("%s by %s", ErrTokenExpired, delta)
|
|
||||||
vErr.Errors |= ValidationErrorExpired
|
vErr.Errors |= ValidationErrorExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,9 +76,10 @@ func (v *Validator) Validate(claims Claims) error {
|
||||||
vErr.Errors |= ValidationErrorNotValidYet
|
vErr.Errors |= ValidationErrorNotValidYet
|
||||||
}
|
}
|
||||||
|
|
||||||
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, false) {
|
// If we have an expected audience, we also require the audience claim
|
||||||
vErr.Inner = ErrTokenNotValidYet
|
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) {
|
||||||
vErr.Errors |= ValidationErrorNotValidYet
|
vErr.Inner = ErrTokenInvalidAudience
|
||||||
|
vErr.Errors |= ValidationErrorAudience
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally, we want to give the claim itself some possibility to do some
|
// Finally, we want to give the claim itself some possibility to do some
|
||||||
|
@ -104,46 +102,68 @@ func (v *Validator) Validate(claims Claims) error {
|
||||||
// VerifyAudience compares the aud claim against cmp.
|
// VerifyAudience compares the aud claim against cmp.
|
||||||
// If required is false, this method will return true if the value matches or is unset
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
func (v *Validator) VerifyAudience(claims Claims, cmp string, req bool) bool {
|
func (v *Validator) VerifyAudience(claims Claims, cmp string, req bool) bool {
|
||||||
return verifyAud(claims.GetAudience(), cmp, req)
|
aud, err := claims.GetAudience()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return verifyAud(aud, cmp, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyExpiresAt compares the exp claim against cmp (cmp < exp).
|
// VerifyExpiresAt compares the exp claim against cmp (cmp < exp).
|
||||||
// If req is false, it will return true, if exp is unset.
|
// If req is false, it will return true, if exp is unset.
|
||||||
func (v *Validator) VerifyExpiresAt(claims Claims, cmp time.Time, req bool) bool {
|
func (v *Validator) VerifyExpiresAt(claims Claims, cmp time.Time, req bool) bool {
|
||||||
exp := claims.GetExpirationTime()
|
var time *time.Time = nil
|
||||||
if exp == nil {
|
|
||||||
return verifyExp(nil, cmp, req, v.leeway)
|
exp, err := claims.GetExpirationTime()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
} else if exp != nil {
|
||||||
|
time = &exp.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
return verifyExp(&exp.Time, cmp, req, v.leeway)
|
return verifyExp(time, cmp, req, v.leeway)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat).
|
// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat).
|
||||||
// If req is false, it will return true, if iat is unset.
|
// If req is false, it will return true, if iat is unset.
|
||||||
func (v *Validator) VerifyIssuedAt(claims Claims, cmp time.Time, req bool) bool {
|
func (v *Validator) VerifyIssuedAt(claims Claims, cmp time.Time, req bool) bool {
|
||||||
iat := claims.GetIssuedAt()
|
var time *time.Time = nil
|
||||||
if iat == nil {
|
|
||||||
return verifyIat(nil, cmp, req, v.leeway)
|
iat, err := claims.GetIssuedAt()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
} else if iat != nil {
|
||||||
|
time = &iat.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
return verifyIat(&iat.Time, cmp, req, v.leeway)
|
return verifyIat(time, cmp, req, v.leeway)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
|
// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
|
||||||
// If req is false, it will return true, if nbf is unset.
|
// If req is false, it will return true, if nbf is unset.
|
||||||
func (v *Validator) VerifyNotBefore(claims Claims, cmp time.Time, req bool) bool {
|
func (v *Validator) VerifyNotBefore(claims Claims, cmp time.Time, req bool) bool {
|
||||||
nbf := claims.GetNotBefore()
|
var time *time.Time = nil
|
||||||
if nbf == nil {
|
|
||||||
return verifyNbf(nil, cmp, req, v.leeway)
|
nbf, err := claims.GetNotBefore()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
} else if nbf != nil {
|
||||||
|
time = &nbf.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
return verifyNbf(&nbf.Time, cmp, req, v.leeway)
|
return verifyNbf(time, cmp, req, v.leeway)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyIssuer compares the iss claim against cmp.
|
// VerifyIssuer compares the iss claim against cmp.
|
||||||
// If required is false, this method will return true if the value matches or is unset
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
func (v *Validator) VerifyIssuer(claims Claims, cmp string, req bool) bool {
|
func (v *Validator) VerifyIssuer(claims Claims, cmp string, req bool) bool {
|
||||||
return verifyIss(claims.GetIssuer(), cmp, req)
|
iss, err := claims.GetIssuer()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return verifyIss(iss, cmp, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----- helpers
|
// ----- helpers
|
||||||
|
|
Loading…
Reference in New Issue