diff --git a/README.md b/README.md index b7de5f4..e9a7e5e 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Here's an example of an extension that integrates with multiple Google Cloud Pla ## Compliance -This library was last reviewed to comply with [RTF 7519](https://datatracker.ietf.org/doc/html/rfc7519) dated May 2015 with a few notable differences: +This library was last reviewed to comply with [RFC 7519](https://datatracker.ietf.org/doc/html/rfc7519) dated May 2015 with a few notable differences: * In order to protect against accidental use of [Unsecured JWTs](https://datatracker.ietf.org/doc/html/rfc7519#section-6), tokens using `alg=none` will only be accepted if the constant `jwt.UnsafeAllowNoneSignatureType` is provided as the key. diff --git a/claims.go b/claims.go index 019f007..849c2c6 100644 --- a/claims.go +++ b/claims.go @@ -12,9 +12,116 @@ type Claims interface { Valid() error } -// StandardClaims are a structured version of the Claims Section, as referenced at -// https://tools.ietf.org/html/rfc7519#section-4.1 -// See examples for how to use this with your own claim types +// RegisteredClaims are a structured version of the JWT Claims Set, +// restricted to Registered Claim Names, as referenced at +// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1 +// +// This type can be used on its own, but then additional private and +// public claims embedded in the JWT will not be parsed. The typical usecase +// therefore is to embedded this in a user-defined claim type. +// +// See examples for how to use this with your own claim types. +type RegisteredClaims struct { + // the `iss` (Issuer) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.1 + Issuer string `json:"iss,omitempty"` + + // the `sub` (Subject) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.2 + Subject string `json:"sub,omitempty"` + + // the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3 + Audience ClaimStrings `json:"aud,omitempty"` + + // the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4 + ExpiresAt *NumericDate `json:"exp,omitempty"` + + // the `nbf` (Not Before) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.5 + NotBefore *NumericDate `json:"nbf,omitempty"` + + // the `iat` (Issued At) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.6 + IssuedAt *NumericDate `json:"iat,omitempty"` + + // the `jti` (JWT ID) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.7 + ID string `json:"jti,omitempty"` +} + +// Valid validates time based claims "exp, iat, nbf". +// There is no accounting for clock skew. +// As well, if any of the above claims are not in the token, it will still +// be considered a valid claim. +func (c RegisteredClaims) Valid() error { + vErr := new(ValidationError) + now := TimeFunc() + + // The claims below are optional, by default, so if they are set to the + // default value in Go, let's not fail the verification for them. + if !c.VerifyExpiresAt(now, false) { + delta := now.Sub(c.ExpiresAt.Time) + vErr.Inner = fmt.Errorf("token is expired by %v", delta) + vErr.Errors |= ValidationErrorExpired + } + + if !c.VerifyIssuedAt(now, false) { + vErr.Inner = fmt.Errorf("Token used before issued") + vErr.Errors |= ValidationErrorIssuedAt + } + + if !c.VerifyNotBefore(now, false) { + vErr.Inner = fmt.Errorf("token is not valid yet") + vErr.Errors |= ValidationErrorNotValidYet + } + + if vErr.valid() { + return nil + } + + return vErr +} + +// VerifyAudience compares the aud claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (c *RegisteredClaims) VerifyAudience(cmp string, req bool) bool { + return verifyAud(c.Audience, cmp, req) +} + +// VerifyExpiresAt compares the exp claim against cmp (cmp <= exp). +// If req is false, it will return true, if exp is unset. +func (c *RegisteredClaims) VerifyExpiresAt(cmp time.Time, req bool) bool { + if c.ExpiresAt == nil { + return verifyExp(nil, cmp, req) + } + + return verifyExp(&c.ExpiresAt.Time, cmp, req) +} + +// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat). +// If req is false, it will return true, if iat is unset. +func (c *RegisteredClaims) VerifyIssuedAt(cmp time.Time, req bool) bool { + if c.IssuedAt == nil { + return verifyIat(nil, cmp, req) + } + + return verifyIat(&c.IssuedAt.Time, cmp, req) +} + +// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). +// If req is false, it will return true, if nbf is unset. +func (c *RegisteredClaims) VerifyNotBefore(cmp time.Time, req bool) bool { + if c.NotBefore == nil { + return verifyNbf(nil, cmp, req) + } + + return verifyNbf(&c.NotBefore.Time, cmp, req) +} + +// StandardClaims are a structured version of the JWT Claims Set, as referenced at +// https://datatracker.ietf.org/doc/html/rfc7519#section-4. They do not follow the +// specification exactly, since they were based on an earlier draft of the +// specification and not updated. The main difference is that they only +// support integer-based date fields and singular audiences. This might lead to +// incompatibilities with other JWT implementations. The use of this is discouraged, instead +// the newer RegisteredClaims struct should be used. +// +// Deprecated: Use RegisteredClaims instead for a forward-compatible way to access registered claims in a struct. type StandardClaims struct { Audience string `json:"aud,omitempty"` ExpiresAt int64 `json:"exp,omitempty"` @@ -66,13 +173,34 @@ func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool { // VerifyExpiresAt compares the exp claim against cmp (cmp <= exp). // If req is false, it will return true, if exp is unset. func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool) bool { - return verifyExp(c.ExpiresAt, cmp, req) + if c.ExpiresAt == 0 { + return verifyExp(nil, time.Unix(cmp, 0), req) + } + + t := time.Unix(c.ExpiresAt, 0) + return verifyExp(&t, time.Unix(cmp, 0), req) } // VerifyIssuedAt compares the iat claim against cmp (cmp >= iat). // If req is false, it will return true, if iat is unset. func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool { - return verifyIat(c.IssuedAt, cmp, req) + if c.IssuedAt == 0 { + return verifyIat(nil, time.Unix(cmp, 0), req) + } + + t := time.Unix(c.IssuedAt, 0) + return verifyIat(&t, time.Unix(cmp, 0), req) +} + +// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). +// If req is false, it will return true, if nbf is unset. +func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool { + if c.NotBefore == 0 { + return verifyNbf(nil, time.Unix(cmp, 0), req) + } + + t := time.Unix(c.NotBefore, 0) + return verifyNbf(&t, time.Unix(cmp, 0), req) } // VerifyIssuer compares the iss claim against cmp. @@ -81,12 +209,6 @@ func (c *StandardClaims) VerifyIssuer(cmp string, req bool) bool { return verifyIss(c.Issuer, cmp, req) } -// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). -// If req is false, it will return true, if nbf is unset. -func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool { - return verifyNbf(c.NotBefore, cmp, req) -} - // ----- helpers func verifyAud(aud []string, cmp string, required bool) bool { @@ -112,18 +234,25 @@ func verifyAud(aud []string, cmp string, required bool) bool { return result } -func verifyExp(exp int64, now int64, required bool) bool { - if exp == 0 { +func verifyExp(exp *time.Time, now time.Time, required bool) bool { + if exp == nil { return !required } - return now <= exp + return now.Before(*exp) || now.Equal(*exp) } -func verifyIat(iat int64, now int64, required bool) bool { - if iat == 0 { +func verifyIat(iat *time.Time, now time.Time, required bool) bool { + if iat == nil { return !required } - return now >= iat + return now.After(*iat) || now.Equal(*iat) +} + +func verifyNbf(nbf *time.Time, now time.Time, required bool) bool { + if nbf == nil { + return !required + } + return now.After(*nbf) || now.Equal(*nbf) } func verifyIss(iss string, cmp string, required bool) bool { @@ -136,10 +265,3 @@ func verifyIss(iss string, cmp string, required bool) bool { return false } } - -func verifyNbf(nbf int64, now int64, required bool) bool { - if nbf == 0 { - return !required - } - return now >= nbf -} diff --git a/example_test.go b/example_test.go index aae1c55..7815757 100644 --- a/example_test.go +++ b/example_test.go @@ -7,41 +7,57 @@ import ( "github.com/golang-jwt/jwt/v4" ) -// Example (atypical) using the StandardClaims type by itself to parse a token. -// The StandardClaims type is designed to be embedded into your custom types +// Example (atypical) using the RegisteredClaims type by itself to parse a token. +// The RegisteredClaims type is designed to be embedded into your custom types // to provide standard validation features. You can use it alone, but there's // no way to retrieve other fields after parsing. // See the CustomClaimsType example for intended usage. -func ExampleNewWithClaims_standardClaims() { +func ExampleNewWithClaims_registeredClaims() { mySigningKey := []byte("AllYourBase") // Create the Claims - claims := &jwt.StandardClaims{ - ExpiresAt: 15000, + claims := &jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Unix(1516239022, 0)), Issuer: "test", } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) ss, err := token.SignedString(mySigningKey) fmt.Printf("%v %v", ss, err) - //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.QsODzZu3lUZMVdhbO76u3Jv02iYCvEHcYVUI1kOWEU0 + //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 } -// Example creating a token using a custom claims type. The StandardClaim is embedded -// in the custom type to allow for easy encoding, parsing and validation of standard claims. +// Example creating a token using a custom claims type. The RegisteredClaims is embedded +// in the custom type to allow for easy encoding, parsing and validation of registered claims. func ExampleNewWithClaims_customClaimsType() { mySigningKey := []byte("AllYourBase") type MyCustomClaims struct { Foo string `json:"foo"` - jwt.StandardClaims + jwt.RegisteredClaims } - // Create the Claims + // Create the claims claims := MyCustomClaims{ "bar", - jwt.StandardClaims{ - ExpiresAt: 15000, + jwt.RegisteredClaims{ + // A usual scenario is to set the expiration time relative to the current time + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Issuer: "test", + Subject: "somebody", + ID: "1", + Audience: []string{"somebody_else"}, + }, + } + + // Create claims while leaving out some of the optional fields + claims = MyCustomClaims{ + "bar", + jwt.RegisteredClaims{ + // Also fixed dates can be used for the NumericDate + ExpiresAt: jwt.NewNumericDate(time.Unix(1516239022, 0)), Issuer: "test", }, } @@ -49,42 +65,31 @@ func ExampleNewWithClaims_customClaimsType() { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) ss, err := token.SignedString(mySigningKey) fmt.Printf("%v %v", ss, err) - //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c + + //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM } // Example creating a token using a custom claims type. The StandardClaim is embedded // in the custom type to allow for easy encoding, parsing and validation of standard claims. func ExampleParseWithClaims_customClaimsType() { - tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c" + tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" type MyCustomClaims struct { Foo string `json:"foo"` - jwt.StandardClaims + jwt.RegisteredClaims } - // sample token is expired. override time so it parses as valid - at(time.Unix(0, 0), func() { - token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte("AllYourBase"), nil - }) - - if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { - fmt.Printf("%v %v", claims.Foo, claims.StandardClaims.ExpiresAt) - } else { - fmt.Println(err) - } + token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte("AllYourBase"), nil }) - // Output: bar 15000 -} - -// Override time value for tests. Restore default value after. -func at(t time.Time, f func()) { - jwt.TimeFunc = func() time.Time { - return t + if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { + fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) + } else { + fmt.Println(err) } - f() - jwt.TimeFunc = time.Now + + // Output: bar test } // An example of parsing the error types using bitfield checks diff --git a/http_example_test.go b/http_example_test.go index 71a4cc8..c43ab5f 100644 --- a/http_example_test.go +++ b/http_example_test.go @@ -73,7 +73,7 @@ type CustomerInfo struct { } type CustomClaimsExample struct { - *jwt.StandardClaims + *jwt.RegisteredClaims TokenType string CustomerInfo } @@ -142,10 +142,10 @@ func createToken(user string) (string, error) { // set our claims t.Claims = &CustomClaimsExample{ - &jwt.StandardClaims{ + &jwt.RegisteredClaims{ // set the expire time - // see http://tools.ietf.org/html/draft-ietf-oauth-json-web-token-20#section-4.1.4 - ExpiresAt: time.Now().Add(time.Minute * 1).Unix(), + // see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4 + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)), }, "level1", CustomerInfo{user, "human"}, diff --git a/map_claims.go b/map_claims.go index c058d41..e7da633 100644 --- a/map_claims.go +++ b/map_claims.go @@ -3,6 +3,7 @@ package jwt import ( "encoding/json" "errors" + "time" // "fmt" ) @@ -34,34 +35,78 @@ func (m MapClaims) VerifyAudience(cmp string, req bool) bool { // VerifyExpiresAt compares the exp claim against cmp (cmp <= exp). // If req is false, it will return true, if exp is unset. func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { - exp, ok := m["exp"] + cmpTime := time.Unix(cmp, 0) + + v, ok := m["exp"] if !ok { return !req } - switch expType := exp.(type) { + + switch exp := v.(type) { case float64: - return verifyExp(int64(expType), cmp, req) + if exp == 0 { + return verifyExp(nil, cmpTime, req) + } + + return verifyExp(&newNumericDateFromSeconds(exp).Time, cmpTime, req) case json.Number: - v, _ := expType.Int64() - return verifyExp(v, cmp, req) + v, _ := exp.Float64() + + return verifyExp(&newNumericDateFromSeconds(v).Time, cmpTime, req) } + return false } // VerifyIssuedAt compares the exp claim against cmp (cmp >= iat). // If req is false, it will return true, if iat is unset. func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { - iat, ok := m["iat"] + cmpTime := time.Unix(cmp, 0) + + v, ok := m["iat"] if !ok { return !req } - switch iatType := iat.(type) { + + switch iat := v.(type) { case float64: - return verifyIat(int64(iatType), cmp, req) + if iat == 0 { + return verifyIat(nil, cmpTime, req) + } + + return verifyIat(&newNumericDateFromSeconds(iat).Time, cmpTime, req) case json.Number: - v, _ := iatType.Int64() - return verifyIat(v, cmp, req) + v, _ := iat.Float64() + + return verifyIat(&newNumericDateFromSeconds(v).Time, cmpTime, req) } + + return false +} + +// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). +// If req is false, it will return true, if nbf is unset. +func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { + cmpTime := time.Unix(cmp, 0) + + v, ok := m["nbf"] + if !ok { + return !req + } + + switch nbf := v.(type) { + case float64: + if nbf == 0 { + return verifyNbf(nil, cmpTime, req) + } + + return verifyNbf(&newNumericDateFromSeconds(nbf).Time, cmpTime, req) + case json.Number: + v, _ := nbf.Float64() + + return verifyNbf(&newNumericDateFromSeconds(v).Time, cmpTime, req) + } + return false } @@ -72,24 +117,7 @@ func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { return verifyIss(iss, cmp, req) } -// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). -// If req is false, it will return true, if nbf is unset. -func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { - nbf, ok := m["nbf"] - if !ok { - return !req - } - switch nbfType := nbf.(type) { - case float64: - return verifyNbf(int64(nbfType), cmp, req) - case json.Number: - v, _ := nbfType.Int64() - return verifyNbf(v, cmp, req) - } - return false -} - -// Valid calidates time based claims "exp, iat, nbf". +// Valid validates time based claims "exp, iat, nbf". // There is no accounting for clock skew. // As well, if any of the above claims are not in the token, it will still // be considered a valid claim. diff --git a/parser_test.go b/parser_test.go index 2de337a..d997a0e 100644 --- a/parser_test.go +++ b/parser_test.go @@ -181,6 +181,61 @@ var jwtTestData = []struct { 0, &jwt.Parser{UseJSONNumber: true, SkipClaimsValidation: true}, }, + { + "RFC7519 Claims", + "", + defaultKeyFunc, + &jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Second * 10)), + }, + true, + 0, + &jwt.Parser{UseJSONNumber: true}, + }, + { + "RFC7519 Claims - single aud", + "", + defaultKeyFunc, + &jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{"test"}, + }, + true, + 0, + &jwt.Parser{UseJSONNumber: true}, + }, + { + "RFC7519 Claims - multiple aud", + "", + defaultKeyFunc, + &jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{"test", "test"}, + }, + true, + 0, + &jwt.Parser{UseJSONNumber: true}, + }, + { + "RFC7519 Claims - single aud with wrong type", + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOjF9.8mAIDUfZNQT3TGm1QFIQp91OCpJpQpbB1-m9pA2mkHc", // { "aud": 1 } + defaultKeyFunc, + &jwt.RegisteredClaims{ + Audience: nil, // because of the unmarshal error, this will be empty + }, + false, + jwt.ValidationErrorMalformed, + &jwt.Parser{UseJSONNumber: true}, + }, + { + "RFC7519 Claims - multiple aud with wrong types", + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsidGVzdCIsMV19.htEBUf7BVbfSmVoTFjXf3y6DLmDUuLy1vTJ14_EX7Ws", // { "aud": ["test", 1] } + defaultKeyFunc, + &jwt.RegisteredClaims{ + Audience: nil, // because of the unmarshal error, this will be empty + }, + false, + jwt.ValidationErrorMalformed, + &jwt.Parser{UseJSONNumber: true}, + }, } func TestParser_Parse(t *testing.T) { @@ -188,62 +243,66 @@ func TestParser_Parse(t *testing.T) { // Iterate over test data set and run tests for _, data := range jwtTestData { - // If the token string is blank, use helper function to generate string - if data.tokenString == "" { - data.tokenString = test.MakeSampleToken(data.claims, privateKey) - } + t.Run(data.name, func(t *testing.T) { + // If the token string is blank, use helper function to generate string + if data.tokenString == "" { + data.tokenString = test.MakeSampleToken(data.claims, privateKey) + } - // Parse the token - var token *jwt.Token - var err error - var parser = data.parser - if parser == nil { - parser = new(jwt.Parser) - } - // Figure out correct claims type - switch data.claims.(type) { - case jwt.MapClaims: - token, err = parser.ParseWithClaims(data.tokenString, jwt.MapClaims{}, data.keyfunc) - case *jwt.StandardClaims: - token, err = parser.ParseWithClaims(data.tokenString, &jwt.StandardClaims{}, data.keyfunc) - } + // Parse the token + var token *jwt.Token + var err error + var parser = data.parser + if parser == nil { + parser = new(jwt.Parser) + } + // Figure out correct claims type + switch data.claims.(type) { + case jwt.MapClaims: + token, err = parser.ParseWithClaims(data.tokenString, jwt.MapClaims{}, data.keyfunc) + case *jwt.StandardClaims: + token, err = parser.ParseWithClaims(data.tokenString, &jwt.StandardClaims{}, data.keyfunc) + case *jwt.RegisteredClaims: + token, err = parser.ParseWithClaims(data.tokenString, &jwt.RegisteredClaims{}, data.keyfunc) + } - // Verify result matches expectation - if !reflect.DeepEqual(data.claims, token.Claims) { - t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) - } + // Verify result matches expectation + if !reflect.DeepEqual(data.claims, token.Claims) { + t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) + } - if data.valid && err != nil { - t.Errorf("[%v] Error while verifying token: %T:%v", data.name, err, err) - } + if data.valid && err != nil { + t.Errorf("[%v] Error while verifying token: %T:%v", data.name, err, err) + } - if !data.valid && err == nil { - t.Errorf("[%v] Invalid token passed validation", data.name) - } + if !data.valid && err == nil { + t.Errorf("[%v] Invalid token passed validation", data.name) + } - if (err == nil && !token.Valid) || (err != nil && token.Valid) { - t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name) - } + if (err == nil && !token.Valid) || (err != nil && token.Valid) { + t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name) + } - if data.errors != 0 { - if err == nil { - t.Errorf("[%v] Expecting error. Didn't get one.", data.name) - } else { + if data.errors != 0 { + if err == nil { + t.Errorf("[%v] Expecting error. Didn't get one.", data.name) + } else { - ve := err.(*jwt.ValidationError) - // compare the bitfield part of the error - if e := ve.Errors; e != data.errors { - t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors) - } + ve := err.(*jwt.ValidationError) + // compare the bitfield part of the error + if e := ve.Errors; e != data.errors { + t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors) + } - if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError { - t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError) + if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError { + t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError) + } } } - } - if data.valid && token.Signature == "" { - t.Errorf("[%v] Signature is left unpopulated after parsing", data.name) - } + if data.valid && token.Signature == "" { + t.Errorf("[%v] Signature is left unpopulated after parsing", data.name) + } + }) } } @@ -252,38 +311,47 @@ func TestParser_ParseUnverified(t *testing.T) { // Iterate over test data set and run tests for _, data := range jwtTestData { - // If the token string is blank, use helper function to generate string - if data.tokenString == "" { - data.tokenString = test.MakeSampleToken(data.claims, privateKey) + // Skip test data, that intentionally contains malformed tokens, as they would lead to an error + if data.errors&jwt.ValidationErrorMalformed != 0 { + continue } - // Parse the token - var token *jwt.Token - var err error - var parser = data.parser - if parser == nil { - parser = new(jwt.Parser) - } - // Figure out correct claims type - switch data.claims.(type) { - case jwt.MapClaims: - token, _, err = parser.ParseUnverified(data.tokenString, jwt.MapClaims{}) - case *jwt.StandardClaims: - token, _, err = parser.ParseUnverified(data.tokenString, &jwt.StandardClaims{}) - } + t.Run(data.name, func(t *testing.T) { + // If the token string is blank, use helper function to generate string + if data.tokenString == "" { + data.tokenString = test.MakeSampleToken(data.claims, privateKey) + } - if err != nil { - t.Errorf("[%v] Invalid token", data.name) - } + // Parse the token + var token *jwt.Token + var err error + var parser = data.parser + if parser == nil { + parser = new(jwt.Parser) + } + // Figure out correct claims type + switch data.claims.(type) { + case jwt.MapClaims: + token, _, err = parser.ParseUnverified(data.tokenString, jwt.MapClaims{}) + case *jwt.StandardClaims: + token, _, err = parser.ParseUnverified(data.tokenString, &jwt.StandardClaims{}) + case *jwt.RegisteredClaims: + token, _, err = parser.ParseUnverified(data.tokenString, &jwt.RegisteredClaims{}) + } - // Verify result matches expectation - if !reflect.DeepEqual(data.claims, token.Claims) { - t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) - } + if err != nil { + t.Errorf("[%v] Invalid token", data.name) + } - if data.valid && err != nil { - t.Errorf("[%v] Error while verifying token: %T:%v", data.name, err, err) - } + // Verify result matches expectation + if !reflect.DeepEqual(data.claims, token.Claims) { + t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) + } + + if data.valid && err != nil { + t.Errorf("[%v] Error while verifying token: %T:%v", data.name, err, err) + } + }) } } diff --git a/rsa_pss_test.go b/rsa_pss_test.go index 716500e..5b895da 100644 --- a/rsa_pss_test.go +++ b/rsa_pss_test.go @@ -131,9 +131,9 @@ func TestRSAPSSSaltLengthCompatibility(t *testing.T) { } func makeToken(method jwt.SigningMethod) string { - token := jwt.NewWithClaims(method, jwt.StandardClaims{ + token := jwt.NewWithClaims(method, jwt.RegisteredClaims{ Issuer: "example", - IssuedAt: time.Now().Unix(), + IssuedAt: jwt.NewNumericDate(time.Now()), }) privateKey := test.LoadRSAPrivateKeyFromDisk("test/sample_key") signed, err := token.SignedString(privateKey) diff --git a/types.go b/types.go new file mode 100644 index 0000000..15c39a3 --- /dev/null +++ b/types.go @@ -0,0 +1,125 @@ +package jwt + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "time" +) + +// TimePrecision sets the precision of times and dates within this library. +// This has an influence on the precision of times when comparing expiry or +// other related time fields. Furthermore, it is also the precision of times +// when serializing. +// +// For backwards compatibility the default precision is set to seconds, so that +// no fractional timestamps are generated. +var TimePrecision = time.Second + +// MarshalSingleStringAsArray modifies the behaviour of the ClaimStrings type, especially +// its MarshalJSON function. +// +// If it is set to true (the default), it will always serialize the type as an +// array of strings, even if it just contains one element, defaulting to the behaviour +// of the underlying []string. If it is set to false, it will serialize to a single +// string, if it contains one element. Otherwise, it will serialize to an array of strings. +var MarshalSingleStringAsArray = true + +// NumericDate represents a JSON numeric date value, as referenced at +// https://datatracker.ietf.org/doc/html/rfc7519#section-2. +type NumericDate struct { + time.Time +} + +// NewNumericDate constructs a new *NumericDate from a standard library time.Time struct. +// It will truncate the timestamp according to the precision specified in TimePrecision. +func NewNumericDate(t time.Time) *NumericDate { + return &NumericDate{t.Truncate(TimePrecision)} +} + +// newNumericDateFromSeconds creates a new *NumericDate out of a float64 representing a +// UNIX epoch with the float fraction representing non-integer seconds. +func newNumericDateFromSeconds(f float64) *NumericDate { + return NewNumericDate(time.Unix(0, int64(f*float64(time.Second)))) +} + +// MarshalJSON is an implementation of the json.RawMessage interface and serializes the UNIX epoch +// represented in NumericDate to a byte array, using the precision specified in TimePrecision. +func (date NumericDate) MarshalJSON() (b []byte, err error) { + f := float64(date.Truncate(TimePrecision).UnixNano()) / float64(time.Second) + + return []byte(strconv.FormatFloat(f, 'f', -1, 64)), nil +} + +// UnmarshalJSON is an implementation of the json.RawMessage interface and deserializses a +// NumericDate from a JSON representation, i.e. a json.Number. This number represents an UNIX epoch +// with either integer or non-integer seconds. +func (date *NumericDate) UnmarshalJSON(b []byte) (err error) { + var ( + number json.Number + f float64 + ) + + if err = json.Unmarshal(b, &number); err != nil { + return fmt.Errorf("could not parse NumericData: %w", err) + } + + if f, err = number.Float64(); err != nil { + return fmt.Errorf("could not convert json number value to float: %w", err) + } + + n := newNumericDateFromSeconds(f) + *date = *n + + return nil +} + +// ClaimStrings is basically just a slice of strings, but it can be either serialized from a string array or just a string. +// This type is necessary, since the "aud" claim can either be a single string or an array. +type ClaimStrings []string + +func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) { + var value interface{} + + if err = json.Unmarshal(data, &value); err != nil { + return err + } + + var aud []string + + switch v := value.(type) { + case string: + aud = append(aud, v) + case []string: + aud = ClaimStrings(v) + case []interface{}: + for _, vv := range v { + vs, ok := vv.(string) + if !ok { + return &json.UnsupportedTypeError{Type: reflect.TypeOf(vv)} + } + aud = append(aud, vs) + } + case nil: + return nil + default: + return &json.UnsupportedTypeError{Type: reflect.TypeOf(v)} + } + + *s = aud + + return +} + +func (s ClaimStrings) MarshalJSON() (b []byte, err error) { + // This handles a special case in the JWT RFC. If the string array, e.g. used by the "aud" field, + // only contains one element, it MAY be serialized as a single string. This may or may not be + // desired based on the ecosystem of other JWT library used, so we make it configurable by the + // variable MarshalSingleStringAsArray. + if len(s) == 1 && !MarshalSingleStringAsArray { + return json.Marshal(s[0]) + } + + return json.Marshal([]string(s)) +} diff --git a/types_test.go b/types_test.go new file mode 100644 index 0000000..675f953 --- /dev/null +++ b/types_test.go @@ -0,0 +1,67 @@ +package jwt_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +func TestNumericDate(t *testing.T) { + var s struct { + Iat jwt.NumericDate `json:"iat"` + Exp jwt.NumericDate `json:"exp"` + } + + oldPrecision := jwt.TimePrecision + + jwt.TimePrecision = time.Microsecond + + raw := `{"iat":1516239022,"exp":1516239022.12345}` + + err := json.Unmarshal([]byte(raw), &s) + + if err != nil { + t.Errorf("Unexpected error: %s", err) + } + + b, _ := json.Marshal(s) + + if raw != string(b) { + t.Errorf("Serialized format of numeric date mismatch. Expecting: %s Got: %s", string(raw), string(b)) + } + + jwt.TimePrecision = oldPrecision +} + +func TestSingleArrayMarshal(t *testing.T) { + jwt.MarshalSingleStringAsArray = false + + s := jwt.ClaimStrings{"test"} + expected := `"test"` + + b, err := json.Marshal(s) + + if err != nil { + t.Errorf("Unexpected error: %s", err) + } + + if expected != string(b) { + t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", string(expected), string(b)) + } + + jwt.MarshalSingleStringAsArray = true + + expected = `["test"]` + + b, err = json.Marshal(s) + + if err != nil { + t.Errorf("Unexpected error: %s", err) + } + + if expected != string(b) { + t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", string(expected), string(b)) + } +}