mirror of https://github.com/golang-jwt/jwt.git
More consistent way of handling validation errors (#274)
This commit is contained in:
parent
4e6e1ba2bb
commit
28dc52370e
|
@ -25,7 +25,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
go: [1.17, 1.18, 1.19]
|
go: ["1.18", "1.19", "1.20"]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
112
errors.go
112
errors.go
|
@ -2,18 +2,17 @@ package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Error constants
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidKey = errors.New("key is invalid")
|
ErrInvalidKey = errors.New("key is invalid")
|
||||||
ErrInvalidKeyType = errors.New("key is of invalid type")
|
ErrInvalidKeyType = errors.New("key is of invalid type")
|
||||||
ErrHashUnavailable = errors.New("the requested hash function is unavailable")
|
ErrHashUnavailable = errors.New("the requested hash function is unavailable")
|
||||||
|
|
||||||
ErrTokenMalformed = errors.New("token is malformed")
|
ErrTokenMalformed = errors.New("token is malformed")
|
||||||
ErrTokenUnverifiable = errors.New("token is unverifiable")
|
ErrTokenUnverifiable = errors.New("token is unverifiable")
|
||||||
ErrTokenSignatureInvalid = errors.New("token signature is invalid")
|
ErrTokenSignatureInvalid = errors.New("token signature is invalid")
|
||||||
|
ErrTokenRequiredClaimMissing = errors.New("token is missing required claim")
|
||||||
ErrTokenInvalidAudience = errors.New("token has invalid audience")
|
ErrTokenInvalidAudience = errors.New("token has invalid audience")
|
||||||
ErrTokenExpired = errors.New("token is expired")
|
ErrTokenExpired = errors.New("token is expired")
|
||||||
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
|
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
|
||||||
|
@ -22,100 +21,29 @@ var (
|
||||||
ErrTokenNotValidYet = errors.New("token is not valid yet")
|
ErrTokenNotValidYet = errors.New("token is not valid yet")
|
||||||
ErrTokenInvalidId = errors.New("token has invalid id")
|
ErrTokenInvalidId = errors.New("token has invalid id")
|
||||||
ErrTokenInvalidClaims = errors.New("token has invalid claims")
|
ErrTokenInvalidClaims = errors.New("token has invalid claims")
|
||||||
|
|
||||||
ErrInvalidType = errors.New("invalid type for claim")
|
ErrInvalidType = errors.New("invalid type for claim")
|
||||||
)
|
)
|
||||||
|
|
||||||
// The errors that might occur when parsing and validating a token
|
// joinedError is an error type that works similar to what [errors.Join]
|
||||||
const (
|
// produces, with the exception that it has a nice error string; mainly its
|
||||||
ValidationErrorMalformed uint32 = 1 << iota // Token is malformed
|
// error messages are concatenated using a comma, rather than a newline.
|
||||||
ValidationErrorUnverifiable // Token could not be verified because of signing problems
|
type joinedError struct {
|
||||||
ValidationErrorSignatureInvalid // Signature validation failed
|
errs []error
|
||||||
|
|
||||||
// Registered Claim validation errors
|
|
||||||
ValidationErrorAudience // AUD validation failed
|
|
||||||
ValidationErrorExpired // EXP validation failed
|
|
||||||
ValidationErrorIssuedAt // IAT validation failed
|
|
||||||
ValidationErrorIssuer // ISS validation failed
|
|
||||||
ValidationErrorSubject // SUB validation failed
|
|
||||||
ValidationErrorNotValidYet // NBF validation failed
|
|
||||||
ValidationErrorId // JTI validation failed
|
|
||||||
ValidationErrorClaimsInvalid // Generic claims validation error
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewValidationError is a helper for constructing a ValidationError with a string error message
|
|
||||||
func NewValidationError(errorText string, errorFlags uint32) *ValidationError {
|
|
||||||
return &ValidationError{
|
|
||||||
text: errorText,
|
|
||||||
Errors: errorFlags,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidationError represents an error from Parse if token is not valid
|
func (je joinedError) Error() string {
|
||||||
type ValidationError struct {
|
msg := []string{}
|
||||||
// Inner stores the error returned by external dependencies, e.g.: KeyFunc
|
for _, err := range je.errs {
|
||||||
Inner error
|
msg = append(msg, err.Error())
|
||||||
// Errors is a bit-field. See ValidationError... constants
|
|
||||||
Errors uint32
|
|
||||||
// Text can be used for errors that do not have a valid error just have text
|
|
||||||
text string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error is the implementation of the err interface.
|
|
||||||
func (e ValidationError) Error() string {
|
|
||||||
if e.Inner != nil {
|
|
||||||
return e.Inner.Error()
|
|
||||||
} else if e.text != "" {
|
|
||||||
return e.text
|
|
||||||
} else {
|
|
||||||
return "token is invalid"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unwrap gives errors.Is and errors.As access to the inner error.
|
|
||||||
func (e *ValidationError) Unwrap() error {
|
|
||||||
return e.Inner
|
|
||||||
}
|
|
||||||
|
|
||||||
// No errors
|
|
||||||
func (e *ValidationError) valid() bool {
|
|
||||||
return e.Errors == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Is checks if this ValidationError is of the supplied error. We are first
|
|
||||||
// checking for the exact error message by comparing the inner error message. If
|
|
||||||
// that fails, we compare using the error flags. This way we can use custom
|
|
||||||
// error messages (mainly for backwards compatibility) and still leverage
|
|
||||||
// errors.Is using the global error variables.
|
|
||||||
func (e *ValidationError) Is(err error) bool {
|
|
||||||
// Check, if our inner error is a direct match
|
|
||||||
if errors.Is(errors.Unwrap(e), err) {
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, we need to match using our error flags
|
return strings.Join(msg, ", ")
|
||||||
switch err {
|
}
|
||||||
case ErrTokenMalformed:
|
|
||||||
return e.Errors&ValidationErrorMalformed != 0
|
// joinErrors joins together multiple errors. Useful for scenarios where
|
||||||
case ErrTokenUnverifiable:
|
// multiple errors next to each other occur, e.g., in claims validation.
|
||||||
return e.Errors&ValidationErrorUnverifiable != 0
|
func joinErrors(errs ...error) error {
|
||||||
case ErrTokenSignatureInvalid:
|
return &joinedError{
|
||||||
return e.Errors&ValidationErrorSignatureInvalid != 0
|
errs: errs,
|
||||||
case ErrTokenInvalidAudience:
|
}
|
||||||
return e.Errors&ValidationErrorAudience != 0
|
|
||||||
case ErrTokenExpired:
|
|
||||||
return e.Errors&ValidationErrorExpired != 0
|
|
||||||
case ErrTokenUsedBeforeIssued:
|
|
||||||
return e.Errors&ValidationErrorIssuedAt != 0
|
|
||||||
case ErrTokenInvalidIssuer:
|
|
||||||
return e.Errors&ValidationErrorIssuer != 0
|
|
||||||
case ErrTokenNotValidYet:
|
|
||||||
return e.Errors&ValidationErrorNotValidYet != 0
|
|
||||||
case ErrTokenInvalidId:
|
|
||||||
return e.Errors&ValidationErrorId != 0
|
|
||||||
case ErrTokenInvalidClaims:
|
|
||||||
return e.Errors&ValidationErrorClaimsInvalid != 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
//go:build go1.20
|
||||||
|
// +build go1.20
|
||||||
|
|
||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Unwrap implements the multiple error unwrapping for this error type, which is
|
||||||
|
// possible in Go 1.20.
|
||||||
|
func (je joinedError) Unwrap() []error {
|
||||||
|
return je.errs
|
||||||
|
}
|
||||||
|
|
||||||
|
// newError creates a new error message with a detailed error message. The
|
||||||
|
// message will be prefixed with the contents of the supplied error type.
|
||||||
|
// Additionally, more errors, that provide more context can be supplied which
|
||||||
|
// will be appended to the message. This makes use of Go 1.20's possibility to
|
||||||
|
// include more than one %w formatting directive in [fmt.Errorf].
|
||||||
|
//
|
||||||
|
// For example,
|
||||||
|
//
|
||||||
|
// newError("no keyfunc was provided", ErrTokenUnverifiable)
|
||||||
|
//
|
||||||
|
// will produce the error string
|
||||||
|
//
|
||||||
|
// "token is unverifiable: no keyfunc was provided"
|
||||||
|
func newError(message string, err error, more ...error) error {
|
||||||
|
var format string
|
||||||
|
var args []any
|
||||||
|
if message != "" {
|
||||||
|
format = "%w: %s"
|
||||||
|
args = []any{err, message}
|
||||||
|
} else {
|
||||||
|
format = "%w"
|
||||||
|
args = []any{err}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range more {
|
||||||
|
format += ": %w"
|
||||||
|
args = append(args, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fmt.Errorf(format, args...)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
//go:build !go1.20
|
||||||
|
// +build !go1.20
|
||||||
|
|
||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Is implements checking for multiple errors using [errors.Is], since multiple
|
||||||
|
// error unwrapping is not possible in versions less than Go 1.20.
|
||||||
|
func (je joinedError) Is(err error) bool {
|
||||||
|
for _, e := range je.errs {
|
||||||
|
if errors.Is(e, err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrappedErrors is a workaround for wrapping multiple errors in environments
|
||||||
|
// where Go 1.20 is not available. It basically uses the already implemented
|
||||||
|
// functionatlity of joinedError to handle multiple errors with supplies a
|
||||||
|
// custom error message that is identical to the one we produce in Go 1.20 using
|
||||||
|
// multiple %w directives.
|
||||||
|
type wrappedErrors struct {
|
||||||
|
msg string
|
||||||
|
joinedError
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns the stored error string
|
||||||
|
func (we wrappedErrors) Error() string {
|
||||||
|
return we.msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// newError creates a new error message with a detailed error message. The
|
||||||
|
// message will be prefixed with the contents of the supplied error type.
|
||||||
|
// Additionally, more errors, that provide more context can be supplied which
|
||||||
|
// will be appended to the message. Since we cannot use of Go 1.20's possibility
|
||||||
|
// to include more than one %w formatting directive in [fmt.Errorf], we have to
|
||||||
|
// emulate that.
|
||||||
|
//
|
||||||
|
// For example,
|
||||||
|
//
|
||||||
|
// newError("no keyfunc was provided", ErrTokenUnverifiable)
|
||||||
|
//
|
||||||
|
// will produce the error string
|
||||||
|
//
|
||||||
|
// "token is unverifiable: no keyfunc was provided"
|
||||||
|
func newError(message string, err error, more ...error) error {
|
||||||
|
// We cannot wrap multiple errors here with %w, so we have to be a little
|
||||||
|
// bit creative. Basically, we are using %s instead of %w to produce the
|
||||||
|
// same error message and then throw the result into a custom error struct.
|
||||||
|
var format string
|
||||||
|
var args []any
|
||||||
|
if message != "" {
|
||||||
|
format = "%s: %s"
|
||||||
|
args = []any{err, message}
|
||||||
|
} else {
|
||||||
|
format = "%s"
|
||||||
|
args = []any{err}
|
||||||
|
}
|
||||||
|
errs := []error{err}
|
||||||
|
|
||||||
|
for _, e := range more {
|
||||||
|
format += ": %s"
|
||||||
|
args = append(args, e)
|
||||||
|
errs = append(errs, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = &wrappedErrors{
|
||||||
|
msg: fmt.Sprintf(format, args...),
|
||||||
|
joinedError: joinedError{errs: errs},
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,95 @@
|
||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_joinErrors(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
errs []error
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErrors []error
|
||||||
|
wantMessage string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "multiple errors",
|
||||||
|
args: args{
|
||||||
|
errs: []error{ErrTokenNotValidYet, ErrTokenExpired},
|
||||||
|
},
|
||||||
|
wantErrors: []error{ErrTokenNotValidYet, ErrTokenExpired},
|
||||||
|
wantMessage: "token is not valid yet, token is expired",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := joinErrors(tt.args.errs...)
|
||||||
|
for _, wantErr := range tt.wantErrors {
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Errorf("joinErrors() error = %v, does not contain %v", err, wantErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err.Error() != tt.wantMessage {
|
||||||
|
t.Errorf("joinErrors() error.Error() = %v, wantMessage %v", err, tt.wantMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_newError(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
message string
|
||||||
|
err error
|
||||||
|
more []error
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErrors []error
|
||||||
|
wantMessage string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single error",
|
||||||
|
args: args{message: "something is wrong", err: ErrTokenMalformed},
|
||||||
|
wantMessage: "token is malformed: something is wrong",
|
||||||
|
wantErrors: []error{ErrTokenMalformed},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two errors",
|
||||||
|
args: args{message: "something is wrong", err: ErrTokenMalformed, more: []error{io.ErrUnexpectedEOF}},
|
||||||
|
wantMessage: "token is malformed: something is wrong: unexpected EOF",
|
||||||
|
wantErrors: []error{ErrTokenMalformed},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two errors, no detail",
|
||||||
|
args: args{message: "", err: ErrTokenInvalidClaims, more: []error{ErrTokenExpired}},
|
||||||
|
wantMessage: "token has invalid claims: token is expired",
|
||||||
|
wantErrors: []error{ErrTokenInvalidClaims, ErrTokenExpired},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two errors, no detail and join error",
|
||||||
|
args: args{message: "", err: ErrTokenInvalidClaims, more: []error{joinErrors(ErrTokenExpired, ErrTokenNotValidYet)}},
|
||||||
|
wantMessage: "token has invalid claims: token is expired, token is not valid yet",
|
||||||
|
wantErrors: []error{ErrTokenInvalidClaims, ErrTokenExpired, ErrTokenNotValidYet},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := newError(tt.args.message, tt.args.err, tt.args.more...)
|
||||||
|
for _, wantErr := range tt.wantErrors {
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Errorf("newError() error = %v, does not contain %v", err, wantErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err.Error() != tt.wantMessage {
|
||||||
|
t.Errorf("newError() error.Error() = %v, wantMessage %v", err, tt.wantMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
2
go.mod
2
go.mod
|
@ -1,3 +1,3 @@
|
||||||
module github.com/golang-jwt/jwt/v5
|
module github.com/golang-jwt/jwt/v5
|
||||||
|
|
||||||
go 1.16
|
go 1.18
|
||||||
|
|
|
@ -2,6 +2,7 @@ package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MapClaims is a claims type that uses the map[string]interface{} for JSON
|
// MapClaims is a claims type that uses the map[string]interface{} for JSON
|
||||||
|
@ -60,7 +61,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) {
|
||||||
return newNumericDateFromSeconds(v), nil
|
return newNumericDateFromSeconds(v), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, ErrInvalidType
|
return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseClaimsString tries to parse a key in the map claims type as a
|
// parseClaimsString tries to parse a key in the map claims type as a
|
||||||
|
@ -76,7 +77,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
|
||||||
for _, a := range v {
|
for _, a := range v {
|
||||||
vs, ok := a.(string)
|
vs, ok := a.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrInvalidType
|
return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
|
||||||
}
|
}
|
||||||
cs = append(cs, vs)
|
cs = append(cs, vs)
|
||||||
}
|
}
|
||||||
|
@ -101,7 +102,7 @@ func (m MapClaims) parseString(key string) (string, error) {
|
||||||
|
|
||||||
iss, ok = raw.(string)
|
iss, ok = raw.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", ErrInvalidType
|
return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
|
||||||
}
|
}
|
||||||
|
|
||||||
return iss, nil
|
return iss, nil
|
||||||
|
|
7
none.go
7
none.go
|
@ -13,7 +13,7 @@ type unsafeNoneMagicConstant string
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
SigningMethodNone = &signingMethodNone{}
|
SigningMethodNone = &signingMethodNone{}
|
||||||
NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid)
|
NoneSignatureTypeDisallowedError = newError("'none' signature type is not allowed", ErrTokenUnverifiable)
|
||||||
|
|
||||||
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
|
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
|
||||||
return SigningMethodNone
|
return SigningMethodNone
|
||||||
|
@ -33,10 +33,7 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac
|
||||||
}
|
}
|
||||||
// If signing method is none, signature must be an empty string
|
// If signing method is none, signature must be an empty string
|
||||||
if signature != "" {
|
if signature != "" {
|
||||||
return NewValidationError(
|
return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable)
|
||||||
"'none' signing method with non-empty signature",
|
|
||||||
ValidationErrorSignatureInvalid,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept 'none' signing method.
|
// Accept 'none' signing method.
|
||||||
|
|
53
parser.go
53
parser.go
|
@ -65,7 +65,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
|
||||||
}
|
}
|
||||||
if !signingMethodValid {
|
if !signingMethodValid {
|
||||||
// signing method is not in the listed set
|
// signing method is not in the listed set
|
||||||
return token, NewValidationError(fmt.Sprintf("signing method %v is invalid", alg), ValidationErrorSignatureInvalid)
|
return token, newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,17 +73,17 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
|
||||||
var key interface{}
|
var key interface{}
|
||||||
if keyFunc == nil {
|
if keyFunc == nil {
|
||||||
// keyFunc was not provided. short circuiting validation
|
// keyFunc was not provided. short circuiting validation
|
||||||
return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable)
|
return token, newError("no keyfunc was provided", ErrTokenUnverifiable)
|
||||||
}
|
}
|
||||||
if key, err = keyFunc(token); err != nil {
|
if key, err = keyFunc(token); err != nil {
|
||||||
// keyFunc returned an error
|
return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err)
|
||||||
if ve, ok := err.(*ValidationError); ok {
|
|
||||||
return token, ve
|
|
||||||
}
|
|
||||||
return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vErr := &ValidationError{}
|
// Perform signature validation
|
||||||
|
token.Signature = parts[2]
|
||||||
|
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
|
||||||
|
return token, newError("", ErrTokenSignatureInvalid, err)
|
||||||
|
}
|
||||||
|
|
||||||
// Validate Claims
|
// Validate Claims
|
||||||
if !p.skipClaimsValidation {
|
if !p.skipClaimsValidation {
|
||||||
|
@ -93,29 +93,14 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.validator.Validate(claims); err != nil {
|
if err := p.validator.Validate(claims); err != nil {
|
||||||
// If the Claims Valid returned an error, check if it is a validation error,
|
return token, newError("", ErrTokenInvalidClaims, err)
|
||||||
// If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set
|
|
||||||
if e, ok := err.(*ValidationError); !ok {
|
|
||||||
vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid}
|
|
||||||
} else {
|
|
||||||
vErr = e
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform validation
|
// No errors so far, token is valid.
|
||||||
token.Signature = parts[2]
|
|
||||||
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
|
|
||||||
vErr.Inner = err
|
|
||||||
vErr.Errors |= ValidationErrorSignatureInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
if vErr.valid() {
|
|
||||||
token.Valid = true
|
token.Valid = true
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, vErr
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseUnverified parses the token but doesn't validate the signature.
|
// ParseUnverified parses the token but doesn't validate the signature.
|
||||||
|
@ -127,7 +112,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
|
||||||
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
|
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
|
||||||
parts = strings.Split(tokenString, ".")
|
parts = strings.Split(tokenString, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
|
return nil, parts, newError("token contains an invalid number of segments", ErrTokenMalformed)
|
||||||
}
|
}
|
||||||
|
|
||||||
token = &Token{Raw: tokenString}
|
token = &Token{Raw: tokenString}
|
||||||
|
@ -136,12 +121,12 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
|
||||||
var headerBytes []byte
|
var headerBytes []byte
|
||||||
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
|
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
|
||||||
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
|
if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {
|
||||||
return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed)
|
return token, parts, newError("tokenstring should not contain 'bearer '", ErrTokenMalformed)
|
||||||
}
|
}
|
||||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err)
|
||||||
}
|
}
|
||||||
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
|
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
|
||||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse Claims
|
// parse Claims
|
||||||
|
@ -149,7 +134,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
|
||||||
token.Claims = claims
|
token.Claims = claims
|
||||||
|
|
||||||
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
|
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
|
||||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err)
|
||||||
}
|
}
|
||||||
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
|
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
|
||||||
if p.useJSONNumber {
|
if p.useJSONNumber {
|
||||||
|
@ -163,16 +148,16 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
|
||||||
}
|
}
|
||||||
// Handle decode error
|
// Handle decode error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed}
|
return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lookup signature method
|
// Lookup signature method
|
||||||
if method, ok := token.Header["alg"].(string); ok {
|
if method, ok := token.Header["alg"].(string); ok {
|
||||||
if token.Method = GetSigningMethod(method); token.Method == nil {
|
if token.Method = GetSigningMethod(method); token.Method == nil {
|
||||||
return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
|
return token, parts, newError("signing method (alg) is unavailable", ErrTokenUnverifiable)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable)
|
return token, parts, newError("signing method (alg) is unspecified", ErrTokenUnverifiable)
|
||||||
}
|
}
|
||||||
|
|
||||||
return token, parts, nil
|
return token, parts, nil
|
||||||
|
|
|
@ -51,7 +51,6 @@ var jwtTestData = []struct {
|
||||||
keyfunc jwt.Keyfunc
|
keyfunc jwt.Keyfunc
|
||||||
claims jwt.Claims
|
claims jwt.Claims
|
||||||
valid bool
|
valid bool
|
||||||
errors uint32
|
|
||||||
err []error
|
err []error
|
||||||
parser *jwt.Parser
|
parser *jwt.Parser
|
||||||
signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose
|
signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose
|
||||||
|
@ -62,7 +61,16 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
nil,
|
nil,
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorMalformed,
|
[]error{jwt.ErrTokenMalformed},
|
||||||
|
nil,
|
||||||
|
jwt.SigningMethodRS256,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid JSON claim",
|
||||||
|
"eyJhbGciOiJSUzI1NiIsInppcCI6IkRFRiJ9.eNqqVkqtKFCyMjQ1s7Q0sbA0MtFRyk3NTUot8kxRslIKLbZQggn4JeamAoUcfRz99HxcXRWeze172tr4bFq7Ui0AAAD__w.jBXD4LT4aq4oXTgDoPkiV6n4QdSZPZI1Z4J8MWQC42aHK0oXwcovEU06dVbtB81TF-2byuu0-qi8J0GUttODT67k6gCl6DV_iuCOV7gczwTcvKslotUvXzoJ2wa0QuujnjxLEE50r0p6k0tsv_9OIFSUZzDksJFYNPlJH2eFG55DROx4TsOz98az37SujZi9GGbTc9SLgzFHPrHMrovRZ5qLC_w4JrdtsLzBBI11OQJgRYwV8fQf4O8IsMkHtetjkN7dKgUkJtRarNWOk76rpTPppLypiLU4_J0-wrElLMh1TzUVZW6Fz2cDHDDBACJgMmKQ2pOFEDK_vYZN74dLCF5GiTZV6DbXhNxO7lqT7JUN4a3p2z96G7WNRjblf2qZeuYdQvkIsiK-rCbSIE836XeY5gaBgkOzuEvzl_tMrpRmb5Oox1ibOfVT2KBh9Lvqsb1XbQjCio2CLE2ViCLqoe0AaRqlUyrk3n8BIG-r0IW4dcw96CEryEMIjsjVp9mtPXamJzf391kt8Rf3iRBqwv3zP7Plg1ResXbmsFUgOflAUPcYmfLug4W3W52ntcUlTHAKXrNfaJL9QQiYAaDukG-ZHDytsOWTuuXw7lVxjt-XYi1VbRAIjh1aIYSELEmEpE4Ny74htQtywYXMQNfJpB0nNn8IiWakgcYYMJ0TmKM",
|
||||||
|
defaultKeyFunc,
|
||||||
|
nil,
|
||||||
|
false,
|
||||||
[]error{jwt.ErrTokenMalformed},
|
[]error{jwt.ErrTokenMalformed},
|
||||||
nil,
|
nil,
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -73,7 +81,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
nil,
|
nil,
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorMalformed,
|
|
||||||
[]error{jwt.ErrTokenMalformed},
|
[]error{jwt.ErrTokenMalformed},
|
||||||
nil,
|
nil,
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -84,7 +91,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar"},
|
jwt.MapClaims{"foo": "bar"},
|
||||||
true,
|
true,
|
||||||
0,
|
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -95,7 +101,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
|
jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorExpired,
|
|
||||||
[]error{jwt.ErrTokenExpired},
|
[]error{jwt.ErrTokenExpired},
|
||||||
nil,
|
nil,
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -106,7 +111,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
|
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorNotValidYet,
|
|
||||||
[]error{jwt.ErrTokenNotValidYet},
|
[]error{jwt.ErrTokenNotValidYet},
|
||||||
nil,
|
nil,
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -117,8 +121,7 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
|
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
|
[]error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired},
|
||||||
[]error{jwt.ErrTokenNotValidYet},
|
|
||||||
nil,
|
nil,
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
},
|
},
|
||||||
|
@ -128,7 +131,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar"},
|
jwt.MapClaims{"foo": "bar"},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorSignatureInvalid,
|
|
||||||
[]error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification},
|
[]error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification},
|
||||||
nil,
|
nil,
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -139,7 +141,6 @@ var jwtTestData = []struct {
|
||||||
nilKeyFunc,
|
nilKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar"},
|
jwt.MapClaims{"foo": "bar"},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorUnverifiable,
|
|
||||||
[]error{jwt.ErrTokenUnverifiable},
|
[]error{jwt.ErrTokenUnverifiable},
|
||||||
nil,
|
nil,
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -150,7 +151,6 @@ var jwtTestData = []struct {
|
||||||
emptyKeyFunc,
|
emptyKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar"},
|
jwt.MapClaims{"foo": "bar"},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorSignatureInvalid,
|
|
||||||
[]error{jwt.ErrTokenSignatureInvalid},
|
[]error{jwt.ErrTokenSignatureInvalid},
|
||||||
nil,
|
nil,
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -161,7 +161,6 @@ var jwtTestData = []struct {
|
||||||
errorKeyFunc,
|
errorKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar"},
|
jwt.MapClaims{"foo": "bar"},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorUnverifiable,
|
|
||||||
[]error{jwt.ErrTokenUnverifiable, errKeyFuncError},
|
[]error{jwt.ErrTokenUnverifiable, errKeyFuncError},
|
||||||
nil,
|
nil,
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -172,7 +171,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar"},
|
jwt.MapClaims{"foo": "bar"},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorSignatureInvalid,
|
|
||||||
[]error{jwt.ErrTokenSignatureInvalid},
|
[]error{jwt.ErrTokenSignatureInvalid},
|
||||||
jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})),
|
jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -183,7 +181,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar"},
|
jwt.MapClaims{"foo": "bar"},
|
||||||
true,
|
true,
|
||||||
0,
|
|
||||||
nil,
|
nil,
|
||||||
jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})),
|
jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -194,7 +191,6 @@ var jwtTestData = []struct {
|
||||||
ecdsaKeyFunc,
|
ecdsaKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar"},
|
jwt.MapClaims{"foo": "bar"},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorSignatureInvalid,
|
|
||||||
[]error{jwt.ErrTokenSignatureInvalid},
|
[]error{jwt.ErrTokenSignatureInvalid},
|
||||||
jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})),
|
jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})),
|
||||||
jwt.SigningMethodES256,
|
jwt.SigningMethodES256,
|
||||||
|
@ -205,7 +201,6 @@ var jwtTestData = []struct {
|
||||||
ecdsaKeyFunc,
|
ecdsaKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar"},
|
jwt.MapClaims{"foo": "bar"},
|
||||||
true,
|
true,
|
||||||
0,
|
|
||||||
nil,
|
nil,
|
||||||
jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})),
|
jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})),
|
||||||
jwt.SigningMethodES256,
|
jwt.SigningMethodES256,
|
||||||
|
@ -216,7 +211,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": json.Number("123.4")},
|
jwt.MapClaims{"foo": json.Number("123.4")},
|
||||||
true,
|
true,
|
||||||
0,
|
|
||||||
nil,
|
nil,
|
||||||
jwt.NewParser(jwt.WithJSONNumber()),
|
jwt.NewParser(jwt.WithJSONNumber()),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -227,7 +221,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
|
jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorExpired,
|
|
||||||
[]error{jwt.ErrTokenExpired},
|
[]error{jwt.ErrTokenExpired},
|
||||||
jwt.NewParser(jwt.WithJSONNumber()),
|
jwt.NewParser(jwt.WithJSONNumber()),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -238,7 +231,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
|
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorNotValidYet,
|
|
||||||
[]error{jwt.ErrTokenNotValidYet},
|
[]error{jwt.ErrTokenNotValidYet},
|
||||||
jwt.NewParser(jwt.WithJSONNumber()),
|
jwt.NewParser(jwt.WithJSONNumber()),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -249,8 +241,7 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
|
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
|
[]error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired},
|
||||||
[]error{jwt.ErrTokenNotValidYet},
|
|
||||||
jwt.NewParser(jwt.WithJSONNumber()),
|
jwt.NewParser(jwt.WithJSONNumber()),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
},
|
},
|
||||||
|
@ -260,7 +251,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
|
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
|
||||||
true,
|
true,
|
||||||
0,
|
|
||||||
nil,
|
nil,
|
||||||
jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()),
|
jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -273,7 +263,6 @@ var jwtTestData = []struct {
|
||||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Second * 10)),
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Second * 10)),
|
||||||
},
|
},
|
||||||
true,
|
true,
|
||||||
0,
|
|
||||||
nil,
|
nil,
|
||||||
jwt.NewParser(jwt.WithJSONNumber()),
|
jwt.NewParser(jwt.WithJSONNumber()),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -286,7 +275,6 @@ var jwtTestData = []struct {
|
||||||
Audience: jwt.ClaimStrings{"test"},
|
Audience: jwt.ClaimStrings{"test"},
|
||||||
},
|
},
|
||||||
true,
|
true,
|
||||||
0,
|
|
||||||
nil,
|
nil,
|
||||||
jwt.NewParser(jwt.WithJSONNumber()),
|
jwt.NewParser(jwt.WithJSONNumber()),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -299,7 +287,6 @@ var jwtTestData = []struct {
|
||||||
Audience: jwt.ClaimStrings{"test", "test"},
|
Audience: jwt.ClaimStrings{"test", "test"},
|
||||||
},
|
},
|
||||||
true,
|
true,
|
||||||
0,
|
|
||||||
nil,
|
nil,
|
||||||
jwt.NewParser(jwt.WithJSONNumber()),
|
jwt.NewParser(jwt.WithJSONNumber()),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -312,7 +299,6 @@ var jwtTestData = []struct {
|
||||||
Audience: nil, // because of the unmarshal error, this will be empty
|
Audience: nil, // because of the unmarshal error, this will be empty
|
||||||
},
|
},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorMalformed,
|
|
||||||
[]error{jwt.ErrTokenMalformed},
|
[]error{jwt.ErrTokenMalformed},
|
||||||
jwt.NewParser(jwt.WithJSONNumber()),
|
jwt.NewParser(jwt.WithJSONNumber()),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -325,7 +311,6 @@ var jwtTestData = []struct {
|
||||||
Audience: nil, // because of the unmarshal error, this will be empty
|
Audience: nil, // because of the unmarshal error, this will be empty
|
||||||
},
|
},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorMalformed,
|
|
||||||
[]error{jwt.ErrTokenMalformed},
|
[]error{jwt.ErrTokenMalformed},
|
||||||
jwt.NewParser(jwt.WithJSONNumber()),
|
jwt.NewParser(jwt.WithJSONNumber()),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -336,7 +321,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
&jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))},
|
&jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))},
|
||||||
false,
|
false,
|
||||||
jwt.ValidationErrorNotValidYet,
|
|
||||||
[]error{jwt.ErrTokenNotValidYet},
|
[]error{jwt.ErrTokenNotValidYet},
|
||||||
jwt.NewParser(jwt.WithLeeway(time.Minute)),
|
jwt.NewParser(jwt.WithLeeway(time.Minute)),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -347,7 +331,6 @@ var jwtTestData = []struct {
|
||||||
defaultKeyFunc,
|
defaultKeyFunc,
|
||||||
&jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))},
|
&jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))},
|
||||||
true,
|
true,
|
||||||
0,
|
|
||||||
nil,
|
nil,
|
||||||
jwt.NewParser(jwt.WithLeeway(2 * time.Minute)),
|
jwt.NewParser(jwt.WithLeeway(2 * time.Minute)),
|
||||||
jwt.SigningMethodRS256,
|
jwt.SigningMethodRS256,
|
||||||
|
@ -381,7 +364,6 @@ func TestParser_Parse(t *testing.T) {
|
||||||
|
|
||||||
// Parse the token
|
// Parse the token
|
||||||
var token *jwt.Token
|
var token *jwt.Token
|
||||||
var ve *jwt.ValidationError
|
|
||||||
var err error
|
var err error
|
||||||
var parser = data.parser
|
var parser = data.parser
|
||||||
if parser == nil {
|
if parser == nil {
|
||||||
|
@ -417,23 +399,6 @@ func TestParser_Parse(t *testing.T) {
|
||||||
t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name)
|
t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if data.errors != 0 {
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("[%v] Expecting error. Didn't get one.", data.name)
|
|
||||||
} else {
|
|
||||||
if errors.As(err, &ve) {
|
|
||||||
// compare the bitfield part of the error
|
|
||||||
if e := ve.Errors; e != data.errors {
|
|
||||||
t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError {
|
|
||||||
t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if data.err != nil {
|
if data.err != nil {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("[%v] Expecting error(s). Didn't get one.", data.name)
|
t.Errorf("[%v] Expecting error(s). Didn't get one.", data.name)
|
||||||
|
@ -467,7 +432,7 @@ func TestParser_ParseUnverified(t *testing.T) {
|
||||||
// Iterate over test data set and run tests
|
// Iterate over test data set and run tests
|
||||||
for _, data := range jwtTestData {
|
for _, data := range jwtTestData {
|
||||||
// Skip test data, that intentionally contains malformed tokens, as they would lead to an error
|
// Skip test data, that intentionally contains malformed tokens, as they would lead to an error
|
||||||
if data.errors&jwt.ValidationErrorMalformed != 0 {
|
if len(data.err) == 1 && errors.Is(data.err[0], jwt.ErrTokenMalformed) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
219
validator.go
219
validator.go
|
@ -2,9 +2,32 @@ package jwt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ClaimsValidator is an interface that can be implemented by custom claims who
|
||||||
|
// wish to execute any additional claims validation based on
|
||||||
|
// application-specific logic. The Validate function is then executed in
|
||||||
|
// addition to the regular claims validation and any error returned is appended
|
||||||
|
// to the final validation result.
|
||||||
|
//
|
||||||
|
// type MyCustomClaims struct {
|
||||||
|
// Foo string `json:"foo"`
|
||||||
|
// jwt.RegisteredClaims
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// func (m MyCustomClaims) Validate() error {
|
||||||
|
// if m.Foo != "bar" {
|
||||||
|
// return errors.New("must be foobar")
|
||||||
|
// }
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
type ClaimsValidator interface {
|
||||||
|
Claims
|
||||||
|
Validate() error
|
||||||
|
}
|
||||||
|
|
||||||
// validator is the core of the new Validation API. It is automatically used by
|
// validator is the core of the new Validation API. It is automatically used by
|
||||||
// a [Parser] during parsing and can be modified with various parser options.
|
// a [Parser] during parsing and can be modified with various parser options.
|
||||||
//
|
//
|
||||||
|
@ -47,10 +70,13 @@ func newValidator(opts ...ParserOption) *validator {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate validates the given claims. It will also perform any custom
|
// Validate validates the given claims. It will also perform any custom
|
||||||
// validation if claims implements the CustomValidator interface.
|
// validation if claims implements the [ClaimsValidator] interface.
|
||||||
func (v *validator) Validate(claims Claims) error {
|
func (v *validator) Validate(claims Claims) error {
|
||||||
var now time.Time
|
var (
|
||||||
vErr := new(ValidationError)
|
now time.Time
|
||||||
|
errs []error = make([]error, 0, 6)
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
// Check, if we have a time func
|
// Check, if we have a time func
|
||||||
if v.timeFunc != nil {
|
if v.timeFunc != nil {
|
||||||
|
@ -60,140 +86,139 @@ func (v *validator) Validate(claims Claims) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We always need to check the expiration time, but usage of the claim
|
// We always need to check the expiration time, but usage of the claim
|
||||||
// itself is OPTIONAL
|
// itself is OPTIONAL.
|
||||||
if !v.VerifyExpiresAt(claims, now, false) {
|
if err = v.verifyExpiresAt(claims, now, false); err != nil {
|
||||||
vErr.Inner = ErrTokenExpired
|
errs = append(errs, err)
|
||||||
vErr.Errors |= ValidationErrorExpired
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// We always need to check not-before, but usage of the claim itself is
|
// We always need to check not-before, but usage of the claim itself is
|
||||||
// OPTIONAL
|
// OPTIONAL.
|
||||||
if !v.VerifyNotBefore(claims, now, false) {
|
if err = v.verifyNotBefore(claims, now, false); err != nil {
|
||||||
vErr.Inner = ErrTokenNotValidYet
|
errs = append(errs, err)
|
||||||
vErr.Errors |= ValidationErrorNotValidYet
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check issued-at if the option is enabled
|
// Check issued-at if the option is enabled
|
||||||
if v.verifyIat && !v.VerifyIssuedAt(claims, now, false) {
|
if v.verifyIat {
|
||||||
vErr.Inner = ErrTokenUsedBeforeIssued
|
if err = v.verifyIssuedAt(claims, now, false); err != nil {
|
||||||
vErr.Errors |= ValidationErrorIssuedAt
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have an expected audience, we also require the audience claim
|
// If we have an expected audience, we also require the audience claim
|
||||||
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) {
|
if v.expectedAud != "" {
|
||||||
vErr.Inner = ErrTokenInvalidAudience
|
if err = v.verifyAudience(claims, v.expectedAud, true); err != nil {
|
||||||
vErr.Errors |= ValidationErrorAudience
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have an expected issuer, we also require the issuer claim
|
// If we have an expected issuer, we also require the issuer claim
|
||||||
if v.expectedIss != "" && !v.VerifyIssuer(claims, v.expectedIss, true) {
|
if v.expectedIss != "" {
|
||||||
vErr.Inner = ErrTokenInvalidIssuer
|
if err = v.verifyIssuer(claims, v.expectedIss, true); err != nil {
|
||||||
vErr.Errors |= ValidationErrorIssuer
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have an expected subject, we also require the subject claim
|
// If we have an expected subject, we also require the subject claim
|
||||||
if v.expectedSub != "" && !v.VerifySubject(claims, v.expectedSub, true) {
|
if v.expectedSub != "" {
|
||||||
vErr.Inner = ErrTokenInvalidSubject
|
if err = v.verifySubject(claims, v.expectedSub, true); err != nil {
|
||||||
vErr.Errors |= ValidationErrorSubject
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally, we want to give the claim itself some possibility to do some
|
// Finally, we want to give the claim itself some possibility to do some
|
||||||
// additional custom validation based on a custom Validate function.
|
// additional custom validation based on a custom Validate function.
|
||||||
cvt, ok := claims.(interface {
|
cvt, ok := claims.(ClaimsValidator)
|
||||||
Validate() error
|
|
||||||
})
|
|
||||||
if ok {
|
if ok {
|
||||||
if err := cvt.Validate(); err != nil {
|
if err := cvt.Validate(); err != nil {
|
||||||
vErr.Inner = err
|
errs = append(errs, err)
|
||||||
vErr.Errors |= ValidationErrorClaimsInvalid
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if vErr.valid() {
|
if len(errs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return vErr
|
return joinErrors(errs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyExpiresAt compares the exp claim in claims against cmp. This function
|
// verifyExpiresAt compares the exp claim in claims against cmp. This function
|
||||||
// will return true if cmp < exp. Additional leeway is taken into account.
|
// will succeed if cmp < exp. Additional leeway is taken into account.
|
||||||
//
|
//
|
||||||
// If exp is not set, it will return true if the claim is not required,
|
// If exp is not set, it will succeed if the claim is not required,
|
||||||
// otherwise false will be returned.
|
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||||
//
|
//
|
||||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||||
// the wrong type, false will be returned.
|
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||||
func (v *validator) VerifyExpiresAt(claims Claims, cmp time.Time, required bool) bool {
|
func (v *validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error {
|
||||||
exp, err := claims.GetExpirationTime()
|
exp, err := claims.GetExpirationTime()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if exp != nil {
|
if exp == nil {
|
||||||
return cmp.Before((exp.Time).Add(+v.leeway))
|
return errorIfRequired(required, "exp")
|
||||||
} else {
|
|
||||||
return !required
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return errorIfFalse(cmp.Before((exp.Time).Add(+v.leeway)), ErrTokenExpired)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyIssuedAt compares the iat claim in claims against cmp. This function
|
// verifyIssuedAt compares the iat claim in claims against cmp. This function
|
||||||
// will return true if cmp >= iat. Additional leeway is taken into account.
|
// will succeed if cmp >= iat. Additional leeway is taken into account.
|
||||||
//
|
//
|
||||||
// If iat is not set, it will return true if the claim is not required,
|
// If iat is not set, it will succeed if the claim is not required,
|
||||||
// otherwise false will be returned.
|
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||||
//
|
//
|
||||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||||
// the wrong type, false will be returned.
|
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||||
func (v *validator) VerifyIssuedAt(claims Claims, cmp time.Time, required bool) bool {
|
func (v *validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error {
|
||||||
iat, err := claims.GetIssuedAt()
|
iat, err := claims.GetIssuedAt()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if iat != nil {
|
if iat == nil {
|
||||||
return !cmp.Before(iat.Add(-v.leeway))
|
return errorIfRequired(required, "iat")
|
||||||
} else {
|
|
||||||
return !required
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return errorIfFalse(!cmp.Before(iat.Add(-v.leeway)), ErrTokenUsedBeforeIssued)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyNotBefore compares the nbf claim in claims against cmp. This function
|
// verifyNotBefore compares the nbf claim in claims against cmp. This function
|
||||||
// will return true if cmp >= nbf. Additional leeway is taken into account.
|
// will return true if cmp >= nbf. Additional leeway is taken into account.
|
||||||
//
|
//
|
||||||
// If nbf is not set, it will return true if the claim is not required,
|
// If nbf is not set, it will succeed if the claim is not required,
|
||||||
// otherwise false will be returned.
|
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||||
//
|
//
|
||||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||||
// the wrong type, false will be returned.
|
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||||
func (v *validator) VerifyNotBefore(claims Claims, cmp time.Time, required bool) bool {
|
func (v *validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error {
|
||||||
nbf, err := claims.GetNotBefore()
|
nbf, err := claims.GetNotBefore()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if nbf != nil {
|
if nbf == nil {
|
||||||
return !cmp.Before(nbf.Add(-v.leeway))
|
return errorIfRequired(required, "nbf")
|
||||||
} else {
|
|
||||||
return !required
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return errorIfFalse(!cmp.Before(nbf.Add(-v.leeway)), ErrTokenNotValidYet)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyAudience compares the aud claim against cmp.
|
// verifyAudience compares the aud claim against cmp.
|
||||||
//
|
//
|
||||||
// If aud is not set or an empty list, it will return true if the claim is not
|
// If aud is not set or an empty list, it will succeed if the claim is not required,
|
||||||
// required, otherwise false will be returned.
|
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||||
//
|
//
|
||||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||||
// the wrong type, false will be returned.
|
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||||
func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) bool {
|
func (v *validator) verifyAudience(claims Claims, cmp string, required bool) error {
|
||||||
aud, err := claims.GetAudience()
|
aud, err := claims.GetAudience()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(aud) == 0 {
|
if len(aud) == 0 {
|
||||||
return !required
|
return errorIfRequired(required, "aud")
|
||||||
}
|
}
|
||||||
|
|
||||||
// use a var here to keep constant time compare when looping over a number of claims
|
// use a var here to keep constant time compare when looping over a number of claims
|
||||||
|
@ -209,48 +234,68 @@ func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) boo
|
||||||
|
|
||||||
// case where "" is sent in one or many aud claims
|
// case where "" is sent in one or many aud claims
|
||||||
if stringClaims == "" {
|
if stringClaims == "" {
|
||||||
return !required
|
return errorIfRequired(required, "aud")
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return errorIfFalse(result, ErrTokenInvalidAudience)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyIssuer compares the iss claim in claims against cmp.
|
// verifyIssuer compares the iss claim in claims against cmp.
|
||||||
//
|
//
|
||||||
// If iss is not set, it will return true if the claim is not required,
|
// If iss is not set, it will succeed if the claim is not required,
|
||||||
// otherwise false will be returned.
|
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||||
//
|
//
|
||||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||||
// the wrong type, false will be returned.
|
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||||
func (v *validator) VerifyIssuer(claims Claims, cmp string, required bool) bool {
|
func (v *validator) verifyIssuer(claims Claims, cmp string, required bool) error {
|
||||||
iss, err := claims.GetIssuer()
|
iss, err := claims.GetIssuer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if iss == "" {
|
if iss == "" {
|
||||||
return !required
|
return errorIfRequired(required, "iss")
|
||||||
}
|
}
|
||||||
|
|
||||||
return iss == cmp
|
return errorIfFalse(iss == cmp, ErrTokenInvalidIssuer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifySubject compares the sub claim against cmp.
|
// verifySubject compares the sub claim against cmp.
|
||||||
//
|
//
|
||||||
// If sub is not set, it will return true if the claim is not required,
|
// If sub is not set, it will succeed if the claim is not required,
|
||||||
// otherwise false will be returned.
|
// otherwise ErrTokenRequiredClaimMissing will be returned.
|
||||||
//
|
//
|
||||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||||
// the wrong type, false will be returned.
|
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||||
func (v *validator) VerifySubject(claims Claims, cmp string, required bool) bool {
|
func (v *validator) verifySubject(claims Claims, cmp string, required bool) error {
|
||||||
sub, err := claims.GetSubject()
|
sub, err := claims.GetSubject()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if sub == "" {
|
if sub == "" {
|
||||||
return !required
|
return errorIfRequired(required, "sub")
|
||||||
}
|
}
|
||||||
|
|
||||||
return sub == cmp
|
return errorIfFalse(sub == cmp, ErrTokenInvalidSubject)
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorIfFalse returns the error specified in err, if the value is true.
|
||||||
|
// Otherwise, nil is returned.
|
||||||
|
func errorIfFalse(value bool, err error) error {
|
||||||
|
if value {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorIfRequired returns an ErrTokenRequiredClaimMissing error if required is
|
||||||
|
// true. Otherwise, nil is returned.
|
||||||
|
func errorIfRequired(required bool, claim string) error {
|
||||||
|
if required {
|
||||||
|
return newError(fmt.Sprintf("%s claim is required", claim), ErrTokenRequiredClaimMissing)
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,261 @@
|
||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrFooBar = errors.New("must be foobar")
|
||||||
|
|
||||||
|
type MyCustomClaims struct {
|
||||||
|
Foo string `json:"foo"`
|
||||||
|
RegisteredClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MyCustomClaims) Validate() error {
|
||||||
|
if m.Foo != "bar" {
|
||||||
|
return ErrFooBar
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validator_Validate(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
leeway time.Duration
|
||||||
|
timeFunc func() time.Time
|
||||||
|
verifyIat bool
|
||||||
|
expectedAud string
|
||||||
|
expectedIss string
|
||||||
|
expectedSub string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
claims Claims
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "expected iss mismatch",
|
||||||
|
fields: fields{expectedIss: "me"},
|
||||||
|
args: args{RegisteredClaims{Issuer: "not_me"}},
|
||||||
|
wantErr: ErrTokenInvalidIssuer,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expected iss is missing",
|
||||||
|
fields: fields{expectedIss: "me"},
|
||||||
|
args: args{RegisteredClaims{}},
|
||||||
|
wantErr: ErrTokenRequiredClaimMissing,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expected sub mismatch",
|
||||||
|
fields: fields{expectedSub: "me"},
|
||||||
|
args: args{RegisteredClaims{Subject: "not-me"}},
|
||||||
|
wantErr: ErrTokenInvalidSubject,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expected sub is missing",
|
||||||
|
fields: fields{expectedSub: "me"},
|
||||||
|
args: args{RegisteredClaims{}},
|
||||||
|
wantErr: ErrTokenRequiredClaimMissing,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom validator",
|
||||||
|
fields: fields{},
|
||||||
|
args: args{MyCustomClaims{Foo: "not-bar"}},
|
||||||
|
wantErr: ErrFooBar,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v := &validator{
|
||||||
|
leeway: tt.fields.leeway,
|
||||||
|
timeFunc: tt.fields.timeFunc,
|
||||||
|
verifyIat: tt.fields.verifyIat,
|
||||||
|
expectedAud: tt.fields.expectedAud,
|
||||||
|
expectedIss: tt.fields.expectedIss,
|
||||||
|
expectedSub: tt.fields.expectedSub,
|
||||||
|
}
|
||||||
|
if err := v.Validate(tt.args.claims); (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Errorf("validator.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validator_verifyExpiresAt(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
leeway time.Duration
|
||||||
|
timeFunc func() time.Time
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
claims Claims
|
||||||
|
cmp time.Time
|
||||||
|
required bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "good claim",
|
||||||
|
fields: fields{timeFunc: time.Now},
|
||||||
|
args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(time.Now().Add(10 * time.Minute))}},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "claims with invalid type",
|
||||||
|
fields: fields{},
|
||||||
|
args: args{claims: MapClaims{"exp": "string"}},
|
||||||
|
wantErr: ErrInvalidType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v := &validator{
|
||||||
|
leeway: tt.fields.leeway,
|
||||||
|
timeFunc: tt.fields.timeFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := v.verifyExpiresAt(tt.args.claims, tt.args.cmp, tt.args.required)
|
||||||
|
if (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Errorf("validator.verifyExpiresAt() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validator_verifyIssuer(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
expectedIss string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
claims Claims
|
||||||
|
cmp string
|
||||||
|
required bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "good claim",
|
||||||
|
fields: fields{expectedIss: "me"},
|
||||||
|
args: args{claims: MapClaims{"iss": "me"}, cmp: "me"},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "claims with invalid type",
|
||||||
|
fields: fields{expectedIss: "me"},
|
||||||
|
args: args{claims: MapClaims{"iss": 1}, cmp: "me"},
|
||||||
|
wantErr: ErrInvalidType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v := &validator{
|
||||||
|
expectedIss: tt.fields.expectedIss,
|
||||||
|
}
|
||||||
|
err := v.verifyIssuer(tt.args.claims, tt.args.cmp, tt.args.required)
|
||||||
|
if (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Errorf("validator.verifyIssuer() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validator_verifySubject(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
expectedSub string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
claims Claims
|
||||||
|
cmp string
|
||||||
|
required bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "good claim",
|
||||||
|
fields: fields{expectedSub: "me"},
|
||||||
|
args: args{claims: MapClaims{"sub": "me"}, cmp: "me"},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "claims with invalid type",
|
||||||
|
fields: fields{expectedSub: "me"},
|
||||||
|
args: args{claims: MapClaims{"sub": 1}, cmp: "me"},
|
||||||
|
wantErr: ErrInvalidType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v := &validator{
|
||||||
|
expectedSub: tt.fields.expectedSub,
|
||||||
|
}
|
||||||
|
err := v.verifySubject(tt.args.claims, tt.args.cmp, tt.args.required)
|
||||||
|
if (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Errorf("validator.verifySubject() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validator_verifyIssuedAt(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
leeway time.Duration
|
||||||
|
timeFunc func() time.Time
|
||||||
|
verifyIat bool
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
claims Claims
|
||||||
|
cmp time.Time
|
||||||
|
required bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "good claim without iat",
|
||||||
|
fields: fields{verifyIat: true},
|
||||||
|
args: args{claims: MapClaims{}, required: false},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "good claim with iat",
|
||||||
|
fields: fields{verifyIat: true},
|
||||||
|
args: args{
|
||||||
|
claims: RegisteredClaims{IssuedAt: NewNumericDate(time.Now())},
|
||||||
|
cmp: time.Now().Add(10 * time.Minute),
|
||||||
|
required: false,
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v := &validator{
|
||||||
|
leeway: tt.fields.leeway,
|
||||||
|
timeFunc: tt.fields.timeFunc,
|
||||||
|
verifyIat: tt.fields.verifyIat,
|
||||||
|
}
|
||||||
|
if err := v.verifyIssuedAt(tt.args.claims, tt.args.cmp, tt.args.required); (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Errorf("validator.verifyIssuedAt() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue