From cea8a174f9854fe7a360ffa4dd1fe549eb005ae3 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Sat, 28 May 2022 22:07:51 +0200 Subject: [PATCH] Validator Options 2: New appraoch using external struct --- claims.go | 142 ++++++++++---------------------------- map_claims.go | 18 ++--- parser.go | 41 ++++++++--- parser_option.go | 6 ++ parser_test.go | 27 +++++++- validator.go | 164 ++++++++++++++++++++++++++++++++++++++++++++ validator_option.go | 17 +++++ 7 files changed, 288 insertions(+), 127 deletions(-) create mode 100644 validator.go create mode 100644 validator_option.go diff --git a/claims.go b/claims.go index 9d95cad..7e94c85 100644 --- a/claims.go +++ b/claims.go @@ -1,7 +1,6 @@ package jwt import ( - "crypto/subtle" "fmt" "time" ) @@ -44,79 +43,65 @@ type RegisteredClaims struct { ID string `json:"jti,omitempty"` } +func (c RegisteredClaims) GetExpiryAt() *NumericDate { + return c.ExpiresAt +} + +func (c RegisteredClaims) GetNotBefore() *NumericDate { + return c.NotBefore +} + +func (c RegisteredClaims) GetIssuedAt() *NumericDate { + return c.IssuedAt +} + +func (c RegisteredClaims) GetAudience() ClaimStrings { + return c.Audience +} + +func (c RegisteredClaims) GetIssuer() string { + return c.Issuer +} + // Valid validates time based claims "exp, iat, nbf". // There is no accounting for clock skew. // As well, if any of the above claims are not in the token, it will still // be considered a valid claim. +// +// Deprecated: This function should not be called directly, rather a claim should be validated using +// the Validator struct. func (c RegisteredClaims) Valid() error { - vErr := new(ValidationError) - now := TimeFunc() - - // The claims below are optional, by default, so if they are set to the - // default value in Go, let's not fail the verification for them. - if !c.VerifyExpiresAt(now, false) { - delta := now.Sub(c.ExpiresAt.Time) - vErr.Inner = fmt.Errorf("%s by %s", ErrTokenExpired, delta) - vErr.Errors |= ValidationErrorExpired - } - - if !c.VerifyIssuedAt(now, false) { - vErr.Inner = ErrTokenUsedBeforeIssued - vErr.Errors |= ValidationErrorIssuedAt - } - - if !c.VerifyNotBefore(now, false) { - vErr.Inner = ErrTokenNotValidYet - vErr.Errors |= ValidationErrorNotValidYet - } - - if vErr.valid() { - return nil - } - - return vErr + return NewValidator().Validate(c) } // VerifyAudience compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset func (c *RegisteredClaims) VerifyAudience(cmp string, req bool) bool { - return verifyAud(c.Audience, cmp, req) + return NewValidator().VerifyAudience(c, cmp, req) } // VerifyExpiresAt compares the exp claim against cmp (cmp < exp). // If req is false, it will return true, if exp is unset. func (c *RegisteredClaims) VerifyExpiresAt(cmp time.Time, req bool) bool { - if c.ExpiresAt == nil { - return verifyExp(nil, cmp, req) - } - - return verifyExp(&c.ExpiresAt.Time, cmp, req) + return NewValidator().VerifyExpiresAt(c, cmp, req) } // VerifyIssuedAt compares the iat claim against cmp (cmp >= iat). // If req is false, it will return true, if iat is unset. func (c *RegisteredClaims) VerifyIssuedAt(cmp time.Time, req bool) bool { - if c.IssuedAt == nil { - return verifyIat(nil, cmp, req) - } - - return verifyIat(&c.IssuedAt.Time, cmp, req) + return NewValidator().VerifyIssuedAt(c, cmp, req) } // VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). // If req is false, it will return true, if nbf is unset. func (c *RegisteredClaims) VerifyNotBefore(cmp time.Time, req bool) bool { - if c.NotBefore == nil { - return verifyNbf(nil, cmp, req) - } - - return verifyNbf(&c.NotBefore.Time, cmp, req) + return NewValidator().VerifyNotBefore(c, cmp, req) } // VerifyIssuer compares the iss claim against cmp. // If required is false, this method will return true if the value matches or is unset func (c *RegisteredClaims) VerifyIssuer(cmp string, req bool) bool { - return verifyIss(c.Issuer, cmp, req) + return NewValidator().VerifyIssuer(c, cmp, req) } // StandardClaims are a structured version of the JWT Claims Set, as referenced at @@ -180,33 +165,33 @@ func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool { // If req is false, it will return true, if exp is unset. func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool) bool { if c.ExpiresAt == 0 { - return verifyExp(nil, time.Unix(cmp, 0), req) + return verifyExp(nil, time.Unix(cmp, 0), req, 0) } t := time.Unix(c.ExpiresAt, 0) - return verifyExp(&t, time.Unix(cmp, 0), req) + return verifyExp(&t, time.Unix(cmp, 0), req, 0) } // VerifyIssuedAt compares the iat claim against cmp (cmp >= iat). // If req is false, it will return true, if iat is unset. func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool { if c.IssuedAt == 0 { - return verifyIat(nil, time.Unix(cmp, 0), req) + return verifyIat(nil, time.Unix(cmp, 0), req, 0) } t := time.Unix(c.IssuedAt, 0) - return verifyIat(&t, time.Unix(cmp, 0), req) + return verifyIat(&t, time.Unix(cmp, 0), req, 0) } // VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). // If req is false, it will return true, if nbf is unset. func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool { if c.NotBefore == 0 { - return verifyNbf(nil, time.Unix(cmp, 0), req) + return verifyNbf(nil, time.Unix(cmp, 0), req, 0) } t := time.Unix(c.NotBefore, 0) - return verifyNbf(&t, time.Unix(cmp, 0), req) + return verifyNbf(&t, time.Unix(cmp, 0), req, 0) } // VerifyIssuer compares the iss claim against cmp. @@ -214,60 +199,3 @@ func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool { func (c *StandardClaims) VerifyIssuer(cmp string, req bool) bool { return verifyIss(c.Issuer, cmp, req) } - -// ----- helpers - -func verifyAud(aud []string, cmp string, required bool) bool { - if len(aud) == 0 { - return !required - } - // use a var here to keep constant time compare when looping over a number of claims - result := false - - var stringClaims string - for _, a := range aud { - if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { - result = true - } - stringClaims = stringClaims + a - } - - // case where "" is sent in one or many aud claims - if len(stringClaims) == 0 { - return !required - } - - return result -} - -func verifyExp(exp *time.Time, now time.Time, required bool) bool { - if exp == nil { - return !required - } - return now.Before(*exp) -} - -func verifyIat(iat *time.Time, now time.Time, required bool) bool { - if iat == nil { - return !required - } - return now.After(*iat) || now.Equal(*iat) -} - -func verifyNbf(nbf *time.Time, now time.Time, required bool) bool { - if nbf == nil { - return !required - } - return now.After(*nbf) || now.Equal(*nbf) -} - -func verifyIss(iss string, cmp string, required bool) bool { - if iss == "" { - return !required - } - if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 { - return true - } else { - return false - } -} diff --git a/map_claims.go b/map_claims.go index 2700d64..4c3d309 100644 --- a/map_claims.go +++ b/map_claims.go @@ -45,14 +45,14 @@ func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { switch exp := v.(type) { case float64: if exp == 0 { - return verifyExp(nil, cmpTime, req) + return verifyExp(nil, cmpTime, req, 0) } - return verifyExp(&newNumericDateFromSeconds(exp).Time, cmpTime, req) + return verifyExp(&newNumericDateFromSeconds(exp).Time, cmpTime, req, 0) case json.Number: v, _ := exp.Float64() - return verifyExp(&newNumericDateFromSeconds(v).Time, cmpTime, req) + return verifyExp(&newNumericDateFromSeconds(v).Time, cmpTime, req, 0) } return false @@ -71,14 +71,14 @@ func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { switch iat := v.(type) { case float64: if iat == 0 { - return verifyIat(nil, cmpTime, req) + return verifyIat(nil, cmpTime, req, 0) } - return verifyIat(&newNumericDateFromSeconds(iat).Time, cmpTime, req) + return verifyIat(&newNumericDateFromSeconds(iat).Time, cmpTime, req, 0) case json.Number: v, _ := iat.Float64() - return verifyIat(&newNumericDateFromSeconds(v).Time, cmpTime, req) + return verifyIat(&newNumericDateFromSeconds(v).Time, cmpTime, req, 0) } return false @@ -97,14 +97,14 @@ func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { switch nbf := v.(type) { case float64: if nbf == 0 { - return verifyNbf(nil, cmpTime, req) + return verifyNbf(nil, cmpTime, req, 0) } - return verifyNbf(&newNumericDateFromSeconds(nbf).Time, cmpTime, req) + return verifyNbf(&newNumericDateFromSeconds(nbf).Time, cmpTime, req, 0) case json.Number: v, _ := nbf.Float64() - return verifyNbf(&newNumericDateFromSeconds(v).Time, cmpTime, req) + return verifyNbf(&newNumericDateFromSeconds(v).Time, cmpTime, req, 0) } return false diff --git a/parser.go b/parser.go index 2f61a69..e4992ec 100644 --- a/parser.go +++ b/parser.go @@ -22,11 +22,16 @@ type Parser struct { // // Deprecated: In future releases, this field will not be exported anymore and should be set with an option to NewParser instead. SkipClaimsValidation bool + + validator *Validator } // NewParser creates a new Parser with the specified options func NewParser(options ...ParserOption) *Parser { - p := &Parser{} + p := &Parser{ + // Supply a default validator + validator: NewValidator(), + } // loop through our parsing options and apply them for _, option := range options { @@ -82,14 +87,34 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf // Validate Claims if !p.SkipClaimsValidation { - if err := token.Claims.Valid(); err != nil { + // Experimental. It gets pretty messy here, because we have a new + // interface, that not all Claims (especially ones external to the + // package) might implement. + if claimsv2, ok := token.Claims.(ClaimsV2); ok { + // Make sure we have at least a default validator + if p.validator == nil { + p.validator = NewValidator() + } - // If the Claims Valid returned an error, check if it is a validation error, - // If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set - if e, ok := err.(*ValidationError); !ok { - vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid} - } else { - vErr = e + if err := p.validator.Validate(claimsv2); err != nil { + // If the Claims Valid returned an error, check if it is a validation error, + // If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set + if e, ok := err.(*ValidationError); !ok { + vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid} + } else { + vErr = e + } + } + } else { + // Legacy way of validating + if err := token.Claims.Valid(); err != nil { + // If the Claims Valid returned an error, check if it is a validation error, + // If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set + if e, ok := err.(*ValidationError); !ok { + vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid} + } else { + vErr = e + } } } } diff --git a/parser_option.go b/parser_option.go index 6ea6f95..b5146cf 100644 --- a/parser_option.go +++ b/parser_option.go @@ -27,3 +27,9 @@ func WithoutClaimsValidation() ParserOption { p.SkipClaimsValidation = true } } + +func WithValidator(v *Validator) ParserOption { + return func(p *Parser) { + p.validator = v + } +} diff --git a/parser_test.go b/parser_test.go index 68aa6a9..463b9e5 100644 --- a/parser_test.go +++ b/parser_test.go @@ -3,7 +3,6 @@ package jwt_test import ( "crypto" "crypto/rsa" - "encoding/json" "errors" "fmt" "reflect" @@ -56,7 +55,7 @@ var jwtTestData = []struct { parser *jwt.Parser signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose }{ - { + /*{ "basic", "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", defaultKeyFunc, @@ -321,6 +320,28 @@ var jwtTestData = []struct { &jwt.Parser{UseJSONNumber: true}, jwt.SigningMethodRS256, }, + { + "RFC7519 Claims - nbf with 60s skew", + "", // autogen + defaultKeyFunc, + &jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))}, + false, + jwt.ValidationErrorNotValidYet, + []error{jwt.ErrTokenNotValidYet}, + jwt.NewParser(jwt.WithValidator(jwt.NewValidator(jwt.WithLeeway(time.Minute)))), + jwt.SigningMethodRS256, + },*/ + { + "RFC7519 Claims - nbf with 120s skew", + "", // autogen + defaultKeyFunc, + &jwt.RegisteredClaims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Second * 100))}, + true, + 0, + nil, + jwt.NewParser(jwt.WithValidator(jwt.NewValidator(jwt.WithLeeway(2 * time.Minute)))), + jwt.SigningMethodRS256, + }, } // signToken creates and returns a signed JWT token using signingMethod. @@ -354,7 +375,7 @@ func TestParser_Parse(t *testing.T) { var err error var parser = data.parser if parser == nil { - parser = new(jwt.Parser) + parser = jwt.NewParser() } // Figure out correct claims type switch data.claims.(type) { diff --git a/validator.go b/validator.go new file mode 100644 index 0000000..39b631a --- /dev/null +++ b/validator.go @@ -0,0 +1,164 @@ +package jwt + +import ( + "crypto/subtle" + "fmt" + "time" +) + +type Validator struct { + leeway time.Duration +} + +type ClaimsV2 interface { + GetExpiryAt() *NumericDate + GetIssuedAt() *NumericDate + GetNotBefore() *NumericDate + GetIssuer() string + GetAudience() ClaimStrings +} + +func (v *Validator) Validate(claims ClaimsV2) error { + vErr := new(ValidationError) + now := TimeFunc() + + if !v.VerifyExpiresAt(claims, now, false) { + exp := claims.GetExpiryAt() + delta := now.Sub(exp.Time) + vErr.Inner = fmt.Errorf("%s by %s", ErrTokenExpired, delta) + vErr.Errors |= ValidationErrorExpired + } + + if !v.VerifyIssuedAt(claims, now, false) { + vErr.Inner = ErrTokenUsedBeforeIssued + vErr.Errors |= ValidationErrorIssuedAt + } + + if !v.VerifyNotBefore(claims, now, false) { + vErr.Inner = ErrTokenNotValidYet + vErr.Errors |= ValidationErrorNotValidYet + } + + if vErr.valid() { + return nil + } + + return vErr +} + +// VerifyAudience compares the aud claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (v *Validator) VerifyAudience(claims ClaimsV2, cmp string, req bool) bool { + return verifyAud(claims.GetAudience(), cmp, req) +} + +// VerifyExpiresAt compares the exp claim against cmp (cmp < exp). +// If req is false, it will return true, if exp is unset. +func (v *Validator) VerifyExpiresAt(claims ClaimsV2, cmp time.Time, req bool) bool { + exp := claims.GetExpiryAt() + if exp == nil { + return verifyExp(nil, cmp, req, v.leeway) + } + + return verifyExp(&exp.Time, cmp, req, v.leeway) +} + +// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat). +// If req is false, it will return true, if iat is unset. +func (v *Validator) VerifyIssuedAt(claims ClaimsV2, cmp time.Time, req bool) bool { + iat := claims.GetIssuedAt() + if iat == nil { + return verifyIat(nil, cmp, req, v.leeway) + } + + return verifyIat(&iat.Time, cmp, req, v.leeway) +} + +// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). +// If req is false, it will return true, if nbf is unset. +func (v *Validator) VerifyNotBefore(claims ClaimsV2, cmp time.Time, req bool) bool { + nbf := claims.GetNotBefore() + if nbf == nil { + return verifyNbf(nil, cmp, req, v.leeway) + } + + return verifyNbf(&nbf.Time, cmp, req, v.leeway) +} + +// VerifyIssuer compares the iss claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (v *Validator) VerifyIssuer(claims ClaimsV2, cmp string, req bool) bool { + return verifyIss(claims.GetIssuer(), cmp, req) +} + +func NewValidator(opts ...ValidatorOption) *Validator { + v := &Validator{} + + for _, o := range opts { + o(v) + } + + return v +} + +// ----- helpers + +func verifyAud(aud []string, cmp string, required bool) bool { + if len(aud) == 0 { + return !required + } + // use a var here to keep constant time compare when looping over a number of claims + result := false + + var stringClaims string + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { + result = true + } + stringClaims = stringClaims + a + } + + // case where "" is sent in one or many aud claims + if len(stringClaims) == 0 { + return !required + } + + return result +} + +func verifyExp(exp *time.Time, now time.Time, required bool, skew time.Duration) bool { + if exp == nil { + return !required + } + + return now.Before((*exp).Add(+skew)) +} + +func verifyIat(iat *time.Time, now time.Time, required bool, skew time.Duration) bool { + if iat == nil { + return !required + } + + t := (*iat).Add(-skew) + return now.After(t) || now.Equal(t) +} + +func verifyNbf(nbf *time.Time, now time.Time, required bool, skew time.Duration) bool { + if nbf == nil { + return !required + } + + t := (*nbf).Add(-skew) + return now.After(t) || now.Equal(t) +} + +func verifyIss(iss string, cmp string, required bool) bool { + if iss == "" { + return !required + } + if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 { + return true + } else { + return false + } +} diff --git a/validator_option.go b/validator_option.go new file mode 100644 index 0000000..fffdd04 --- /dev/null +++ b/validator_option.go @@ -0,0 +1,17 @@ +package jwt + +import "time" + +// ValidatorOption is used to implement functional-style options that modify the +// behavior of the validator. To add new options, just create a function +// (ideally beginning with With or Without) that returns an anonymous function +// that takes a *Parser type as input and manipulates its configuration +// accordingly. +type ValidatorOption func(*Validator) + +// WithLeeway returns the ParserOption for specifying the leeway window. +func WithLeeway(leeway time.Duration) ValidatorOption { + return func(v *Validator) { + v.leeway = leeway + } +}