diff --git a/claims.go b/claims.go index 4430c84..4ea64a7 100644 --- a/claims.go +++ b/claims.go @@ -7,9 +7,9 @@ package jwt // https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 namely `exp`, // `iat`, `nbf`, `iss` and `aud`. type Claims interface { - GetExpirationTime() *NumericDate - GetIssuedAt() *NumericDate - GetNotBefore() *NumericDate - GetIssuer() string - GetAudience() ClaimStrings + GetExpirationTime() (*NumericDate, error) + GetIssuedAt() (*NumericDate, error) + GetNotBefore() (*NumericDate, error) + GetIssuer() (string, error) + GetAudience() (ClaimStrings, error) } diff --git a/map_claims.go b/map_claims.go index 93d36ba..9e1857f 100644 --- a/map_claims.go +++ b/map_claims.go @@ -2,65 +2,68 @@ package jwt import ( "encoding/json" + "errors" ) // MapClaims is a claims type that uses the map[string]interface{} for JSON decoding. // This is the default claims type if you don't supply one type MapClaims map[string]interface{} +var ErrInvalidType = errors.New("invalid type for claim") + // GetExpirationTime implements the Claims interface. -func (m MapClaims) GetExpirationTime() *NumericDate { +func (m MapClaims) GetExpirationTime() (*NumericDate, error) { return m.ParseNumericDate("exp") } // GetNotBefore implements the Claims interface. -func (m MapClaims) GetNotBefore() *NumericDate { +func (m MapClaims) GetNotBefore() (*NumericDate, error) { return m.ParseNumericDate("nbf") } // GetIssuedAt implements the Claims interface. -func (m MapClaims) GetIssuedAt() *NumericDate { +func (m MapClaims) GetIssuedAt() (*NumericDate, error) { return m.ParseNumericDate("iat") } // GetAudience implements the Claims interface. -func (m MapClaims) GetAudience() ClaimStrings { +func (m MapClaims) GetAudience() (ClaimStrings, error) { return m.ParseClaimsString("aud") } // GetIssuer implements the Claims interface. -func (m MapClaims) GetIssuer() string { +func (m MapClaims) GetIssuer() (string, error) { return m.ParseString("iss") } // ParseNumericDate tries to parse a key in the map claims type as a number // date. This will succeed, if the underlying type is either a [float64] or a // [json.Number]. Otherwise, nil will be returned. -func (m MapClaims) ParseNumericDate(key string) *NumericDate { +func (m MapClaims) ParseNumericDate(key string) (*NumericDate, error) { v, ok := m[key] if !ok { - return nil + return nil, nil } switch exp := v.(type) { case float64: if exp == 0 { - return nil + return nil, nil } - return newNumericDateFromSeconds(exp) + return newNumericDateFromSeconds(exp), nil case json.Number: v, _ := exp.Float64() - return newNumericDateFromSeconds(v) + return newNumericDateFromSeconds(v), nil } - return nil + return nil, ErrInvalidType } // ParseClaimsString tries to parse a key in the map claims type as a // [ClaimsStrings] type, which can either be a string or an array of string. -func (m MapClaims) ParseClaimsString(key string) ClaimStrings { +func (m MapClaims) ParseClaimsString(key string) (ClaimStrings, error) { var cs []string switch v := m[key].(type) { case string: @@ -71,19 +74,33 @@ func (m MapClaims) ParseClaimsString(key string) ClaimStrings { for _, a := range v { vs, ok := a.(string) if !ok { - return nil + return nil, ErrInvalidType } cs = append(cs, vs) } } - return cs + return cs, nil } -// ParseString tries to parse a key in the map claims type as a -// [string] type. Otherwise, an empty string is returned. -func (m MapClaims) ParseString(key string) string { - iss, _ := m[key].(string) +// ParseString tries to parse a key in the map claims type as a [string] type. +// If the key does not exist, an empty string is returned. If the key has the +// wrong type, an error is returned. +func (m MapClaims) ParseString(key string) (string, error) { + var ( + ok bool + raw interface{} + iss string + ) + raw, ok = m[key] + if !ok { + return "", nil + } - return iss + iss, ok = raw.(string) + if !ok { + return "", ErrInvalidType + } + + return iss, nil } diff --git a/map_claims_test.go b/map_claims_test.go index fb1d362..0aba922 100644 --- a/map_claims_test.go +++ b/map_claims_test.go @@ -1,7 +1,10 @@ package jwt -/* -TODO(oxisto): Re-enable tests with validation API +import ( + "testing" + "time" +) + func TestVerifyAud(t *testing.T) { var nilInterface interface{} var nilListInterface []interface{} @@ -39,7 +42,7 @@ func TestVerifyAud(t *testing.T) { {Name: "[]String Aud without match not required", MapClaims: MapClaims{"aud": []string{"not.example.com", "example.example.com"}}, Expected: false, Required: true, Comparison: "example.com"}, // Required = false - {Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: false, Required: true, Comparison: "example.com"}, + {Name: "Empty []String Aud without match required", MapClaims: MapClaims{"aud": []string{""}}, Expected: true, Required: false, Comparison: "example.com"}, // []interface{} {Name: "Empty []interface{} Aud without match required", MapClaims: MapClaims{"aud": nilListInterface}, Expected: true, Required: false, Comparison: "example.com"}, @@ -53,10 +56,17 @@ func TestVerifyAud(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { - got := test.MapClaims.VerifyAudience(test.Comparison, test.Required) + var opts []ValidatorOption - if got != test.Expected { - t.Errorf("Expected %v, got %v", test.Expected, got) + if test.Required { + opts = append(opts, WithAudience(test.Comparison)) + } + + validator := NewValidator(opts...) + got := validator.Validate(test.MapClaims) + + if (got == nil) != test.Expected { + t.Errorf("Expected %v, got %v", test.Expected, (got == nil)) } }) } @@ -67,9 +77,9 @@ func TestMapclaimsVerifyIssuedAtInvalidTypeString(t *testing.T) { "iat": "foo", } want := false - got := mapClaims.VerifyIssuedAt(0, false) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + got := NewValidator(WithIssuedAt()).Validate(mapClaims) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } } @@ -78,9 +88,9 @@ func TestMapclaimsVerifyNotBeforeInvalidTypeString(t *testing.T) { "nbf": "foo", } want := false - got := mapClaims.VerifyNotBefore(0, false) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + got := NewValidator().Validate(mapClaims) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } } @@ -89,33 +99,38 @@ func TestMapclaimsVerifyExpiresAtInvalidTypeString(t *testing.T) { "exp": "foo", } want := false - got := mapClaims.VerifyExpiresAt(0, false) + got := NewValidator().Validate(mapClaims) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } } func TestMapClaimsVerifyExpiresAtExpire(t *testing.T) { - exp := time.Now().Unix() + exp := time.Now() mapClaims := MapClaims{ - "exp": float64(exp), + "exp": float64(exp.Unix()), } want := false - got := mapClaims.VerifyExpiresAt(exp, true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + got := NewValidator(WithTimeFunc(func() time.Time { + return exp + })).Validate(mapClaims) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } - got = mapClaims.VerifyExpiresAt(exp+1, true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + got = NewValidator(WithTimeFunc(func() time.Time { + return exp.Add(1 * time.Second) + })).Validate(mapClaims) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } want = true - got = mapClaims.VerifyExpiresAt(exp-1, true) - if want != got { - t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + got = NewValidator(WithTimeFunc(func() time.Time { + return exp.Add(-1 * time.Second) + })).Validate(mapClaims) + if want != (got == nil) { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, (got == nil)) } } -*/ diff --git a/registered_claims.go b/registered_claims.go index 0023790..ccdd46a 100644 --- a/registered_claims.go +++ b/registered_claims.go @@ -33,26 +33,26 @@ type RegisteredClaims struct { } // GetExpirationTime implements the Claims interface. -func (c RegisteredClaims) GetExpirationTime() *NumericDate { - return c.ExpiresAt +func (c RegisteredClaims) GetExpirationTime() (*NumericDate, error) { + return c.ExpiresAt, nil } // GetNotBefore implements the Claims interface. -func (c RegisteredClaims) GetNotBefore() *NumericDate { - return c.NotBefore +func (c RegisteredClaims) GetNotBefore() (*NumericDate, error) { + return c.NotBefore, nil } // GetIssuedAt implements the Claims interface. -func (c RegisteredClaims) GetIssuedAt() *NumericDate { - return c.IssuedAt +func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) { + return c.IssuedAt, nil } // GetAudience implements the Claims interface. -func (c RegisteredClaims) GetAudience() ClaimStrings { - return c.Audience +func (c RegisteredClaims) GetAudience() (ClaimStrings, error) { + return c.Audience, nil } // GetIssuer implements the Claims interface. -func (c RegisteredClaims) GetIssuer() string { - return c.Issuer +func (c RegisteredClaims) GetIssuer() (string, error) { + return c.Issuer, nil } diff --git a/validator.go b/validator.go index 2e41133..3fc37ab 100644 --- a/validator.go +++ b/validator.go @@ -2,7 +2,6 @@ package jwt import ( "crypto/subtle" - "fmt" "time" ) @@ -62,9 +61,7 @@ func (v *Validator) Validate(claims Claims) error { } if !v.VerifyExpiresAt(claims, now, false) { - exp := claims.GetExpirationTime() - delta := now.Sub(exp.Time) - vErr.Inner = fmt.Errorf("%s by %s", ErrTokenExpired, delta) + vErr.Inner = ErrTokenExpired vErr.Errors |= ValidationErrorExpired } @@ -79,9 +76,10 @@ func (v *Validator) Validate(claims Claims) error { vErr.Errors |= ValidationErrorNotValidYet } - if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, false) { - vErr.Inner = ErrTokenNotValidYet - vErr.Errors |= ValidationErrorNotValidYet + // If we have an expected audience, we also require the audience claim + if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) { + vErr.Inner = ErrTokenInvalidAudience + vErr.Errors |= ValidationErrorAudience } // Finally, we want to give the claim itself some possibility to do some @@ -104,46 +102,68 @@ func (v *Validator) Validate(claims Claims) error { // VerifyAudience compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset func (v *Validator) VerifyAudience(claims Claims, cmp string, req bool) bool { - return verifyAud(claims.GetAudience(), cmp, req) + aud, err := claims.GetAudience() + if err != nil { + return false + } + + return verifyAud(aud, cmp, req) } // VerifyExpiresAt compares the exp claim against cmp (cmp < exp). // If req is false, it will return true, if exp is unset. func (v *Validator) VerifyExpiresAt(claims Claims, cmp time.Time, req bool) bool { - exp := claims.GetExpirationTime() - if exp == nil { - return verifyExp(nil, cmp, req, v.leeway) + var time *time.Time = nil + + exp, err := claims.GetExpirationTime() + if err != nil { + return false + } else if exp != nil { + time = &exp.Time } - return verifyExp(&exp.Time, cmp, req, v.leeway) + return verifyExp(time, cmp, req, v.leeway) } // VerifyIssuedAt compares the iat claim against cmp (cmp >= iat). // If req is false, it will return true, if iat is unset. func (v *Validator) VerifyIssuedAt(claims Claims, cmp time.Time, req bool) bool { - iat := claims.GetIssuedAt() - if iat == nil { - return verifyIat(nil, cmp, req, v.leeway) + var time *time.Time = nil + + iat, err := claims.GetIssuedAt() + if err != nil { + return false + } else if iat != nil { + time = &iat.Time } - return verifyIat(&iat.Time, cmp, req, v.leeway) + return verifyIat(time, cmp, req, v.leeway) } // VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). // If req is false, it will return true, if nbf is unset. func (v *Validator) VerifyNotBefore(claims Claims, cmp time.Time, req bool) bool { - nbf := claims.GetNotBefore() - if nbf == nil { - return verifyNbf(nil, cmp, req, v.leeway) + var time *time.Time = nil + + nbf, err := claims.GetNotBefore() + if err != nil { + return false + } else if nbf != nil { + time = &nbf.Time } - return verifyNbf(&nbf.Time, cmp, req, v.leeway) + return verifyNbf(time, cmp, req, v.leeway) } // VerifyIssuer compares the iss claim against cmp. // If required is false, this method will return true if the value matches or is unset func (v *Validator) VerifyIssuer(claims Claims, cmp string, req bool) bool { - return verifyIss(claims.GetIssuer(), cmp, req) + iss, err := claims.GetIssuer() + if err != nil { + return false + } + + return verifyIss(iss, cmp, req) } // ----- helpers