More consistent way of handling validation errors (#274)

This commit is contained in:
Christian Banse 2023-02-21 08:54:35 +01:00 committed by GitHub
parent 4e6e1ba2bb
commit 28dc52370e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 689 additions and 287 deletions

View File

@ -25,7 +25,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
go: [1.17, 1.18, 1.19] go: ["1.18", "1.19", "1.20"]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

106
errors.go
View File

@ -2,18 +2,17 @@ package jwt
import ( import (
"errors" "errors"
"strings"
) )
// Error constants
var ( var (
ErrInvalidKey = errors.New("key is invalid") ErrInvalidKey = errors.New("key is invalid")
ErrInvalidKeyType = errors.New("key is of invalid type") ErrInvalidKeyType = errors.New("key is of invalid type")
ErrHashUnavailable = errors.New("the requested hash function is unavailable") ErrHashUnavailable = errors.New("the requested hash function is unavailable")
ErrTokenMalformed = errors.New("token is malformed") ErrTokenMalformed = errors.New("token is malformed")
ErrTokenUnverifiable = errors.New("token is unverifiable") ErrTokenUnverifiable = errors.New("token is unverifiable")
ErrTokenSignatureInvalid = errors.New("token signature is invalid") ErrTokenSignatureInvalid = errors.New("token signature is invalid")
ErrTokenRequiredClaimMissing = errors.New("token is missing required claim")
ErrTokenInvalidAudience = errors.New("token has invalid audience") ErrTokenInvalidAudience = errors.New("token has invalid audience")
ErrTokenExpired = errors.New("token is expired") ErrTokenExpired = errors.New("token is expired")
ErrTokenUsedBeforeIssued = errors.New("token used before issued") ErrTokenUsedBeforeIssued = errors.New("token used before issued")
@ -22,100 +21,29 @@ var (
ErrTokenNotValidYet = errors.New("token is not valid yet") ErrTokenNotValidYet = errors.New("token is not valid yet")
ErrTokenInvalidId = errors.New("token has invalid id") ErrTokenInvalidId = errors.New("token has invalid id")
ErrTokenInvalidClaims = errors.New("token has invalid claims") ErrTokenInvalidClaims = errors.New("token has invalid claims")
ErrInvalidType = errors.New("invalid type for claim") ErrInvalidType = errors.New("invalid type for claim")
) )
// The errors that might occur when parsing and validating a token // joinedError is an error type that works similar to what [errors.Join]
const ( // produces, with the exception that it has a nice error string; mainly its
ValidationErrorMalformed uint32 = 1 << iota // Token is malformed // error messages are concatenated using a comma, rather than a newline.
ValidationErrorUnverifiable // Token could not be verified because of signing problems type joinedError struct {
ValidationErrorSignatureInvalid // Signature validation failed errs []error
// Registered Claim validation errors
ValidationErrorAudience // AUD validation failed
ValidationErrorExpired // EXP validation failed
ValidationErrorIssuedAt // IAT validation failed
ValidationErrorIssuer // ISS validation failed
ValidationErrorSubject // SUB validation failed
ValidationErrorNotValidYet // NBF validation failed
ValidationErrorId // JTI validation failed
ValidationErrorClaimsInvalid // Generic claims validation error
)
// NewValidationError is a helper for constructing a ValidationError with a string error message
func NewValidationError(errorText string, errorFlags uint32) *ValidationError {
return &ValidationError{
text: errorText,
Errors: errorFlags,
}
} }
// ValidationError represents an error from Parse if token is not valid func (je joinedError) Error() string {
type ValidationError struct { msg := []string{}
// Inner stores the error returned by external dependencies, e.g.: KeyFunc for _, err := range je.errs {
Inner error msg = append(msg, err.Error())
// Errors is a bit-field. See ValidationError... constants
Errors uint32
// Text can be used for errors that do not have a valid error just have text
text string
} }
// Error is the implementation of the err interface. return strings.Join(msg, ", ")
func (e ValidationError) Error() string {
if e.Inner != nil {
return e.Inner.Error()
} else if e.text != "" {
return e.text
} else {
return "token is invalid"
}
} }
// Unwrap gives errors.Is and errors.As access to the inner error. // joinErrors joins together multiple errors. Useful for scenarios where
func (e *ValidationError) Unwrap() error { // multiple errors next to each other occur, e.g., in claims validation.
return e.Inner func joinErrors(errs ...error) error {
return &joinedError{
errs: errs,
} }
// No errors
func (e *ValidationError) valid() bool {
return e.Errors == 0
}
// Is checks if this ValidationError is of the supplied error. We are first
// checking for the exact error message by comparing the inner error message. If
// that fails, we compare using the error flags. This way we can use custom
// error messages (mainly for backwards compatibility) and still leverage
// errors.Is using the global error variables.
func (e *ValidationError) Is(err error) bool {
// Check, if our inner error is a direct match
if errors.Is(errors.Unwrap(e), err) {
return true
}
// Otherwise, we need to match using our error flags
switch err {
case ErrTokenMalformed:
return e.Errors&ValidationErrorMalformed != 0
case ErrTokenUnverifiable:
return e.Errors&ValidationErrorUnverifiable != 0
case ErrTokenSignatureInvalid:
return e.Errors&ValidationErrorSignatureInvalid != 0
case ErrTokenInvalidAudience:
return e.Errors&ValidationErrorAudience != 0
case ErrTokenExpired:
return e.Errors&ValidationErrorExpired != 0
case ErrTokenUsedBeforeIssued:
return e.Errors&ValidationErrorIssuedAt != 0
case ErrTokenInvalidIssuer:
return e.Errors&ValidationErrorIssuer != 0
case ErrTokenNotValidYet:
return e.Errors&ValidationErrorNotValidYet != 0
case ErrTokenInvalidId:
return e.Errors&ValidationErrorId != 0
case ErrTokenInvalidClaims:
return e.Errors&ValidationErrorClaimsInvalid != 0
}
return false
} }

47
errors_go1_20.go Normal file
View File

@ -0,0 +1,47 @@
//go:build go1.20
// +build go1.20
package jwt
import (
"fmt"
)
// Unwrap implements the multiple error unwrapping for this error type, which is
// possible in Go 1.20.
func (je joinedError) Unwrap() []error {
return je.errs
}
// newError creates a new error message with a detailed error message. The
// message will be prefixed with the contents of the supplied error type.
// Additionally, more errors, that provide more context can be supplied which
// will be appended to the message. This makes use of Go 1.20's possibility to
// include more than one %w formatting directive in [fmt.Errorf].
//
// For example,
//
// newError("no keyfunc was provided", ErrTokenUnverifiable)
//
// will produce the error string
//
// "token is unverifiable: no keyfunc was provided"
func newError(message string, err error, more ...error) error {
var format string
var args []any
if message != "" {
format = "%w: %s"
args = []any{err, message}
} else {
format = "%w"
args = []any{err}
}
for _, e := range more {
format += ": %w"
args = append(args, e)
}
err = fmt.Errorf(format, args...)
return err
}

78
errors_go_other.go Normal file
View File

@ -0,0 +1,78 @@
//go:build !go1.20
// +build !go1.20
package jwt
import (
"errors"
"fmt"
)
// Is implements checking for multiple errors using [errors.Is], since multiple
// error unwrapping is not possible in versions less than Go 1.20.
func (je joinedError) Is(err error) bool {
for _, e := range je.errs {
if errors.Is(e, err) {
return true
}
}
return false
}
// wrappedErrors is a workaround for wrapping multiple errors in environments
// where Go 1.20 is not available. It basically uses the already implemented
// functionatlity of joinedError to handle multiple errors with supplies a
// custom error message that is identical to the one we produce in Go 1.20 using
// multiple %w directives.
type wrappedErrors struct {
msg string
joinedError
}
// Error returns the stored error string
func (we wrappedErrors) Error() string {
return we.msg
}
// newError creates a new error message with a detailed error message. The
// message will be prefixed with the contents of the supplied error type.
// Additionally, more errors, that provide more context can be supplied which
// will be appended to the message. Since we cannot use of Go 1.20's possibility
// to include more than one %w formatting directive in [fmt.Errorf], we have to
// emulate that.
//
// For example,
//
// newError("no keyfunc was provided", ErrTokenUnverifiable)
//
// will produce the error string
//
// "token is unverifiable: no keyfunc was provided"
func newError(message string, err error, more ...error) error {
// We cannot wrap multiple errors here with %w, so we have to be a little
// bit creative. Basically, we are using %s instead of %w to produce the
// same error message and then throw the result into a custom error struct.
var format string
var args []any
if message != "" {
format = "%s: %s"
args = []any{err, message}
} else {
format = "%s"
args = []any{err}
}
errs := []error{err}
for _, e := range more {
format += ": %s"
args = append(args, e)
errs = append(errs, e)
}
err = &wrappedErrors{
msg: fmt.Sprintf(format, args...),
joinedError: joinedError{errs: errs},
}
return err
}

95
errors_test.go Normal file
View File

@ -0,0 +1,95 @@
package jwt
import (
"errors"
"io"
"testing"
)
func Test_joinErrors(t *testing.T) {
type args struct {
errs []error
}
tests := []struct {
name string
args args
wantErrors []error
wantMessage string
}{
{
name: "multiple errors",
args: args{
errs: []error{ErrTokenNotValidYet, ErrTokenExpired},
},
wantErrors: []error{ErrTokenNotValidYet, ErrTokenExpired},
wantMessage: "token is not valid yet, token is expired",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := joinErrors(tt.args.errs...)
for _, wantErr := range tt.wantErrors {
if !errors.Is(err, wantErr) {
t.Errorf("joinErrors() error = %v, does not contain %v", err, wantErr)
}
}
if err.Error() != tt.wantMessage {
t.Errorf("joinErrors() error.Error() = %v, wantMessage %v", err, tt.wantMessage)
}
})
}
}
func Test_newError(t *testing.T) {
type args struct {
message string
err error
more []error
}
tests := []struct {
name string
args args
wantErrors []error
wantMessage string
}{
{
name: "single error",
args: args{message: "something is wrong", err: ErrTokenMalformed},
wantMessage: "token is malformed: something is wrong",
wantErrors: []error{ErrTokenMalformed},
},
{
name: "two errors",
args: args{message: "something is wrong", err: ErrTokenMalformed, more: []error{io.ErrUnexpectedEOF}},
wantMessage: "token is malformed: something is wrong: unexpected EOF",
wantErrors: []error{ErrTokenMalformed},
},
{
name: "two errors, no detail",
args: args{message: "", err: ErrTokenInvalidClaims, more: []error{ErrTokenExpired}},
wantMessage: "token has invalid claims: token is expired",
wantErrors: []error{ErrTokenInvalidClaims, ErrTokenExpired},
},
{
name: "two errors, no detail and join error",
args: args{message: "", err: ErrTokenInvalidClaims, more: []error{joinErrors(ErrTokenExpired, ErrTokenNotValidYet)}},
wantMessage: "token has invalid claims: token is expired, token is not valid yet",
wantErrors: []error{ErrTokenInvalidClaims, ErrTokenExpired, ErrTokenNotValidYet},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := newError(tt.args.message, tt.args.err, tt.args.more...)
for _, wantErr := range tt.wantErrors {
if !errors.Is(err, wantErr) {
t.Errorf("newError() error = %v, does not contain %v", err, wantErr)
}
}
if err.Error() != tt.wantMessage {
t.Errorf("newError() error.Error() = %v, wantMessage %v", err, tt.wantMessage)
}
})
}
}

2
go.mod
View File

@ -1,3 +1,3 @@
module github.com/golang-jwt/jwt/v5 module github.com/golang-jwt/jwt/v5
go 1.16 go 1.18

View File

@ -2,6 +2,7 @@ package jwt
import ( import (
"encoding/json" "encoding/json"
"fmt"
) )
// MapClaims is a claims type that uses the map[string]interface{} for JSON // MapClaims is a claims type that uses the map[string]interface{} for JSON
@ -60,7 +61,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) {
return newNumericDateFromSeconds(v), nil return newNumericDateFromSeconds(v), nil
} }
return nil, ErrInvalidType return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
} }
// parseClaimsString tries to parse a key in the map claims type as a // parseClaimsString tries to parse a key in the map claims type as a
@ -76,7 +77,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
for _, a := range v { for _, a := range v {
vs, ok := a.(string) vs, ok := a.(string)
if !ok { if !ok {
return nil, ErrInvalidType return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
} }
cs = append(cs, vs) cs = append(cs, vs)
} }
@ -101,7 +102,7 @@ func (m MapClaims) parseString(key string) (string, error) {
iss, ok = raw.(string) iss, ok = raw.(string)
if !ok { if !ok {
return "", ErrInvalidType return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
} }
return iss, nil return iss, nil

View File

@ -13,7 +13,7 @@ type unsafeNoneMagicConstant string
func init() { func init() {
SigningMethodNone = &signingMethodNone{} SigningMethodNone = &signingMethodNone{}
NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid) NoneSignatureTypeDisallowedError = newError("'none' signature type is not allowed", ErrTokenUnverifiable)
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod { RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
return SigningMethodNone return SigningMethodNone
@ -33,10 +33,7 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac
} }
// If signing method is none, signature must be an empty string // If signing method is none, signature must be an empty string
if signature != "" { if signature != "" {
return NewValidationError( return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable)
"'none' signing method with non-empty signature",
ValidationErrorSignatureInvalid,
)
} }
// Accept 'none' signing method. // Accept 'none' signing method.

View File

@ -65,7 +65,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
} }
if !signingMethodValid { if !signingMethodValid {
// signing method is not in the listed set // signing method is not in the listed set
return token, NewValidationError(fmt.Sprintf("signing method %v is invalid", alg), ValidationErrorSignatureInvalid) return token, newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid)
} }
} }
@ -73,17 +73,17 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
var key interface{} var key interface{}
if keyFunc == nil { if keyFunc == nil {
// keyFunc was not provided. short circuiting validation // keyFunc was not provided. short circuiting validation
return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable) return token, newError("no keyfunc was provided", ErrTokenUnverifiable)
} }
if key, err = keyFunc(token); err != nil { if key, err = keyFunc(token); err != nil {
// keyFunc returned an error return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err)
if ve, ok := err.(*ValidationError); ok {
return token, ve
}
return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable}
} }
vErr := &ValidationError{} // Perform signature validation
token.Signature = parts[2]
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
return token, newError("", ErrTokenSignatureInvalid, err)
}
// Validate Claims // Validate Claims
if !p.skipClaimsValidation { if !p.skipClaimsValidation {
@ -93,29 +93,14 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
} }
if err := p.validator.Validate(claims); err != nil { if err := p.validator.Validate(claims); err != nil {
// If the Claims Valid returned an error, check if it is a validation error, return token, newError("", ErrTokenInvalidClaims, err)
// If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set
if e, ok := err.(*ValidationError); !ok {
vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid}
} else {
vErr = e
}
} }
} }
// Perform validation // No errors so far, token is valid.
token.Signature = parts[2]
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
vErr.Inner = err
vErr.Errors |= ValidationErrorSignatureInvalid
}
if vErr.valid() {
token.Valid = true token.Valid = true
return token, nil
}
return token, vErr return token, nil
} }
// ParseUnverified parses the token but doesn't validate the signature. // ParseUnverified parses the token but doesn't validate the signature.
@ -127,7 +112,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) { func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
parts = strings.Split(tokenString, ".") parts = strings.Split(tokenString, ".")
if len(parts) != 3 { if len(parts) != 3 {
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed) return nil, parts, newError("token contains an invalid number of segments", ErrTokenMalformed)
} }
token = &Token{Raw: tokenString} token = &Token{Raw: tokenString}
@ -136,12 +121,12 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
var headerBytes []byte var headerBytes []byte
if headerBytes, err = DecodeSegment(parts[0]); err != nil { if headerBytes, err = DecodeSegment(parts[0]); err != nil {
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") { if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed) return token, parts, newError("tokenstring should not contain 'bearer '", ErrTokenMalformed)
} }
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err)
} }
if err = json.Unmarshal(headerBytes, &token.Header); err != nil { if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err)
} }
// parse Claims // parse Claims
@ -149,7 +134,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
token.Claims = claims token.Claims = claims
if claimBytes, err = DecodeSegment(parts[1]); err != nil { if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err)
} }
dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
if p.useJSONNumber { if p.useJSONNumber {
@ -163,16 +148,16 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
} }
// Handle decode error // Handle decode error
if err != nil { if err != nil {
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err)
} }
// Lookup signature method // Lookup signature method
if method, ok := token.Header["alg"].(string); ok { if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil { if token.Method = GetSigningMethod(method); token.Method == nil {
return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable) return token, parts, newError("signing method (alg) is unavailable", ErrTokenUnverifiable)
} }
} else { } else {
return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable) return token, parts, newError("signing method (alg) is unspecified", ErrTokenUnverifiable)
} }
return token, parts, nil return token, parts, nil

View File

@ -51,7 +51,6 @@ var jwtTestData = []struct {
keyfunc jwt.Keyfunc keyfunc jwt.Keyfunc
claims jwt.Claims claims jwt.Claims
valid bool valid bool
errors uint32
err []error err []error
parser *jwt.Parser parser *jwt.Parser
signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose
@ -62,7 +61,16 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
nil, nil,
false, false,
jwt.ValidationErrorMalformed, []error{jwt.ErrTokenMalformed},
nil,
jwt.SigningMethodRS256,
},
{
"invalid JSON claim",
"eyJhbGciOiJSUzI1NiIsInppcCI6IkRFRiJ9.eNqqVkqtKFCyMjQ1s7Q0sbA0MtFRyk3NTUot8kxRslIKLbZQggn4JeamAoUcfRz99HxcXRWeze172tr4bFq7Ui0AAAD__w.jBXD4LT4aq4oXTgDoPkiV6n4QdSZPZI1Z4J8MWQC42aHK0oXwcovEU06dVbtB81TF-2byuu0-qi8J0GUttODT67k6gCl6DV_iuCOV7gczwTcvKslotUvXzoJ2wa0QuujnjxLEE50r0p6k0tsv_9OIFSUZzDksJFYNPlJH2eFG55DROx4TsOz98az37SujZi9GGbTc9SLgzFHPrHMrovRZ5qLC_w4JrdtsLzBBI11OQJgRYwV8fQf4O8IsMkHtetjkN7dKgUkJtRarNWOk76rpTPppLypiLU4_J0-wrElLMh1TzUVZW6Fz2cDHDDBACJgMmKQ2pOFEDK_vYZN74dLCF5GiTZV6DbXhNxO7lqT7JUN4a3p2z96G7WNRjblf2qZeuYdQvkIsiK-rCbSIE836XeY5gaBgkOzuEvzl_tMrpRmb5Oox1ibOfVT2KBh9Lvqsb1XbQjCio2CLE2ViCLqoe0AaRqlUyrk3n8BIG-r0IW4dcw96CEryEMIjsjVp9mtPXamJzf391kt8Rf3iRBqwv3zP7Plg1ResXbmsFUgOflAUPcYmfLug4W3W52ntcUlTHAKXrNfaJL9QQiYAaDukG-ZHDytsOWTuuXw7lVxjt-XYi1VbRAIjh1aIYSELEmEpE4Ny74htQtywYXMQNfJpB0nNn8IiWakgcYYMJ0TmKM",
defaultKeyFunc,
nil,
false,
[]error{jwt.ErrTokenMalformed}, []error{jwt.ErrTokenMalformed},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -73,7 +81,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
nil, nil,
false, false,
jwt.ValidationErrorMalformed,
[]error{jwt.ErrTokenMalformed}, []error{jwt.ErrTokenMalformed},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -84,7 +91,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
true, true,
0,
nil, nil,
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -95,7 +101,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
false, false,
jwt.ValidationErrorExpired,
[]error{jwt.ErrTokenExpired}, []error{jwt.ErrTokenExpired},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -106,7 +111,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
false, false,
jwt.ValidationErrorNotValidYet,
[]error{jwt.ErrTokenNotValidYet}, []error{jwt.ErrTokenNotValidYet},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -117,8 +121,7 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
false, false,
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired},
[]error{jwt.ErrTokenNotValidYet},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -128,7 +131,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification}, []error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -139,7 +141,6 @@ var jwtTestData = []struct {
nilKeyFunc, nilKeyFunc,
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorUnverifiable,
[]error{jwt.ErrTokenUnverifiable}, []error{jwt.ErrTokenUnverifiable},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -150,7 +151,6 @@ var jwtTestData = []struct {
emptyKeyFunc, emptyKeyFunc,
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid}, []error{jwt.ErrTokenSignatureInvalid},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -161,7 +161,6 @@ var jwtTestData = []struct {
errorKeyFunc, errorKeyFunc,
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorUnverifiable,
[]error{jwt.ErrTokenUnverifiable, errKeyFuncError}, []error{jwt.ErrTokenUnverifiable, errKeyFuncError},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -172,7 +171,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid}, []error{jwt.ErrTokenSignatureInvalid},
jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})), jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -183,7 +181,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
true, true,
0,
nil, nil,
jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -194,7 +191,6 @@ var jwtTestData = []struct {
ecdsaKeyFunc, ecdsaKeyFunc,
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid}, []error{jwt.ErrTokenSignatureInvalid},
jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})),
jwt.SigningMethodES256, jwt.SigningMethodES256,
@ -205,7 +201,6 @@ var jwtTestData = []struct {
ecdsaKeyFunc, ecdsaKeyFunc,
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
true, true,
0,
nil, nil,
jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})), jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})),
jwt.SigningMethodES256, jwt.SigningMethodES256,
@ -216,7 +211,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": json.Number("123.4")}, jwt.MapClaims{"foo": json.Number("123.4")},
true, true,
0,
nil, nil,
jwt.NewParser(jwt.WithJSONNumber()), jwt.NewParser(jwt.WithJSONNumber()),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -227,7 +221,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
false, false,
jwt.ValidationErrorExpired,
[]error{jwt.ErrTokenExpired}, []error{jwt.ErrTokenExpired},
jwt.NewParser(jwt.WithJSONNumber()), jwt.NewParser(jwt.WithJSONNumber()),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -238,7 +231,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
false, false,
jwt.ValidationErrorNotValidYet,
[]error{jwt.ErrTokenNotValidYet}, []error{jwt.ErrTokenNotValidYet},
jwt.NewParser(jwt.WithJSONNumber()), jwt.NewParser(jwt.WithJSONNumber()),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -249,8 +241,7 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
false, false,
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired},
[]error{jwt.ErrTokenNotValidYet},
jwt.NewParser(jwt.WithJSONNumber()), jwt.NewParser(jwt.WithJSONNumber()),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -260,7 +251,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
true, true,
0,
nil, nil,
jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()), jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -273,7 +263,6 @@ var jwtTestData = []struct {
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Second * 10)), ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Second * 10)),
}, },
true, true,
0,
nil, nil,
jwt.NewParser(jwt.WithJSONNumber()), jwt.NewParser(jwt.WithJSONNumber()),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -286,7 +275,6 @@ var jwtTestData = []struct {
Audience: jwt.ClaimStrings{"test"}, Audience: jwt.ClaimStrings{"test"},
}, },
true, true,
0,
nil, nil,
jwt.NewParser(jwt.WithJSONNumber()), jwt.NewParser(jwt.WithJSONNumber()),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -299,7 +287,6 @@ var jwtTestData = []struct {
Audience: jwt.ClaimStrings{"test", "test"}, Audience: jwt.ClaimStrings{"test", "test"},
}, },
true, true,
0,
nil, nil,
jwt.NewParser(jwt.WithJSONNumber()), jwt.NewParser(jwt.WithJSONNumber()),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -312,7 +299,6 @@ var jwtTestData = []struct {
Audience: nil, // because of the unmarshal error, this will be empty Audience: nil, // because of the unmarshal error, this will be empty
}, },
false, false,
jwt.ValidationErrorMalformed,
[]error{jwt.ErrTokenMalformed}, []error{jwt.ErrTokenMalformed},
jwt.NewParser(jwt.WithJSONNumber()), jwt.NewParser(jwt.WithJSONNumber()),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -325,7 +311,6 @@ var jwtTestData = []struct {
Audience: nil, // because of the unmarshal error, this will be empty Audience: nil, // because of the unmarshal error, this will be empty
}, },
false, false,
jwt.ValidationErrorMalformed,
[]error{jwt.ErrTokenMalformed}, []error{jwt.ErrTokenMalformed},
jwt.NewParser(jwt.WithJSONNumber()), jwt.NewParser(jwt.WithJSONNumber()),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -336,7 +321,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
&jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))}, &jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))},
false, false,
jwt.ValidationErrorNotValidYet,
[]error{jwt.ErrTokenNotValidYet}, []error{jwt.ErrTokenNotValidYet},
jwt.NewParser(jwt.WithLeeway(time.Minute)), jwt.NewParser(jwt.WithLeeway(time.Minute)),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -347,7 +331,6 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
&jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))}, &jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))},
true, true,
0,
nil, nil,
jwt.NewParser(jwt.WithLeeway(2 * time.Minute)), jwt.NewParser(jwt.WithLeeway(2 * time.Minute)),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
@ -381,7 +364,6 @@ func TestParser_Parse(t *testing.T) {
// Parse the token // Parse the token
var token *jwt.Token var token *jwt.Token
var ve *jwt.ValidationError
var err error var err error
var parser = data.parser var parser = data.parser
if parser == nil { if parser == nil {
@ -417,23 +399,6 @@ func TestParser_Parse(t *testing.T) {
t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name) t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name)
} }
if data.errors != 0 {
if err == nil {
t.Errorf("[%v] Expecting error. Didn't get one.", data.name)
} else {
if errors.As(err, &ve) {
// compare the bitfield part of the error
if e := ve.Errors; e != data.errors {
t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors)
}
if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError {
t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError)
}
}
}
}
if data.err != nil { if data.err != nil {
if err == nil { if err == nil {
t.Errorf("[%v] Expecting error(s). Didn't get one.", data.name) t.Errorf("[%v] Expecting error(s). Didn't get one.", data.name)
@ -467,7 +432,7 @@ func TestParser_ParseUnverified(t *testing.T) {
// Iterate over test data set and run tests // Iterate over test data set and run tests
for _, data := range jwtTestData { for _, data := range jwtTestData {
// Skip test data, that intentionally contains malformed tokens, as they would lead to an error // Skip test data, that intentionally contains malformed tokens, as they would lead to an error
if data.errors&jwt.ValidationErrorMalformed != 0 { if len(data.err) == 1 && errors.Is(data.err[0], jwt.ErrTokenMalformed) {
continue continue
} }

View File

@ -2,9 +2,32 @@ package jwt
import ( import (
"crypto/subtle" "crypto/subtle"
"fmt"
"time" "time"
) )
// ClaimsValidator is an interface that can be implemented by custom claims who
// wish to execute any additional claims validation based on
// application-specific logic. The Validate function is then executed in
// addition to the regular claims validation and any error returned is appended
// to the final validation result.
//
// type MyCustomClaims struct {
// Foo string `json:"foo"`
// jwt.RegisteredClaims
// }
//
// func (m MyCustomClaims) Validate() error {
// if m.Foo != "bar" {
// return errors.New("must be foobar")
// }
// return nil
// }
type ClaimsValidator interface {
Claims
Validate() error
}
// validator is the core of the new Validation API. It is automatically used by // validator is the core of the new Validation API. It is automatically used by
// a [Parser] during parsing and can be modified with various parser options. // a [Parser] during parsing and can be modified with various parser options.
// //
@ -47,10 +70,13 @@ func newValidator(opts ...ParserOption) *validator {
} }
// Validate validates the given claims. It will also perform any custom // Validate validates the given claims. It will also perform any custom
// validation if claims implements the CustomValidator interface. // validation if claims implements the [ClaimsValidator] interface.
func (v *validator) Validate(claims Claims) error { func (v *validator) Validate(claims Claims) error {
var now time.Time var (
vErr := new(ValidationError) now time.Time
errs []error = make([]error, 0, 6)
err error
)
// Check, if we have a time func // Check, if we have a time func
if v.timeFunc != nil { if v.timeFunc != nil {
@ -60,140 +86,139 @@ func (v *validator) Validate(claims Claims) error {
} }
// We always need to check the expiration time, but usage of the claim // We always need to check the expiration time, but usage of the claim
// itself is OPTIONAL // itself is OPTIONAL.
if !v.VerifyExpiresAt(claims, now, false) { if err = v.verifyExpiresAt(claims, now, false); err != nil {
vErr.Inner = ErrTokenExpired errs = append(errs, err)
vErr.Errors |= ValidationErrorExpired
} }
// We always need to check not-before, but usage of the claim itself is // We always need to check not-before, but usage of the claim itself is
// OPTIONAL // OPTIONAL.
if !v.VerifyNotBefore(claims, now, false) { if err = v.verifyNotBefore(claims, now, false); err != nil {
vErr.Inner = ErrTokenNotValidYet errs = append(errs, err)
vErr.Errors |= ValidationErrorNotValidYet
} }
// Check issued-at if the option is enabled // Check issued-at if the option is enabled
if v.verifyIat && !v.VerifyIssuedAt(claims, now, false) { if v.verifyIat {
vErr.Inner = ErrTokenUsedBeforeIssued if err = v.verifyIssuedAt(claims, now, false); err != nil {
vErr.Errors |= ValidationErrorIssuedAt errs = append(errs, err)
}
} }
// If we have an expected audience, we also require the audience claim // If we have an expected audience, we also require the audience claim
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) { if v.expectedAud != "" {
vErr.Inner = ErrTokenInvalidAudience if err = v.verifyAudience(claims, v.expectedAud, true); err != nil {
vErr.Errors |= ValidationErrorAudience errs = append(errs, err)
}
} }
// If we have an expected issuer, we also require the issuer claim // If we have an expected issuer, we also require the issuer claim
if v.expectedIss != "" && !v.VerifyIssuer(claims, v.expectedIss, true) { if v.expectedIss != "" {
vErr.Inner = ErrTokenInvalidIssuer if err = v.verifyIssuer(claims, v.expectedIss, true); err != nil {
vErr.Errors |= ValidationErrorIssuer errs = append(errs, err)
}
} }
// If we have an expected subject, we also require the subject claim // If we have an expected subject, we also require the subject claim
if v.expectedSub != "" && !v.VerifySubject(claims, v.expectedSub, true) { if v.expectedSub != "" {
vErr.Inner = ErrTokenInvalidSubject if err = v.verifySubject(claims, v.expectedSub, true); err != nil {
vErr.Errors |= ValidationErrorSubject errs = append(errs, err)
}
} }
// Finally, we want to give the claim itself some possibility to do some // Finally, we want to give the claim itself some possibility to do some
// additional custom validation based on a custom Validate function. // additional custom validation based on a custom Validate function.
cvt, ok := claims.(interface { cvt, ok := claims.(ClaimsValidator)
Validate() error
})
if ok { if ok {
if err := cvt.Validate(); err != nil { if err := cvt.Validate(); err != nil {
vErr.Inner = err errs = append(errs, err)
vErr.Errors |= ValidationErrorClaimsInvalid
} }
} }
if vErr.valid() { if len(errs) == 0 {
return nil return nil
} }
return vErr return joinErrors(errs...)
} }
// VerifyExpiresAt compares the exp claim in claims against cmp. This function // verifyExpiresAt compares the exp claim in claims against cmp. This function
// will return true if cmp < exp. Additional leeway is taken into account. // will succeed if cmp < exp. Additional leeway is taken into account.
// //
// If exp is not set, it will return true if the claim is not required, // If exp is not set, it will succeed if the claim is not required,
// otherwise false will be returned. // otherwise ErrTokenRequiredClaimMissing will be returned.
// //
// Additionally, if any error occurs while retrieving the claim, e.g., when its // Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned. // the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) VerifyExpiresAt(claims Claims, cmp time.Time, required bool) bool { func (v *validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error {
exp, err := claims.GetExpirationTime() exp, err := claims.GetExpirationTime()
if err != nil { if err != nil {
return false return err
} }
if exp != nil { if exp == nil {
return cmp.Before((exp.Time).Add(+v.leeway)) return errorIfRequired(required, "exp")
} else {
return !required
}
} }
// VerifyIssuedAt compares the iat claim in claims against cmp. This function return errorIfFalse(cmp.Before((exp.Time).Add(+v.leeway)), ErrTokenExpired)
// will return true if cmp >= iat. Additional leeway is taken into account. }
// verifyIssuedAt compares the iat claim in claims against cmp. This function
// will succeed if cmp >= iat. Additional leeway is taken into account.
// //
// If iat is not set, it will return true if the claim is not required, // If iat is not set, it will succeed if the claim is not required,
// otherwise false will be returned. // otherwise ErrTokenRequiredClaimMissing will be returned.
// //
// Additionally, if any error occurs while retrieving the claim, e.g., when its // Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned. // the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) VerifyIssuedAt(claims Claims, cmp time.Time, required bool) bool { func (v *validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error {
iat, err := claims.GetIssuedAt() iat, err := claims.GetIssuedAt()
if err != nil { if err != nil {
return false return err
} }
if iat != nil { if iat == nil {
return !cmp.Before(iat.Add(-v.leeway)) return errorIfRequired(required, "iat")
} else {
return !required
}
} }
// VerifyNotBefore compares the nbf claim in claims against cmp. This function return errorIfFalse(!cmp.Before(iat.Add(-v.leeway)), ErrTokenUsedBeforeIssued)
}
// verifyNotBefore compares the nbf claim in claims against cmp. This function
// will return true if cmp >= nbf. Additional leeway is taken into account. // will return true if cmp >= nbf. Additional leeway is taken into account.
// //
// If nbf is not set, it will return true if the claim is not required, // If nbf is not set, it will succeed if the claim is not required,
// otherwise false will be returned. // otherwise ErrTokenRequiredClaimMissing will be returned.
// //
// Additionally, if any error occurs while retrieving the claim, e.g., when its // Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned. // the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) VerifyNotBefore(claims Claims, cmp time.Time, required bool) bool { func (v *validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error {
nbf, err := claims.GetNotBefore() nbf, err := claims.GetNotBefore()
if err != nil { if err != nil {
return false return err
} }
if nbf != nil { if nbf == nil {
return !cmp.Before(nbf.Add(-v.leeway)) return errorIfRequired(required, "nbf")
} else {
return !required
}
} }
// VerifyAudience compares the aud claim against cmp. return errorIfFalse(!cmp.Before(nbf.Add(-v.leeway)), ErrTokenNotValidYet)
}
// verifyAudience compares the aud claim against cmp.
// //
// If aud is not set or an empty list, it will return true if the claim is not // If aud is not set or an empty list, it will succeed if the claim is not required,
// required, otherwise false will be returned. // otherwise ErrTokenRequiredClaimMissing will be returned.
// //
// Additionally, if any error occurs while retrieving the claim, e.g., when its // Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned. // the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) bool { func (v *validator) verifyAudience(claims Claims, cmp string, required bool) error {
aud, err := claims.GetAudience() aud, err := claims.GetAudience()
if err != nil { if err != nil {
return false return err
} }
if len(aud) == 0 { if len(aud) == 0 {
return !required return errorIfRequired(required, "aud")
} }
// use a var here to keep constant time compare when looping over a number of claims // use a var here to keep constant time compare when looping over a number of claims
@ -209,48 +234,68 @@ func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) boo
// case where "" is sent in one or many aud claims // case where "" is sent in one or many aud claims
if stringClaims == "" { if stringClaims == "" {
return !required return errorIfRequired(required, "aud")
} }
return result return errorIfFalse(result, ErrTokenInvalidAudience)
} }
// VerifyIssuer compares the iss claim in claims against cmp. // verifyIssuer compares the iss claim in claims against cmp.
// //
// If iss is not set, it will return true if the claim is not required, // If iss is not set, it will succeed if the claim is not required,
// otherwise false will be returned. // otherwise ErrTokenRequiredClaimMissing will be returned.
// //
// Additionally, if any error occurs while retrieving the claim, e.g., when its // Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned. // the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) VerifyIssuer(claims Claims, cmp string, required bool) bool { func (v *validator) verifyIssuer(claims Claims, cmp string, required bool) error {
iss, err := claims.GetIssuer() iss, err := claims.GetIssuer()
if err != nil { if err != nil {
return false return err
} }
if iss == "" { if iss == "" {
return !required return errorIfRequired(required, "iss")
} }
return iss == cmp return errorIfFalse(iss == cmp, ErrTokenInvalidIssuer)
} }
// VerifySubject compares the sub claim against cmp. // verifySubject compares the sub claim against cmp.
// //
// If sub is not set, it will return true if the claim is not required, // If sub is not set, it will succeed if the claim is not required,
// otherwise false will be returned. // otherwise ErrTokenRequiredClaimMissing will be returned.
// //
// Additionally, if any error occurs while retrieving the claim, e.g., when its // Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned. // the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) VerifySubject(claims Claims, cmp string, required bool) bool { func (v *validator) verifySubject(claims Claims, cmp string, required bool) error {
sub, err := claims.GetSubject() sub, err := claims.GetSubject()
if err != nil { if err != nil {
return false return err
} }
if sub == "" { if sub == "" {
return !required return errorIfRequired(required, "sub")
} }
return sub == cmp return errorIfFalse(sub == cmp, ErrTokenInvalidSubject)
}
// errorIfFalse returns the error specified in err, if the value is true.
// Otherwise, nil is returned.
func errorIfFalse(value bool, err error) error {
if value {
return nil
} else {
return err
}
}
// errorIfRequired returns an ErrTokenRequiredClaimMissing error if required is
// true. Otherwise, nil is returned.
func errorIfRequired(required bool, claim string) error {
if required {
return newError(fmt.Sprintf("%s claim is required", claim), ErrTokenRequiredClaimMissing)
} else {
return nil
}
} }

261
validator_test.go Normal file
View File

@ -0,0 +1,261 @@
package jwt
import (
"errors"
"testing"
"time"
)
var ErrFooBar = errors.New("must be foobar")
type MyCustomClaims struct {
Foo string `json:"foo"`
RegisteredClaims
}
func (m MyCustomClaims) Validate() error {
if m.Foo != "bar" {
return ErrFooBar
}
return nil
}
func Test_validator_Validate(t *testing.T) {
type fields struct {
leeway time.Duration
timeFunc func() time.Time
verifyIat bool
expectedAud string
expectedIss string
expectedSub string
}
type args struct {
claims Claims
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "expected iss mismatch",
fields: fields{expectedIss: "me"},
args: args{RegisteredClaims{Issuer: "not_me"}},
wantErr: ErrTokenInvalidIssuer,
},
{
name: "expected iss is missing",
fields: fields{expectedIss: "me"},
args: args{RegisteredClaims{}},
wantErr: ErrTokenRequiredClaimMissing,
},
{
name: "expected sub mismatch",
fields: fields{expectedSub: "me"},
args: args{RegisteredClaims{Subject: "not-me"}},
wantErr: ErrTokenInvalidSubject,
},
{
name: "expected sub is missing",
fields: fields{expectedSub: "me"},
args: args{RegisteredClaims{}},
wantErr: ErrTokenRequiredClaimMissing,
},
{
name: "custom validator",
fields: fields{},
args: args{MyCustomClaims{Foo: "not-bar"}},
wantErr: ErrFooBar,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validator{
leeway: tt.fields.leeway,
timeFunc: tt.fields.timeFunc,
verifyIat: tt.fields.verifyIat,
expectedAud: tt.fields.expectedAud,
expectedIss: tt.fields.expectedIss,
expectedSub: tt.fields.expectedSub,
}
if err := v.Validate(tt.args.claims); (err != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("validator.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_validator_verifyExpiresAt(t *testing.T) {
type fields struct {
leeway time.Duration
timeFunc func() time.Time
}
type args struct {
claims Claims
cmp time.Time
required bool
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "good claim",
fields: fields{timeFunc: time.Now},
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(time.Now().Add(10 * time.Minute))}},
wantErr: nil,
},
{
name: "claims with invalid type",
fields: fields{},
args: args{claims: MapClaims{"exp": "string"}},
wantErr: ErrInvalidType,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validator{
leeway: tt.fields.leeway,
timeFunc: tt.fields.timeFunc,
}
err := v.verifyExpiresAt(tt.args.claims, tt.args.cmp, tt.args.required)
if (err != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("validator.verifyExpiresAt() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_validator_verifyIssuer(t *testing.T) {
type fields struct {
expectedIss string
}
type args struct {
claims Claims
cmp string
required bool
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "good claim",
fields: fields{expectedIss: "me"},
args: args{claims: MapClaims{"iss": "me"}, cmp: "me"},
wantErr: nil,
},
{
name: "claims with invalid type",
fields: fields{expectedIss: "me"},
args: args{claims: MapClaims{"iss": 1}, cmp: "me"},
wantErr: ErrInvalidType,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validator{
expectedIss: tt.fields.expectedIss,
}
err := v.verifyIssuer(tt.args.claims, tt.args.cmp, tt.args.required)
if (err != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("validator.verifyIssuer() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_validator_verifySubject(t *testing.T) {
type fields struct {
expectedSub string
}
type args struct {
claims Claims
cmp string
required bool
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "good claim",
fields: fields{expectedSub: "me"},
args: args{claims: MapClaims{"sub": "me"}, cmp: "me"},
wantErr: nil,
},
{
name: "claims with invalid type",
fields: fields{expectedSub: "me"},
args: args{claims: MapClaims{"sub": 1}, cmp: "me"},
wantErr: ErrInvalidType,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validator{
expectedSub: tt.fields.expectedSub,
}
err := v.verifySubject(tt.args.claims, tt.args.cmp, tt.args.required)
if (err != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("validator.verifySubject() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_validator_verifyIssuedAt(t *testing.T) {
type fields struct {
leeway time.Duration
timeFunc func() time.Time
verifyIat bool
}
type args struct {
claims Claims
cmp time.Time
required bool
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "good claim without iat",
fields: fields{verifyIat: true},
args: args{claims: MapClaims{}, required: false},
wantErr: nil,
},
{
name: "good claim with iat",
fields: fields{verifyIat: true},
args: args{
claims: RegisteredClaims{IssuedAt: NewNumericDate(time.Now())},
cmp: time.Now().Add(10 * time.Minute),
required: false,
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &validator{
leeway: tt.fields.leeway,
timeFunc: tt.fields.timeFunc,
verifyIat: tt.fields.verifyIat,
}
if err := v.verifyIssuedAt(tt.args.claims, tt.args.cmp, tt.args.required); (err != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("validator.verifyIssuedAt() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}