diff --git a/jwt.go b/jwt.go index 06995aa..36e58f4 100644 --- a/jwt.go +++ b/jwt.go @@ -19,13 +19,46 @@ var TimeFunc = time.Now // Header of the token (such as `kid`) to identify which key to use. type Keyfunc func(*Token) (interface{}, error) +// For a type to be a Claims object, it must just have a Valid method that determines +// if the token is invalid for any supported reason +type Claims interface { + Valid() error +} + +type MapClaim map[string]interface{} + +func (m MapClaim) Valid() error { + vErr := new(ValidationError) + now := TimeFunc().Unix() + + if exp, ok := m["exp"].(float64); ok { + if now > int64(exp) { + vErr.err = "token is expired" + vErr.Errors |= ValidationErrorExpired + } + } + + if nbf, ok := m["nbf"].(float64); ok { + if now < int64(nbf) { + vErr.err = "token is not valid yet" + vErr.Errors |= ValidationErrorNotValidYet + } + } + + if vErr.valid() { + return nil + } + + return vErr +} + // A JWT Token. Different fields will be used depending on whether you're // creating or parsing/verifying a token. type Token struct { Raw string // The raw token. Populated when you Parse a token Method SigningMethod // The signing method used or to be used Header map[string]interface{} // The first segment of the token - Claims map[string]interface{} // The second segment of the token + Claims Claims // The second segment of the token Signature string // The third segment of the token. Populated when you Parse a token Valid bool // Is the token valid? Populated when you Parse/Verify a token } @@ -37,7 +70,7 @@ func New(method SigningMethod) *Token { "typ": "JWT", "alg": method.Alg(), }, - Claims: make(map[string]interface{}), + Claims: make(MapClaim), Method: method, } } @@ -63,16 +96,15 @@ func (t *Token) SigningString() (string, error) { var err error parts := make([]string, 2) for i, _ := range parts { - var source map[string]interface{} - if i == 0 { - source = t.Header - } else { - source = t.Claims - } - var jsonValue []byte - if jsonValue, err = json.Marshal(source); err != nil { - return "", err + if i == 0 { + if jsonValue, err = json.Marshal(t.Header); err != nil { + return "", err + } + } else { + if jsonValue, err = json.Marshal(t.Claims); err != nil { + return "", err + } } parts[i] = EncodeSegment(jsonValue) @@ -130,20 +162,8 @@ func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { } // Check expiration times - vErr := &ValidationError{} - now := TimeFunc().Unix() - if exp, ok := token.Claims["exp"].(float64); ok { - if now > int64(exp) { - vErr.err = "token is expired" - vErr.Errors |= ValidationErrorExpired - } - } - if nbf, ok := token.Claims["nbf"].(float64); ok { - if now < int64(nbf) { - vErr.err = "token is not valid yet" - vErr.Errors |= ValidationErrorNotValidYet - } - } + err = token.Claims.Valid() + vErr := err.(ValidationError) // Perform validation if err = token.Method.Verify(strings.Join(parts[0:2], "."), parts[2], key); err != nil {