mirror of https://github.com/golang-jwt/jwt.git
More consistent way of handling validation errors (#274)
This commit is contained in:
parent
4e6e1ba2bb
commit
28dc52370e
|
@ -25,7 +25,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go: [1.17, 1.18, 1.19]
|
||||
go: ["1.18", "1.19", "1.20"]
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
|
106
errors.go
106
errors.go
|
@ -2,18 +2,17 @@ package jwt
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Error constants
|
||||
var (
|
||||
ErrInvalidKey = errors.New("key is invalid")
|
||||
ErrInvalidKeyType = errors.New("key is of invalid type")
|
||||
ErrHashUnavailable = errors.New("the requested hash function is unavailable")
|
||||
|
||||
ErrTokenMalformed = errors.New("token is malformed")
|
||||
ErrTokenUnverifiable = errors.New("token is unverifiable")
|
||||
ErrTokenSignatureInvalid = errors.New("token signature is invalid")
|
||||
|
||||
ErrTokenRequiredClaimMissing = errors.New("token is missing required claim")
|
||||
ErrTokenInvalidAudience = errors.New("token has invalid audience")
|
||||
ErrTokenExpired = errors.New("token is expired")
|
||||
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
|
||||
|
@ -22,100 +21,29 @@ var (
|
|||
ErrTokenNotValidYet = errors.New("token is not valid yet")
|
||||
ErrTokenInvalidId = errors.New("token has invalid id")
|
||||
ErrTokenInvalidClaims = errors.New("token has invalid claims")
|
||||
|
||||
ErrInvalidType = errors.New("invalid type for claim")
|
||||
)
|
||||
|
||||
// The errors that might occur when parsing and validating a token
|
||||
const (
|
||||
ValidationErrorMalformed uint32 = 1 << iota // Token is malformed
|
||||
ValidationErrorUnverifiable // Token could not be verified because of signing problems
|
||||
ValidationErrorSignatureInvalid // Signature validation failed
|
||||
|
||||
// Registered Claim validation errors
|
||||
ValidationErrorAudience // AUD validation failed
|
||||
ValidationErrorExpired // EXP validation failed
|
||||
ValidationErrorIssuedAt // IAT validation failed
|
||||
ValidationErrorIssuer // ISS validation failed
|
||||
ValidationErrorSubject // SUB validation failed
|
||||
ValidationErrorNotValidYet // NBF validation failed
|
||||
ValidationErrorId // JTI validation failed
|
||||
ValidationErrorClaimsInvalid // Generic claims validation error
|
||||
)
|
||||
|
||||
// NewValidationError is a helper for constructing a ValidationError with a string error message
|
||||
func NewValidationError(errorText string, errorFlags uint32) *ValidationError {
|
||||
return &ValidationError{
|
||||
text: errorText,
|
||||
Errors: errorFlags,
|
||||
}
|
||||
// joinedError is an error type that works similar to what [errors.Join]
|
||||
// produces, with the exception that it has a nice error string; mainly its
|
||||
// error messages are concatenated using a comma, rather than a newline.
|
||||
type joinedError struct {
|
||||
errs []error
|
||||
}
|
||||
|
||||
// ValidationError represents an error from Parse if token is not valid
|
||||
type ValidationError struct {
|
||||
// Inner stores the error returned by external dependencies, e.g.: KeyFunc
|
||||
Inner error
|
||||
// Errors is a bit-field. See ValidationError... constants
|
||||
Errors uint32
|
||||
// Text can be used for errors that do not have a valid error just have text
|
||||
text string
|
||||
func (je joinedError) Error() string {
|
||||
msg := []string{}
|
||||
for _, err := range je.errs {
|
||||
msg = append(msg, err.Error())
|
||||
}
|
||||
|
||||
// Error is the implementation of the err interface.
|
||||
func (e ValidationError) Error() string {
|
||||
if e.Inner != nil {
|
||||
return e.Inner.Error()
|
||||
} else if e.text != "" {
|
||||
return e.text
|
||||
} else {
|
||||
return "token is invalid"
|
||||
}
|
||||
return strings.Join(msg, ", ")
|
||||
}
|
||||
|
||||
// Unwrap gives errors.Is and errors.As access to the inner error.
|
||||
func (e *ValidationError) Unwrap() error {
|
||||
return e.Inner
|
||||
// joinErrors joins together multiple errors. Useful for scenarios where
|
||||
// multiple errors next to each other occur, e.g., in claims validation.
|
||||
func joinErrors(errs ...error) error {
|
||||
return &joinedError{
|
||||
errs: errs,
|
||||
}
|
||||
|
||||
// No errors
|
||||
func (e *ValidationError) valid() bool {
|
||||
return e.Errors == 0
|
||||
}
|
||||
|
||||
// Is checks if this ValidationError is of the supplied error. We are first
|
||||
// checking for the exact error message by comparing the inner error message. If
|
||||
// that fails, we compare using the error flags. This way we can use custom
|
||||
// error messages (mainly for backwards compatibility) and still leverage
|
||||
// errors.Is using the global error variables.
|
||||
func (e *ValidationError) Is(err error) bool {
|
||||
// Check, if our inner error is a direct match
|
||||
if errors.Is(errors.Unwrap(e), err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Otherwise, we need to match using our error flags
|
||||
switch err {
|
||||
case ErrTokenMalformed:
|
||||
return e.Errors&ValidationErrorMalformed != 0
|
||||
case ErrTokenUnverifiable:
|
||||
return e.Errors&ValidationErrorUnverifiable != 0
|
||||
case ErrTokenSignatureInvalid:
|
||||
return e.Errors&ValidationErrorSignatureInvalid != 0
|
||||
case ErrTokenInvalidAudience:
|
||||
return e.Errors&ValidationErrorAudience != 0
|
||||
case ErrTokenExpired:
|
||||
return e.Errors&ValidationErrorExpired != 0
|
||||
case ErrTokenUsedBeforeIssued:
|
||||
return e.Errors&ValidationErrorIssuedAt != 0
|
||||
case ErrTokenInvalidIssuer:
|
||||
return e.Errors&ValidationErrorIssuer != 0
|
||||
case ErrTokenNotValidYet:
|
||||
return e.Errors&ValidationErrorNotValidYet != 0
|
||||
case ErrTokenInvalidId:
|
||||
return e.Errors&ValidationErrorId != 0
|
||||
case ErrTokenInvalidClaims:
|
||||
return e.Errors&ValidationErrorClaimsInvalid != 0
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
//go:build go1.20
|
||||
// +build go1.20
|
||||
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Unwrap implements the multiple error unwrapping for this error type, which is
|
||||
// possible in Go 1.20.
|
||||
func (je joinedError) Unwrap() []error {
|
||||
return je.errs
|
||||
}
|
||||
|
||||
// newError creates a new error message with a detailed error message. The
|
||||
// message will be prefixed with the contents of the supplied error type.
|
||||
// Additionally, more errors, that provide more context can be supplied which
|
||||
// will be appended to the message. This makes use of Go 1.20's possibility to
|
||||
// include more than one %w formatting directive in [fmt.Errorf].
|
||||
//
|
||||
// For example,
|
||||
//
|
||||
// newError("no keyfunc was provided", ErrTokenUnverifiable)
|
||||
//
|
||||
// will produce the error string
|
||||
//
|
||||
// "token is unverifiable: no keyfunc was provided"
|
||||
func newError(message string, err error, more ...error) error {
|
||||
var format string
|
||||
var args []any
|
||||
if message != "" {
|
||||
format = "%w: %s"
|
||||
args = []any{err, message}
|
||||
} else {
|
||||
format = "%w"
|
||||
args = []any{err}
|
||||
}
|
||||
|
||||
for _, e := range more {
|
||||
format += ": %w"
|
||||
args = append(args, e)
|
||||
}
|
||||
|
||||
err = fmt.Errorf(format, args...)
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
//go:build !go1.20
|
||||
// +build !go1.20
|
||||
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Is implements checking for multiple errors using [errors.Is], since multiple
|
||||
// error unwrapping is not possible in versions less than Go 1.20.
|
||||
func (je joinedError) Is(err error) bool {
|
||||
for _, e := range je.errs {
|
||||
if errors.Is(e, err) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// wrappedErrors is a workaround for wrapping multiple errors in environments
|
||||
// where Go 1.20 is not available. It basically uses the already implemented
|
||||
// functionatlity of joinedError to handle multiple errors with supplies a
|
||||
// custom error message that is identical to the one we produce in Go 1.20 using
|
||||
// multiple %w directives.
|
||||
type wrappedErrors struct {
|
||||
msg string
|
||||
joinedError
|
||||
}
|
||||
|
||||
// Error returns the stored error string
|
||||
func (we wrappedErrors) Error() string {
|
||||
return we.msg
|
||||
}
|
||||
|
||||
// newError creates a new error message with a detailed error message. The
|
||||
// message will be prefixed with the contents of the supplied error type.
|
||||
// Additionally, more errors, that provide more context can be supplied which
|
||||
// will be appended to the message. Since we cannot use of Go 1.20's possibility
|
||||
// to include more than one %w formatting directive in [fmt.Errorf], we have to
|
||||
// emulate that.
|
||||
//
|
||||
// For example,
|
||||
//
|
||||
// newError("no keyfunc was provided", ErrTokenUnverifiable)
|
||||
//
|
||||
// will produce the error string
|
||||
//
|
||||
// "token is unverifiable: no keyfunc was provided"
|
||||
func newError(message string, err error, more ...error) error {
|
||||
// We cannot wrap multiple errors here with %w, so we have to be a little
|
||||
// bit creative. Basically, we are using %s instead of %w to produce the
|
||||
// same error message and then throw the result into a custom error struct.
|
||||
var format string
|
||||
var args []any
|
||||
if message != "" {
|
||||
format = "%s: %s"
|
||||
args = []any{err, message}
|
||||
} else {
|
||||
format = "%s"
|
||||
args = []any{err}
|
||||
}
|
||||
errs := []error{err}
|
||||
|
||||
for _, e := range more {
|
||||
format += ": %s"
|
||||
args = append(args, e)
|
||||
errs = append(errs, e)
|
||||
}
|
||||
|
||||
err = &wrappedErrors{
|
||||
msg: fmt.Sprintf(format, args...),
|
||||
joinedError: joinedError{errs: errs},
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_joinErrors(t *testing.T) {
|
||||
type args struct {
|
||||
errs []error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErrors []error
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "multiple errors",
|
||||
args: args{
|
||||
errs: []error{ErrTokenNotValidYet, ErrTokenExpired},
|
||||
},
|
||||
wantErrors: []error{ErrTokenNotValidYet, ErrTokenExpired},
|
||||
wantMessage: "token is not valid yet, token is expired",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := joinErrors(tt.args.errs...)
|
||||
for _, wantErr := range tt.wantErrors {
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Errorf("joinErrors() error = %v, does not contain %v", err, wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
if err.Error() != tt.wantMessage {
|
||||
t.Errorf("joinErrors() error.Error() = %v, wantMessage %v", err, tt.wantMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_newError(t *testing.T) {
|
||||
type args struct {
|
||||
message string
|
||||
err error
|
||||
more []error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErrors []error
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "single error",
|
||||
args: args{message: "something is wrong", err: ErrTokenMalformed},
|
||||
wantMessage: "token is malformed: something is wrong",
|
||||
wantErrors: []error{ErrTokenMalformed},
|
||||
},
|
||||
{
|
||||
name: "two errors",
|
||||
args: args{message: "something is wrong", err: ErrTokenMalformed, more: []error{io.ErrUnexpectedEOF}},
|
||||
wantMessage: "token is malformed: something is wrong: unexpected EOF",
|
||||
wantErrors: []error{ErrTokenMalformed},
|
||||
},
|
||||
{
|
||||
name: "two errors, no detail",
|
||||
args: args{message: "", err: ErrTokenInvalidClaims, more: []error{ErrTokenExpired}},
|
||||
wantMessage: "token has invalid claims: token is expired",
|
||||
wantErrors: []error{ErrTokenInvalidClaims, ErrTokenExpired},
|
||||
},
|
||||
{
|
||||
name: "two errors, no detail and join error",
|
||||
args: args{message: "", err: ErrTokenInvalidClaims, more: []error{joinErrors(ErrTokenExpired, ErrTokenNotValidYet)}},
|
||||
wantMessage: "token has invalid claims: token is expired, token is not valid yet",
|
||||
wantErrors: []error{ErrTokenInvalidClaims, ErrTokenExpired, ErrTokenNotValidYet},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := newError(tt.args.message, tt.args.err, tt.args.more...)
|
||||
for _, wantErr := range tt.wantErrors {
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Errorf("newError() error = %v, does not contain %v", err, wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
if err.Error() != tt.wantMessage {
|
||||
t.Errorf("newError() error.Error() = %v, wantMessage %v", err, tt.wantMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ package jwt
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// MapClaims is a claims type that uses the map[string]interface{} for JSON
|
||||
|
@ -60,7 +61,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) {
|
|||
return newNumericDateFromSeconds(v), nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidType
|
||||
return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
|
||||
}
|
||||
|
||||
// parseClaimsString tries to parse a key in the map claims type as a
|
||||
|
@ -76,7 +77,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
|
|||
for _, a := range v {
|
||||
vs, ok := a.(string)
|
||||
if !ok {
|
||||
return nil, ErrInvalidType
|
||||
return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
|
||||
}
|
||||
cs = append(cs, vs)
|
||||
}
|
||||
|
@ -101,7 +102,7 @@ func (m MapClaims) parseString(key string) (string, error) {
|
|||
|
||||
iss, ok = raw.(string)
|
||||
if !ok {
|
||||
return "", ErrInvalidType
|
||||
return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
|
||||
}
|
||||
|
||||
return iss, nil
|
||||
|
|
7
none.go
7
none.go
|
@ -13,7 +13,7 @@ type unsafeNoneMagicConstant string
|
|||
|
||||
func init() {
|
||||
SigningMethodNone = &signingMethodNone{}
|
||||
NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid)
|
||||
NoneSignatureTypeDisallowedError = newError("'none' signature type is not allowed", ErrTokenUnverifiable)
|
||||
|
||||
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
|
||||
return SigningMethodNone
|
||||
|
@ -33,10 +33,7 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac
|
|||
}
|
||||
// If signing method is none, signature must be an empty string
|
||||
if signature != "" {
|
||||
return NewValidationError(
|
||||
"'none' signing method with non-empty signature",
|
||||
ValidationErrorSignatureInvalid,
|
||||
)
|
||||
return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable)
|
||||
}
|
||||
|
||||
// Accept 'none' signing method.
|
||||
|
|
53
parser.go
53
parser.go
|
@ -65,7 +65,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
|
|||
}
|
||||
if !signingMethodValid {
|
||||
// signing method is not in the listed set
|
||||
return token, NewValidationError(fmt.Sprintf("signing method %v is invalid", alg), ValidationErrorSignatureInvalid)
|
||||
return token, newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -73,17 +73,17 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
|
|||
var key interface{}
|
||||
if keyFunc == nil {
|
||||
// keyFunc was not provided. short circuiting validation
|
||||
return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable)
|
||||
return token, newError("no keyfunc was provided", ErrTokenUnverifiable)
|
||||
}
|
||||
if key, err = keyFunc(token); err != nil {
|
||||
// keyFunc returned an error
|
||||
if ve, ok := err.(*ValidationError); ok {
|
||||
return token, ve
|
||||
}
|
||||
return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable}
|
||||
return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err)
|
||||
}
|
||||
|
||||
vErr := &ValidationError{}
|
||||
// Perform signature validation
|
||||
token.Signature = parts[2]
|
||||
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
|
||||
return token, newError("", ErrTokenSignatureInvalid, err)
|
||||
}
|
||||
|
||||
// Validate Claims
|
||||
if !p.skipClaimsValidation {
|
||||
|
@ -93,29 +93,14 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
|
|||
}
|
||||
|
||||
if err := p.validator.Validate(claims); err != nil {
|
||||
// If the Claims Valid returned an error, check if it is a validation error,
|
||||
// If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set
|
||||
if e, ok := err.(*ValidationError); !ok {
|
||||
vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid}
|
||||
} else {
|
||||
vErr = e
|
||||
}
|
||||
return token, newError("", ErrTokenInvalidClaims, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Perform validation
|
||||
token.Signature = parts[2]
|
||||
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
|
||||
vErr.Inner = err
|
||||
vErr.Errors |= ValidationErrorSignatureInvalid
|
||||
}
|
||||
|
||||
if vErr.valid() {
|
||||
// No errors so far, token is valid.
|
||||
token.Valid = true
|
||||
return token, nil
|
||||
}
|
||||
|
||||
return token, vErr
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ParseUnverified parses the token but doesn't validate the signature.
|
||||
|
@ -127,7 +112,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
|
|||
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
|
||||
parts = strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
|
||||
return nil, parts, newError("token contains an invalid number of segments", ErrTokenMalformed)
|
||||
}
|
||||
|
||||
token = &Token{Raw: tokenString}
|
||||
|
@ -136,12 +121,12 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
|
|||
var headerBytes []byte
|
||||
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
|
||||
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
|
||||
return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed)
|
||||
return token, parts, newError("tokenstring should not contain 'bearer '", ErrTokenMalformed)
|
||||
}
|
||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||
return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err)
|
||||
}
|
||||
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
|
||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||
return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err)
|
||||
}
|
||||
|
||||
// parse Claims
|
||||
|
@ -149,7 +134,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
|
|||
token.Claims = claims
|
||||
|
||||
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
|
||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||
return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err)
|
||||
}
|
||||
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
|
||||
if p.useJSONNumber {
|
||||
|
@ -163,16 +148,16 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
|
|||
}
|
||||
// Handle decode error
|
||||
if err != nil {
|
||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
||||
return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err)
|
||||
}
|
||||
|
||||
// Lookup signature method
|
||||
if method, ok := token.Header["alg"].(string); ok {
|
||||
if token.Method = GetSigningMethod(method); token.Method == nil {
|
||||
return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
|
||||
return token, parts, newError("signing method (alg) is unavailable", ErrTokenUnverifiable)
|
||||
}
|
||||
} else {
|
||||
return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable)
|
||||
return token, parts, newError("signing method (alg) is unspecified", ErrTokenUnverifiable)
|
||||
}
|
||||
|
||||
return token, parts, nil
|
||||
|
|
|
@ -51,7 +51,6 @@ var jwtTestData = []struct {
|
|||
keyfunc jwt.Keyfunc
|
||||
claims jwt.Claims
|
||||
valid bool
|
||||
errors uint32
|
||||
err []error
|
||||
parser *jwt.Parser
|
||||
signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose
|
||||
|
@ -62,7 +61,16 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
nil,
|
||||
false,
|
||||
jwt.ValidationErrorMalformed,
|
||||
[]error{jwt.ErrTokenMalformed},
|
||||
nil,
|
||||
jwt.SigningMethodRS256,
|
||||
},
|
||||
{
|
||||
"invalid JSON claim",
|
||||
"eyJhbGciOiJSUzI1NiIsInppcCI6IkRFRiJ9.eNqqVkqtKFCyMjQ1s7Q0sbA0MtFRyk3NTUot8kxRslIKLbZQggn4JeamAoUcfRz99HxcXRWeze172tr4bFq7Ui0AAAD__w.jBXD4LT4aq4oXTgDoPkiV6n4QdSZPZI1Z4J8MWQC42aHK0oXwcovEU06dVbtB81TF-2byuu0-qi8J0GUttODT67k6gCl6DV_iuCOV7gczwTcvKslotUvXzoJ2wa0QuujnjxLEE50r0p6k0tsv_9OIFSUZzDksJFYNPlJH2eFG55DROx4TsOz98az37SujZi9GGbTc9SLgzFHPrHMrovRZ5qLC_w4JrdtsLzBBI11OQJgRYwV8fQf4O8IsMkHtetjkN7dKgUkJtRarNWOk76rpTPppLypiLU4_J0-wrElLMh1TzUVZW6Fz2cDHDDBACJgMmKQ2pOFEDK_vYZN74dLCF5GiTZV6DbXhNxO7lqT7JUN4a3p2z96G7WNRjblf2qZeuYdQvkIsiK-rCbSIE836XeY5gaBgkOzuEvzl_tMrpRmb5Oox1ibOfVT2KBh9Lvqsb1XbQjCio2CLE2ViCLqoe0AaRqlUyrk3n8BIG-r0IW4dcw96CEryEMIjsjVp9mtPXamJzf391kt8Rf3iRBqwv3zP7Plg1ResXbmsFUgOflAUPcYmfLug4W3W52ntcUlTHAKXrNfaJL9QQiYAaDukG-ZHDytsOWTuuXw7lVxjt-XYi1VbRAIjh1aIYSELEmEpE4Ny74htQtywYXMQNfJpB0nNn8IiWakgcYYMJ0TmKM",
|
||||
defaultKeyFunc,
|
||||
nil,
|
||||
false,
|
||||
[]error{jwt.ErrTokenMalformed},
|
||||
nil,
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -73,7 +81,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
nil,
|
||||
false,
|
||||
jwt.ValidationErrorMalformed,
|
||||
[]error{jwt.ErrTokenMalformed},
|
||||
nil,
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -84,7 +91,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar"},
|
||||
true,
|
||||
0,
|
||||
nil,
|
||||
nil,
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -95,7 +101,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
|
||||
false,
|
||||
jwt.ValidationErrorExpired,
|
||||
[]error{jwt.ErrTokenExpired},
|
||||
nil,
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -106,7 +111,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
|
||||
false,
|
||||
jwt.ValidationErrorNotValidYet,
|
||||
[]error{jwt.ErrTokenNotValidYet},
|
||||
nil,
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -117,8 +121,7 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
|
||||
false,
|
||||
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
|
||||
[]error{jwt.ErrTokenNotValidYet},
|
||||
[]error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired},
|
||||
nil,
|
||||
jwt.SigningMethodRS256,
|
||||
},
|
||||
|
@ -128,7 +131,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar"},
|
||||
false,
|
||||
jwt.ValidationErrorSignatureInvalid,
|
||||
[]error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification},
|
||||
nil,
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -139,7 +141,6 @@ var jwtTestData = []struct {
|
|||
nilKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar"},
|
||||
false,
|
||||
jwt.ValidationErrorUnverifiable,
|
||||
[]error{jwt.ErrTokenUnverifiable},
|
||||
nil,
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -150,7 +151,6 @@ var jwtTestData = []struct {
|
|||
emptyKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar"},
|
||||
false,
|
||||
jwt.ValidationErrorSignatureInvalid,
|
||||
[]error{jwt.ErrTokenSignatureInvalid},
|
||||
nil,
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -161,7 +161,6 @@ var jwtTestData = []struct {
|
|||
errorKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar"},
|
||||
false,
|
||||
jwt.ValidationErrorUnverifiable,
|
||||
[]error{jwt.ErrTokenUnverifiable, errKeyFuncError},
|
||||
nil,
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -172,7 +171,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar"},
|
||||
false,
|
||||
jwt.ValidationErrorSignatureInvalid,
|
||||
[]error{jwt.ErrTokenSignatureInvalid},
|
||||
jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -183,7 +181,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar"},
|
||||
true,
|
||||
0,
|
||||
nil,
|
||||
jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -194,7 +191,6 @@ var jwtTestData = []struct {
|
|||
ecdsaKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar"},
|
||||
false,
|
||||
jwt.ValidationErrorSignatureInvalid,
|
||||
[]error{jwt.ErrTokenSignatureInvalid},
|
||||
jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})),
|
||||
jwt.SigningMethodES256,
|
||||
|
@ -205,7 +201,6 @@ var jwtTestData = []struct {
|
|||
ecdsaKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar"},
|
||||
true,
|
||||
0,
|
||||
nil,
|
||||
jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})),
|
||||
jwt.SigningMethodES256,
|
||||
|
@ -216,7 +211,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": json.Number("123.4")},
|
||||
true,
|
||||
0,
|
||||
nil,
|
||||
jwt.NewParser(jwt.WithJSONNumber()),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -227,7 +221,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
|
||||
false,
|
||||
jwt.ValidationErrorExpired,
|
||||
[]error{jwt.ErrTokenExpired},
|
||||
jwt.NewParser(jwt.WithJSONNumber()),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -238,7 +231,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
|
||||
false,
|
||||
jwt.ValidationErrorNotValidYet,
|
||||
[]error{jwt.ErrTokenNotValidYet},
|
||||
jwt.NewParser(jwt.WithJSONNumber()),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -249,8 +241,7 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
|
||||
false,
|
||||
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
|
||||
[]error{jwt.ErrTokenNotValidYet},
|
||||
[]error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired},
|
||||
jwt.NewParser(jwt.WithJSONNumber()),
|
||||
jwt.SigningMethodRS256,
|
||||
},
|
||||
|
@ -260,7 +251,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
|
||||
true,
|
||||
0,
|
||||
nil,
|
||||
jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -273,7 +263,6 @@ var jwtTestData = []struct {
|
|||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Second * 10)),
|
||||
},
|
||||
true,
|
||||
0,
|
||||
nil,
|
||||
jwt.NewParser(jwt.WithJSONNumber()),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -286,7 +275,6 @@ var jwtTestData = []struct {
|
|||
Audience: jwt.ClaimStrings{"test"},
|
||||
},
|
||||
true,
|
||||
0,
|
||||
nil,
|
||||
jwt.NewParser(jwt.WithJSONNumber()),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -299,7 +287,6 @@ var jwtTestData = []struct {
|
|||
Audience: jwt.ClaimStrings{"test", "test"},
|
||||
},
|
||||
true,
|
||||
0,
|
||||
nil,
|
||||
jwt.NewParser(jwt.WithJSONNumber()),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -312,7 +299,6 @@ var jwtTestData = []struct {
|
|||
Audience: nil, // because of the unmarshal error, this will be empty
|
||||
},
|
||||
false,
|
||||
jwt.ValidationErrorMalformed,
|
||||
[]error{jwt.ErrTokenMalformed},
|
||||
jwt.NewParser(jwt.WithJSONNumber()),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -325,7 +311,6 @@ var jwtTestData = []struct {
|
|||
Audience: nil, // because of the unmarshal error, this will be empty
|
||||
},
|
||||
false,
|
||||
jwt.ValidationErrorMalformed,
|
||||
[]error{jwt.ErrTokenMalformed},
|
||||
jwt.NewParser(jwt.WithJSONNumber()),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -336,7 +321,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
&jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))},
|
||||
false,
|
||||
jwt.ValidationErrorNotValidYet,
|
||||
[]error{jwt.ErrTokenNotValidYet},
|
||||
jwt.NewParser(jwt.WithLeeway(time.Minute)),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -347,7 +331,6 @@ var jwtTestData = []struct {
|
|||
defaultKeyFunc,
|
||||
&jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))},
|
||||
true,
|
||||
0,
|
||||
nil,
|
||||
jwt.NewParser(jwt.WithLeeway(2 * time.Minute)),
|
||||
jwt.SigningMethodRS256,
|
||||
|
@ -381,7 +364,6 @@ func TestParser_Parse(t *testing.T) {
|
|||
|
||||
// Parse the token
|
||||
var token *jwt.Token
|
||||
var ve *jwt.ValidationError
|
||||
var err error
|
||||
var parser = data.parser
|
||||
if parser == nil {
|
||||
|
@ -417,23 +399,6 @@ func TestParser_Parse(t *testing.T) {
|
|||
t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name)
|
||||
}
|
||||
|
||||
if data.errors != 0 {
|
||||
if err == nil {
|
||||
t.Errorf("[%v] Expecting error. Didn't get one.", data.name)
|
||||
} else {
|
||||
if errors.As(err, &ve) {
|
||||
// compare the bitfield part of the error
|
||||
if e := ve.Errors; e != data.errors {
|
||||
t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors)
|
||||
}
|
||||
|
||||
if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError {
|
||||
t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if data.err != nil {
|
||||
if err == nil {
|
||||
t.Errorf("[%v] Expecting error(s). Didn't get one.", data.name)
|
||||
|
@ -467,7 +432,7 @@ func TestParser_ParseUnverified(t *testing.T) {
|
|||
// Iterate over test data set and run tests
|
||||
for _, data := range jwtTestData {
|
||||
// Skip test data, that intentionally contains malformed tokens, as they would lead to an error
|
||||
if data.errors&jwt.ValidationErrorMalformed != 0 {
|
||||
if len(data.err) == 1 && errors.Is(data.err[0], jwt.ErrTokenMalformed) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
225
validator.go
225
validator.go
|
@ -2,9 +2,32 @@ package jwt
|
|||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClaimsValidator is an interface that can be implemented by custom claims who
|
||||
// wish to execute any additional claims validation based on
|
||||
// application-specific logic. The Validate function is then executed in
|
||||
// addition to the regular claims validation and any error returned is appended
|
||||
// to the final validation result.
|
||||
//
|
||||
// type MyCustomClaims struct {
|
||||
// Foo string `json:"foo"`
|
||||
// jwt.RegisteredClaims
|
||||
// }
|
||||
//
|
||||
// func (m MyCustomClaims) Validate() error {
|
||||
// if m.Foo != "bar" {
|
||||
// return errors.New("must be foobar")
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
type ClaimsValidator interface {
|
||||
Claims
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// validator is the core of the new Validation API. It is automatically used by
|
||||
// a [Parser] during parsing and can be modified with various parser options.
|
||||
//
|
||||
|
@ -47,10 +70,13 @@ func newValidator(opts ...ParserOption) *validator {
|
|||
}
|
||||
|
||||
// Validate validates the given claims. It will also perform any custom
|
||||
// validation if claims implements the CustomValidator interface.
|
||||
// validation if claims implements the [ClaimsValidator] interface.
|
||||
func (v *validator) Validate(claims Claims) error {
|
||||
var now time.Time
|
||||
vErr := new(ValidationError)
|
||||
var (
|
||||
now time.Time
|
||||
errs []error = make([]error, 0, 6)
|
||||
err error
|
||||
)
|
||||
|
||||
// Check, if we have a time func
|
||||
if v.timeFunc != nil {
|
||||
|
@ -60,140 +86,139 @@ func (v *validator) Validate(claims Claims) error {
|
|||
}
|
||||
|
||||
// We always need to check the expiration time, but usage of the claim
|
||||
// itself is OPTIONAL
|
||||
if !v.VerifyExpiresAt(claims, now, false) {
|
||||
vErr.Inner = ErrTokenExpired
|
||||
vErr.Errors |= ValidationErrorExpired
|
||||
// itself is OPTIONAL.
|
||||
if err = v.verifyExpiresAt(claims, now, false); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
// We always need to check not-before, but usage of the claim itself is
|
||||
// OPTIONAL
|
||||
if !v.VerifyNotBefore(claims, now, false) {
|
||||
vErr.Inner = ErrTokenNotValidYet
|
||||
vErr.Errors |= ValidationErrorNotValidYet
|
||||
// OPTIONAL.
|
||||
if err = v.verifyNotBefore(claims, now, false); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
// Check issued-at if the option is enabled
|
||||
if v.verifyIat && !v.VerifyIssuedAt(claims, now, false) {
|
||||
vErr.Inner = ErrTokenUsedBeforeIssued
|
||||
vErr.Errors |= ValidationErrorIssuedAt
|
||||
if v.verifyIat {
|
||||
if err = v.verifyIssuedAt(claims, now, false); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If we have an expected audience, we also require the audience claim
|
||||
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) {
|
||||
vErr.Inner = ErrTokenInvalidAudience
|
||||
vErr.Errors |= ValidationErrorAudience
|
||||
if v.expectedAud != "" {
|
||||
if err = v.verifyAudience(claims, v.expectedAud, true); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If we have an expected issuer, we also require the issuer claim
|
||||
if v.expectedIss != "" && !v.VerifyIssuer(claims, v.expectedIss, true) {
|
||||
vErr.Inner = ErrTokenInvalidIssuer
|
||||
vErr.Errors |= ValidationErrorIssuer
|
||||
if v.expectedIss != "" {
|
||||
if err = v.verifyIssuer(claims, v.expectedIss, true); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If we have an expected subject, we also require the subject claim
|
||||
if v.expectedSub != "" && !v.VerifySubject(claims, v.expectedSub, true) {
|
||||
vErr.Inner = ErrTokenInvalidSubject
|
||||
vErr.Errors |= ValidationErrorSubject
|
||||
if v.expectedSub != "" {
|
||||
if err = v.verifySubject(claims, v.expectedSub, true); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, we want to give the claim itself some possibility to do some
|
||||
// additional custom validation based on a custom Validate function.
|
||||
cvt, ok := claims.(interface {
|
||||
Validate() error
|
||||
})
|
||||
cvt, ok := claims.(ClaimsValidator)
|
||||
if ok {
|
||||
if err := cvt.Validate(); err != nil {
|
||||
vErr.Inner = err
|
||||
vErr.Errors |= ValidationErrorClaimsInvalid
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if vErr.valid() {
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return vErr
|
||||
return joinErrors(errs...)
|
||||
}
|
||||
|
||||
// VerifyExpiresAt compares the exp claim in claims against cmp. This function
|
||||
// will return true if cmp < exp. Additional leeway is taken into account.
|
||||
// verifyExpiresAt compares the exp claim in claims against cmp. This function
|
||||
// will succeed if cmp < exp. Additional leeway is taken into account.
|
||||
//
|
||||
// If exp is not set, it will return true if the claim is not required,
|
||||
// otherwise false will be returned.
|
||||
// If exp is not set, it will succeed if the claim is not required,
|
||||
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||
//
|
||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||
// the wrong type, false will be returned.
|
||||
func (v *validator) VerifyExpiresAt(claims Claims, cmp time.Time, required bool) bool {
|
||||
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||
func (v *validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error {
|
||||
exp, err := claims.GetExpirationTime()
|
||||
if err != nil {
|
||||
return false
|
||||
return err
|
||||
}
|
||||
|
||||
if exp != nil {
|
||||
return cmp.Before((exp.Time).Add(+v.leeway))
|
||||
} else {
|
||||
return !required
|
||||
}
|
||||
if exp == nil {
|
||||
return errorIfRequired(required, "exp")
|
||||
}
|
||||
|
||||
// VerifyIssuedAt compares the iat claim in claims against cmp. This function
|
||||
// will return true if cmp >= iat. Additional leeway is taken into account.
|
||||
return errorIfFalse(cmp.Before((exp.Time).Add(+v.leeway)), ErrTokenExpired)
|
||||
}
|
||||
|
||||
// verifyIssuedAt compares the iat claim in claims against cmp. This function
|
||||
// will succeed if cmp >= iat. Additional leeway is taken into account.
|
||||
//
|
||||
// If iat is not set, it will return true if the claim is not required,
|
||||
// otherwise false will be returned.
|
||||
// If iat is not set, it will succeed if the claim is not required,
|
||||
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||
//
|
||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||
// the wrong type, false will be returned.
|
||||
func (v *validator) VerifyIssuedAt(claims Claims, cmp time.Time, required bool) bool {
|
||||
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||
func (v *validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error {
|
||||
iat, err := claims.GetIssuedAt()
|
||||
if err != nil {
|
||||
return false
|
||||
return err
|
||||
}
|
||||
|
||||
if iat != nil {
|
||||
return !cmp.Before(iat.Add(-v.leeway))
|
||||
} else {
|
||||
return !required
|
||||
}
|
||||
if iat == nil {
|
||||
return errorIfRequired(required, "iat")
|
||||
}
|
||||
|
||||
// VerifyNotBefore compares the nbf claim in claims against cmp. This function
|
||||
return errorIfFalse(!cmp.Before(iat.Add(-v.leeway)), ErrTokenUsedBeforeIssued)
|
||||
}
|
||||
|
||||
// verifyNotBefore compares the nbf claim in claims against cmp. This function
|
||||
// will return true if cmp >= nbf. Additional leeway is taken into account.
|
||||
//
|
||||
// If nbf is not set, it will return true if the claim is not required,
|
||||
// otherwise false will be returned.
|
||||
// If nbf is not set, it will succeed if the claim is not required,
|
||||
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||
//
|
||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||
// the wrong type, false will be returned.
|
||||
func (v *validator) VerifyNotBefore(claims Claims, cmp time.Time, required bool) bool {
|
||||
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||
func (v *validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error {
|
||||
nbf, err := claims.GetNotBefore()
|
||||
if err != nil {
|
||||
return false
|
||||
return err
|
||||
}
|
||||
|
||||
if nbf != nil {
|
||||
return !cmp.Before(nbf.Add(-v.leeway))
|
||||
} else {
|
||||
return !required
|
||||
}
|
||||
if nbf == nil {
|
||||
return errorIfRequired(required, "nbf")
|
||||
}
|
||||
|
||||
// VerifyAudience compares the aud claim against cmp.
|
||||
return errorIfFalse(!cmp.Before(nbf.Add(-v.leeway)), ErrTokenNotValidYet)
|
||||
}
|
||||
|
||||
// verifyAudience compares the aud claim against cmp.
|
||||
//
|
||||
// If aud is not set or an empty list, it will return true if the claim is not
|
||||
// required, otherwise false will be returned.
|
||||
// If aud is not set or an empty list, it will succeed if the claim is not required,
|
||||
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||
//
|
||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||
// the wrong type, false will be returned.
|
||||
func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) bool {
|
||||
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||
func (v *validator) verifyAudience(claims Claims, cmp string, required bool) error {
|
||||
aud, err := claims.GetAudience()
|
||||
if err != nil {
|
||||
return false
|
||||
return err
|
||||
}
|
||||
|
||||
if len(aud) == 0 {
|
||||
return !required
|
||||
return errorIfRequired(required, "aud")
|
||||
}
|
||||
|
||||
// use a var here to keep constant time compare when looping over a number of claims
|
||||
|
@ -209,48 +234,68 @@ func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) boo
|
|||
|
||||
// case where "" is sent in one or many aud claims
|
||||
if stringClaims == "" {
|
||||
return !required
|
||||
return errorIfRequired(required, "aud")
|
||||
}
|
||||
|
||||
return result
|
||||
return errorIfFalse(result, ErrTokenInvalidAudience)
|
||||
}
|
||||
|
||||
// VerifyIssuer compares the iss claim in claims against cmp.
|
||||
// verifyIssuer compares the iss claim in claims against cmp.
|
||||
//
|
||||
// If iss is not set, it will return true if the claim is not required,
|
||||
// otherwise false will be returned.
|
||||
// If iss is not set, it will succeed if the claim is not required,
|
||||
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||
//
|
||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||
// the wrong type, false will be returned.
|
||||
func (v *validator) VerifyIssuer(claims Claims, cmp string, required bool) bool {
|
||||
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||
func (v *validator) verifyIssuer(claims Claims, cmp string, required bool) error {
|
||||
iss, err := claims.GetIssuer()
|
||||
if err != nil {
|
||||
return false
|
||||
return err
|
||||
}
|
||||
|
||||
if iss == "" {
|
||||
return !required
|
||||
return errorIfRequired(required, "iss")
|
||||
}
|
||||
|
||||
return iss == cmp
|
||||
return errorIfFalse(iss == cmp, ErrTokenInvalidIssuer)
|
||||
}
|
||||
|
||||
// VerifySubject compares the sub claim against cmp.
|
||||
// verifySubject compares the sub claim against cmp.
|
||||
//
|
||||
// If sub is not set, it will return true if the claim is not required,
|
||||
// otherwise false will be returned.
|
||||
// If sub is not set, it will succeed if the claim is not required,
|
||||
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||
//
|
||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||
// the wrong type, false will be returned.
|
||||
func (v *validator) VerifySubject(claims Claims, cmp string, required bool) bool {
|
||||
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||
func (v *validator) verifySubject(claims Claims, cmp string, required bool) error {
|
||||
sub, err := claims.GetSubject()
|
||||
if err != nil {
|
||||
return false
|
||||
return err
|
||||
}
|
||||
|
||||
if sub == "" {
|
||||
return !required
|
||||
return errorIfRequired(required, "sub")
|
||||
}
|
||||
|
||||
return sub == cmp
|
||||
return errorIfFalse(sub == cmp, ErrTokenInvalidSubject)
|
||||
}
|
||||
|
||||
// errorIfFalse returns the error specified in err, if the value is true.
|
||||
// Otherwise, nil is returned.
|
||||
func errorIfFalse(value bool, err error) error {
|
||||
if value {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// errorIfRequired returns an ErrTokenRequiredClaimMissing error if required is
|
||||
// true. Otherwise, nil is returned.
|
||||
func errorIfRequired(required bool, claim string) error {
|
||||
if required {
|
||||
return newError(fmt.Sprintf("%s claim is required", claim), ErrTokenRequiredClaimMissing)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,261 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrFooBar = errors.New("must be foobar")
|
||||
|
||||
type MyCustomClaims struct {
|
||||
Foo string `json:"foo"`
|
||||
RegisteredClaims
|
||||
}
|
||||
|
||||
func (m MyCustomClaims) Validate() error {
|
||||
if m.Foo != "bar" {
|
||||
return ErrFooBar
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_validator_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
leeway time.Duration
|
||||
timeFunc func() time.Time
|
||||
verifyIat bool
|
||||
expectedAud string
|
||||
expectedIss string
|
||||
expectedSub string
|
||||
}
|
||||
type args struct {
|
||||
claims Claims
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "expected iss mismatch",
|
||||
fields: fields{expectedIss: "me"},
|
||||
args: args{RegisteredClaims{Issuer: "not_me"}},
|
||||
wantErr: ErrTokenInvalidIssuer,
|
||||
},
|
||||
{
|
||||
name: "expected iss is missing",
|
||||
fields: fields{expectedIss: "me"},
|
||||
args: args{RegisteredClaims{}},
|
||||
wantErr: ErrTokenRequiredClaimMissing,
|
||||
},
|
||||
{
|
||||
name: "expected sub mismatch",
|
||||
fields: fields{expectedSub: "me"},
|
||||
args: args{RegisteredClaims{Subject: "not-me"}},
|
||||
wantErr: ErrTokenInvalidSubject,
|
||||
},
|
||||
{
|
||||
name: "expected sub is missing",
|
||||
fields: fields{expectedSub: "me"},
|
||||
args: args{RegisteredClaims{}},
|
||||
wantErr: ErrTokenRequiredClaimMissing,
|
||||
},
|
||||
{
|
||||
name: "custom validator",
|
||||
fields: fields{},
|
||||
args: args{MyCustomClaims{Foo: "not-bar"}},
|
||||
wantErr: ErrFooBar,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := &validator{
|
||||
leeway: tt.fields.leeway,
|
||||
timeFunc: tt.fields.timeFunc,
|
||||
verifyIat: tt.fields.verifyIat,
|
||||
expectedAud: tt.fields.expectedAud,
|
||||
expectedIss: tt.fields.expectedIss,
|
||||
expectedSub: tt.fields.expectedSub,
|
||||
}
|
||||
if err := v.Validate(tt.args.claims); (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||
t.Errorf("validator.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validator_verifyExpiresAt(t *testing.T) {
|
||||
type fields struct {
|
||||
leeway time.Duration
|
||||
timeFunc func() time.Time
|
||||
}
|
||||
type args struct {
|
||||
claims Claims
|
||||
cmp time.Time
|
||||
required bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "good claim",
|
||||
fields: fields{timeFunc: time.Now},
|
||||
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(time.Now().Add(10 * time.Minute))}},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "claims with invalid type",
|
||||
fields: fields{},
|
||||
args: args{claims: MapClaims{"exp": "string"}},
|
||||
wantErr: ErrInvalidType,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := &validator{
|
||||
leeway: tt.fields.leeway,
|
||||
timeFunc: tt.fields.timeFunc,
|
||||
}
|
||||
|
||||
err := v.verifyExpiresAt(tt.args.claims, tt.args.cmp, tt.args.required)
|
||||
if (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||
t.Errorf("validator.verifyExpiresAt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validator_verifyIssuer(t *testing.T) {
|
||||
type fields struct {
|
||||
expectedIss string
|
||||
}
|
||||
type args struct {
|
||||
claims Claims
|
||||
cmp string
|
||||
required bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "good claim",
|
||||
fields: fields{expectedIss: "me"},
|
||||
args: args{claims: MapClaims{"iss": "me"}, cmp: "me"},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "claims with invalid type",
|
||||
fields: fields{expectedIss: "me"},
|
||||
args: args{claims: MapClaims{"iss": 1}, cmp: "me"},
|
||||
wantErr: ErrInvalidType,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := &validator{
|
||||
expectedIss: tt.fields.expectedIss,
|
||||
}
|
||||
err := v.verifyIssuer(tt.args.claims, tt.args.cmp, tt.args.required)
|
||||
if (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||
t.Errorf("validator.verifyIssuer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validator_verifySubject(t *testing.T) {
|
||||
type fields struct {
|
||||
expectedSub string
|
||||
}
|
||||
type args struct {
|
||||
claims Claims
|
||||
cmp string
|
||||
required bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "good claim",
|
||||
fields: fields{expectedSub: "me"},
|
||||
args: args{claims: MapClaims{"sub": "me"}, cmp: "me"},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "claims with invalid type",
|
||||
fields: fields{expectedSub: "me"},
|
||||
args: args{claims: MapClaims{"sub": 1}, cmp: "me"},
|
||||
wantErr: ErrInvalidType,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := &validator{
|
||||
expectedSub: tt.fields.expectedSub,
|
||||
}
|
||||
err := v.verifySubject(tt.args.claims, tt.args.cmp, tt.args.required)
|
||||
if (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||
t.Errorf("validator.verifySubject() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validator_verifyIssuedAt(t *testing.T) {
|
||||
type fields struct {
|
||||
leeway time.Duration
|
||||
timeFunc func() time.Time
|
||||
verifyIat bool
|
||||
}
|
||||
type args struct {
|
||||
claims Claims
|
||||
cmp time.Time
|
||||
required bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "good claim without iat",
|
||||
fields: fields{verifyIat: true},
|
||||
args: args{claims: MapClaims{}, required: false},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "good claim with iat",
|
||||
fields: fields{verifyIat: true},
|
||||
args: args{
|
||||
claims: RegisteredClaims{IssuedAt: NewNumericDate(time.Now())},
|
||||
cmp: time.Now().Add(10 * time.Minute),
|
||||
required: false,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := &validator{
|
||||
leeway: tt.fields.leeway,
|
||||
timeFunc: tt.fields.timeFunc,
|
||||
verifyIat: tt.fields.verifyIat,
|
||||
}
|
||||
if err := v.verifyIssuedAt(tt.args.claims, tt.args.cmp, tt.args.required); (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||
t.Errorf("validator.verifyIssuedAt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue