diff --git a/errors.go b/errors.go index 6acea3f..8e956a8 100644 --- a/errors.go +++ b/errors.go @@ -18,6 +18,7 @@ const ( ValidationErrorSignatureInvalid // Signature validation failed ValidationErrorExpired // Exp validation failed ValidationErrorNotValidYet // NBF validation failed + ValidationErrorIssuedAt // IAT validation failed ValidationErrorClaimsInvalid // Generic claims validation error ) diff --git a/jwt.go b/jwt.go index 19ef53b..1ecc5a2 100644 --- a/jwt.go +++ b/jwt.go @@ -21,26 +21,45 @@ type Keyfunc func(*Token) (interface{}, error) // For a type to be a Claims object, it must just have a Valid method that determines // if the token is invalid for any supported reason -type Claims interface { +type Claimer interface { Valid() error } -type MapClaim map[string]interface{} +// Structured version of Claims Section, as referenced at https://tools.ietf.org/html/rfc7519#section-4.1 +type Claims struct { + Audience string `json:"aud,omitempty"` + ExpiresAt int64 `json:"exp,omitempty"` + Id string `json:"jti,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` + Issuer string `json:"iss,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + Subject string `json:"sub,omitempty"` +} -func (m MapClaim) Valid() error { +func (c Claims) Valid() error { vErr := new(ValidationError) now := TimeFunc().Unix() - if exp, ok := m["exp"].(float64); ok { - if now > int64(exp) { - vErr.err = "token is expired" + // The claims below are optional, so if they are set to the default value in Go, let's not + // verify them. + + if c.ExpiresAt != 0 { + if c.VerifyExpiresAt(now) == false { + vErr.err = "Token is expired" vErr.Errors |= ValidationErrorExpired } } - if nbf, ok := m["nbf"].(float64); ok { - if now < int64(nbf) { - vErr.err = "token is not valid yet" + if c.IssuedAt != 0 { + if c.VerifyIssuedAt(now) == false { + vErr.err = "Token used before issued, clock skew issue?" + vErr.Errors |= ValidationErrorIssuedAt + } + } + + if c.NotBefore != 0 { + if c.VerifyNotBefore(now) == false { + vErr.err = "Token is not valid yet" vErr.Errors |= ValidationErrorNotValidYet } } @@ -52,13 +71,141 @@ func (m MapClaim) Valid() error { return vErr } +func (c *Claims) VerifyAudience(cmp string) bool { + return verifyAud(c.Audience, cmp) +} + +func (c *Claims) VerifyExpiresAt(cmp int64) bool { + return verifyExp(c.ExpiresAt, cmp) +} + +func (c *Claims) VerifyIssuedAt(cmp int64) bool { + return verifyIat(c.IssuedAt, cmp) +} + +func (c *Claims) VerifyIssuer(cmp string) bool { + return verifyIss(c.Issuer, cmp) +} + +func (c *Claims) VerifyNotBefore(cmp int64) bool { + return verifyNbf(c.NotBefore, cmp) +} + +type MapClaim map[string]interface{} + +func (m MapClaim) VerifyAudience(cmp string) bool { + val, exists := m["aud"] + if !exists { + return true // Don't fail validation if claim doesn't exist + } + + if aud, ok := val.(string); ok { + return verifyAud(aud, cmp) + } + return false +} + +func (m MapClaim) VerifyExpiresAt(cmp int64) bool { + val, exists := m["exp"] + if !exists { + return true + } + + if exp, ok := val.(float64); ok { + return verifyExp(int64(exp), cmp) + } + return false +} + +func (m MapClaim) VerifyIssuedAt(cmp int64) bool { + val, exists := m["iat"] + if !exists { + return true + } + + if iat, ok := val.(float64); ok { + return verifyIat(int64(iat), cmp) + } + return false +} + +func (m MapClaim) VerifyIssuer(cmp string) bool { + val, exists := m["iss"] + if !exists { + return true + } + + if iss, ok := val.(string); ok { + return verifyIss(iss, cmp) + } + return false +} + +func (m MapClaim) VerifyNotBefore(cmp int64) bool { + val, exists := m["nbf"] + if !exists { + return true + } + + if nbf, ok := val.(float64); ok { + return verifyNbf(int64(nbf), cmp) + } + return false +} + +func (m MapClaim) Valid() error { + vErr := new(ValidationError) + now := TimeFunc().Unix() + + if m.VerifyExpiresAt(now) == false { + vErr.err = "Token is expired" + vErr.Errors |= ValidationErrorExpired + } + + if m.VerifyIssuedAt(now) == false { + vErr.err = "Token used before issued, clock skew issue?" + vErr.Errors |= ValidationErrorIssuedAt + } + + if m.VerifyNotBefore(now) == false { + vErr.err = "Token is not valid yet" + vErr.Errors |= ValidationErrorNotValidYet + } + + if vErr.valid() { + return nil + } + + return vErr +} + +func verifyAud(aud string, cmp string) bool { + return aud == cmp +} + +func verifyExp(exp int64, now int64) bool { + return now <= exp +} + +func verifyIat(iat int64, now int64) bool { + return now >= iat +} + +func verifyIss(iss string, cmp string) bool { + return iss == cmp +} + +func verifyNbf(nbf int64, now int64) bool { + return now >= nbf +} + // A JWT Token. Different fields will be used depending on whether you're // creating or parsing/verifying a token. type Token struct { Raw string // The raw token. Populated when you Parse a token Method SigningMethod // The signing method used or to be used Header map[string]interface{} // The first segment of the token - Claims Claims // The second segment of the token + Claims Claimer // The second segment of the token Signature string // The third segment of the token. Populated when you Parse a token Valid bool // Is the token valid? Populated when you Parse/Verify a token } @@ -70,12 +217,12 @@ func New(method SigningMethod) *Token { "typ": "JWT", "alg": method.Alg(), }, - Claims: make(MapClaim), + Claims: Claims{}, Method: method, } } -func NewWithClaims(method SigningMethod, claims Claims) *Token { +func NewWithClaims(method SigningMethod, claims Claimer) *Token { return &Token{ Header: map[string]interface{}{ "typ": "JWT", @@ -127,10 +274,10 @@ func (t *Token) SigningString() (string, error) { // keyFunc will receive the parsed token and should return the key for validating. // If everything is kosher, err will be nil func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { - return ParseWithClaims(tokenString, keyFunc, make(MapClaim)) + return ParseWithClaims(tokenString, keyFunc, &Claims{}) } -func ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Claims) (*Token, error) { +func ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Claimer) (*Token, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { return nil, &ValidationError{err: "token contains an invalid number of segments", Errors: ValidationErrorMalformed} @@ -182,13 +329,13 @@ func ParseWithClaims(tokenString string, keyFunc Keyfunc, claims Claims) (*Token return token, &ValidationError{err: err.Error(), Errors: ValidationErrorUnverifiable} } - // Check expiration times - err = token.Claims.Valid() var vErr *ValidationError - // If the Claims Valid returned an error, check if it is a validation error, - // if not, convert it into one with a generic ClaimsInvalid flag set - if err != nil { + // Validate Claims + if err := token.Claims.Valid(); err != nil { + + // If the Claims Valid returned an error, check if it is a validation error, + // If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set if e, ok := err.(*ValidationError); !ok { vErr = &ValidationError{err: err.Error(), Errors: ValidationErrorClaimsInvalid} } else {