Backwards-compatible implementation of RFC7519's registered claim's structure (#15)

This PR aims at implementing compliance to RFC7519, as documented in #11 without breaking the public API. It creates a new struct `RegisteredClaims` and deprecates (but not removes) the `StandardClaims`. It introduces a new type `NumericDate`, which represents a JSON numeric date value as specified in the RFC. This allows us to handle float as well as int-based time fields in `aud`, `exp` and `nbf`. Additionally, it introduces the type `StringArray`, which is basically a wrapper around `[]string` to deal with the oddities of the JWT `aud` field.
This commit is contained in:
Christian Banse 2021-08-22 19:23:13 +02:00 committed by GitHub
parent c9ab96ba53
commit 80625fb516
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 582 additions and 167 deletions

View File

@ -52,7 +52,7 @@ Here's an example of an extension that integrates with multiple Google Cloud Pla
## Compliance
This library was last reviewed to comply with [RTF 7519](https://datatracker.ietf.org/doc/html/rfc7519) dated May 2015 with a few notable differences:
This library was last reviewed to comply with [RFC 7519](https://datatracker.ietf.org/doc/html/rfc7519) dated May 2015 with a few notable differences:
* In order to protect against accidental use of [Unsecured JWTs](https://datatracker.ietf.org/doc/html/rfc7519#section-6), tokens using `alg=none` will only be accepted if the constant `jwt.UnsafeAllowNoneSignatureType` is provided as the key.

170
claims.go
View File

@ -12,9 +12,116 @@ type Claims interface {
Valid() error
}
// StandardClaims are a structured version of the Claims Section, as referenced at
// https://tools.ietf.org/html/rfc7519#section-4.1
// See examples for how to use this with your own claim types
// RegisteredClaims are a structured version of the JWT Claims Set,
// restricted to Registered Claim Names, as referenced at
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1
//
// This type can be used on its own, but then additional private and
// public claims embedded in the JWT will not be parsed. The typical usecase
// therefore is to embedded this in a user-defined claim type.
//
// See examples for how to use this with your own claim types.
type RegisteredClaims struct {
// the `iss` (Issuer) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.1
Issuer string `json:"iss,omitempty"`
// the `sub` (Subject) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.2
Subject string `json:"sub,omitempty"`
// the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3
Audience ClaimStrings `json:"aud,omitempty"`
// the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt *NumericDate `json:"exp,omitempty"`
// the `nbf` (Not Before) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.5
NotBefore *NumericDate `json:"nbf,omitempty"`
// the `iat` (Issued At) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.6
IssuedAt *NumericDate `json:"iat,omitempty"`
// the `jti` (JWT ID) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.7
ID string `json:"jti,omitempty"`
}
// Valid validates time based claims "exp, iat, nbf".
// There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
// be considered a valid claim.
func (c RegisteredClaims) Valid() error {
vErr := new(ValidationError)
now := TimeFunc()
// The claims below are optional, by default, so if they are set to the
// default value in Go, let's not fail the verification for them.
if !c.VerifyExpiresAt(now, false) {
delta := now.Sub(c.ExpiresAt.Time)
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
vErr.Errors |= ValidationErrorExpired
}
if !c.VerifyIssuedAt(now, false) {
vErr.Inner = fmt.Errorf("Token used before issued")
vErr.Errors |= ValidationErrorIssuedAt
}
if !c.VerifyNotBefore(now, false) {
vErr.Inner = fmt.Errorf("token is not valid yet")
vErr.Errors |= ValidationErrorNotValidYet
}
if vErr.valid() {
return nil
}
return vErr
}
// VerifyAudience compares the aud claim against cmp.
// If required is false, this method will return true if the value matches or is unset
func (c *RegisteredClaims) VerifyAudience(cmp string, req bool) bool {
return verifyAud(c.Audience, cmp, req)
}
// VerifyExpiresAt compares the exp claim against cmp (cmp <= exp).
// If req is false, it will return true, if exp is unset.
func (c *RegisteredClaims) VerifyExpiresAt(cmp time.Time, req bool) bool {
if c.ExpiresAt == nil {
return verifyExp(nil, cmp, req)
}
return verifyExp(&c.ExpiresAt.Time, cmp, req)
}
// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat).
// If req is false, it will return true, if iat is unset.
func (c *RegisteredClaims) VerifyIssuedAt(cmp time.Time, req bool) bool {
if c.IssuedAt == nil {
return verifyIat(nil, cmp, req)
}
return verifyIat(&c.IssuedAt.Time, cmp, req)
}
// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
// If req is false, it will return true, if nbf is unset.
func (c *RegisteredClaims) VerifyNotBefore(cmp time.Time, req bool) bool {
if c.NotBefore == nil {
return verifyNbf(nil, cmp, req)
}
return verifyNbf(&c.NotBefore.Time, cmp, req)
}
// StandardClaims are a structured version of the JWT Claims Set, as referenced at
// https://datatracker.ietf.org/doc/html/rfc7519#section-4. They do not follow the
// specification exactly, since they were based on an earlier draft of the
// specification and not updated. The main difference is that they only
// support integer-based date fields and singular audiences. This might lead to
// incompatibilities with other JWT implementations. The use of this is discouraged, instead
// the newer RegisteredClaims struct should be used.
//
// Deprecated: Use RegisteredClaims instead for a forward-compatible way to access registered claims in a struct.
type StandardClaims struct {
Audience string `json:"aud,omitempty"`
ExpiresAt int64 `json:"exp,omitempty"`
@ -66,13 +173,34 @@ func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool {
// VerifyExpiresAt compares the exp claim against cmp (cmp <= exp).
// If req is false, it will return true, if exp is unset.
func (c *StandardClaims) VerifyExpiresAt(cmp int64, req bool) bool {
return verifyExp(c.ExpiresAt, cmp, req)
if c.ExpiresAt == 0 {
return verifyExp(nil, time.Unix(cmp, 0), req)
}
t := time.Unix(c.ExpiresAt, 0)
return verifyExp(&t, time.Unix(cmp, 0), req)
}
// VerifyIssuedAt compares the iat claim against cmp (cmp >= iat).
// If req is false, it will return true, if iat is unset.
func (c *StandardClaims) VerifyIssuedAt(cmp int64, req bool) bool {
return verifyIat(c.IssuedAt, cmp, req)
if c.IssuedAt == 0 {
return verifyIat(nil, time.Unix(cmp, 0), req)
}
t := time.Unix(c.IssuedAt, 0)
return verifyIat(&t, time.Unix(cmp, 0), req)
}
// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
// If req is false, it will return true, if nbf is unset.
func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool {
if c.NotBefore == 0 {
return verifyNbf(nil, time.Unix(cmp, 0), req)
}
t := time.Unix(c.NotBefore, 0)
return verifyNbf(&t, time.Unix(cmp, 0), req)
}
// VerifyIssuer compares the iss claim against cmp.
@ -81,12 +209,6 @@ func (c *StandardClaims) VerifyIssuer(cmp string, req bool) bool {
return verifyIss(c.Issuer, cmp, req)
}
// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
// If req is false, it will return true, if nbf is unset.
func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool {
return verifyNbf(c.NotBefore, cmp, req)
}
// ----- helpers
func verifyAud(aud []string, cmp string, required bool) bool {
@ -112,18 +234,25 @@ func verifyAud(aud []string, cmp string, required bool) bool {
return result
}
func verifyExp(exp int64, now int64, required bool) bool {
if exp == 0 {
func verifyExp(exp *time.Time, now time.Time, required bool) bool {
if exp == nil {
return !required
}
return now <= exp
return now.Before(*exp) || now.Equal(*exp)
}
func verifyIat(iat int64, now int64, required bool) bool {
if iat == 0 {
func verifyIat(iat *time.Time, now time.Time, required bool) bool {
if iat == nil {
return !required
}
return now >= iat
return now.After(*iat) || now.Equal(*iat)
}
func verifyNbf(nbf *time.Time, now time.Time, required bool) bool {
if nbf == nil {
return !required
}
return now.After(*nbf) || now.Equal(*nbf)
}
func verifyIss(iss string, cmp string, required bool) bool {
@ -136,10 +265,3 @@ func verifyIss(iss string, cmp string, required bool) bool {
return false
}
}
func verifyNbf(nbf int64, now int64, required bool) bool {
if nbf == 0 {
return !required
}
return now >= nbf
}

View File

@ -7,41 +7,57 @@ import (
"github.com/golang-jwt/jwt/v4"
)
// Example (atypical) using the StandardClaims type by itself to parse a token.
// The StandardClaims type is designed to be embedded into your custom types
// Example (atypical) using the RegisteredClaims type by itself to parse a token.
// The RegisteredClaims type is designed to be embedded into your custom types
// to provide standard validation features. You can use it alone, but there's
// no way to retrieve other fields after parsing.
// See the CustomClaimsType example for intended usage.
func ExampleNewWithClaims_standardClaims() {
func ExampleNewWithClaims_registeredClaims() {
mySigningKey := []byte("AllYourBase")
// Create the Claims
claims := &jwt.StandardClaims{
ExpiresAt: 15000,
claims := &jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Unix(1516239022, 0)),
Issuer: "test",
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, err := token.SignedString(mySigningKey)
fmt.Printf("%v %v", ss, err)
//Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.QsODzZu3lUZMVdhbO76u3Jv02iYCvEHcYVUI1kOWEU0 <nil>
//Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 <nil>
}
// Example creating a token using a custom claims type. The StandardClaim is embedded
// in the custom type to allow for easy encoding, parsing and validation of standard claims.
// Example creating a token using a custom claims type. The RegisteredClaims is embedded
// in the custom type to allow for easy encoding, parsing and validation of registered claims.
func ExampleNewWithClaims_customClaimsType() {
mySigningKey := []byte("AllYourBase")
type MyCustomClaims struct {
Foo string `json:"foo"`
jwt.StandardClaims
jwt.RegisteredClaims
}
// Create the Claims
// Create the claims
claims := MyCustomClaims{
"bar",
jwt.StandardClaims{
ExpiresAt: 15000,
jwt.RegisteredClaims{
// A usual scenario is to set the expiration time relative to the current time
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "test",
Subject: "somebody",
ID: "1",
Audience: []string{"somebody_else"},
},
}
// Create claims while leaving out some of the optional fields
claims = MyCustomClaims{
"bar",
jwt.RegisteredClaims{
// Also fixed dates can be used for the NumericDate
ExpiresAt: jwt.NewNumericDate(time.Unix(1516239022, 0)),
Issuer: "test",
},
}
@ -49,42 +65,31 @@ func ExampleNewWithClaims_customClaimsType() {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, err := token.SignedString(mySigningKey)
fmt.Printf("%v %v", ss, err)
//Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c <nil>
//Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM <nil>
}
// Example creating a token using a custom claims type. The StandardClaim is embedded
// in the custom type to allow for easy encoding, parsing and validation of standard claims.
func ExampleParseWithClaims_customClaimsType() {
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c"
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"
type MyCustomClaims struct {
Foo string `json:"foo"`
jwt.StandardClaims
jwt.RegisteredClaims
}
// sample token is expired. override time so it parses as valid
at(time.Unix(0, 0), func() {
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte("AllYourBase"), nil
})
if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid {
fmt.Printf("%v %v", claims.Foo, claims.StandardClaims.ExpiresAt)
fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer)
} else {
fmt.Println(err)
}
})
// Output: bar 15000
}
// Override time value for tests. Restore default value after.
func at(t time.Time, f func()) {
jwt.TimeFunc = func() time.Time {
return t
}
f()
jwt.TimeFunc = time.Now
// Output: bar test
}
// An example of parsing the error types using bitfield checks

View File

@ -73,7 +73,7 @@ type CustomerInfo struct {
}
type CustomClaimsExample struct {
*jwt.StandardClaims
*jwt.RegisteredClaims
TokenType string
CustomerInfo
}
@ -142,10 +142,10 @@ func createToken(user string) (string, error) {
// set our claims
t.Claims = &CustomClaimsExample{
&jwt.StandardClaims{
&jwt.RegisteredClaims{
// set the expire time
// see http://tools.ietf.org/html/draft-ietf-oauth-json-web-token-20#section-4.1.4
ExpiresAt: time.Now().Add(time.Minute * 1).Unix(),
// see https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute * 1)),
},
"level1",
CustomerInfo{user, "human"},

View File

@ -3,6 +3,7 @@ package jwt
import (
"encoding/json"
"errors"
"time"
// "fmt"
)
@ -34,34 +35,78 @@ func (m MapClaims) VerifyAudience(cmp string, req bool) bool {
// VerifyExpiresAt compares the exp claim against cmp (cmp <= exp).
// If req is false, it will return true, if exp is unset.
func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool {
exp, ok := m["exp"]
cmpTime := time.Unix(cmp, 0)
v, ok := m["exp"]
if !ok {
return !req
}
switch expType := exp.(type) {
switch exp := v.(type) {
case float64:
return verifyExp(int64(expType), cmp, req)
case json.Number:
v, _ := expType.Int64()
return verifyExp(v, cmp, req)
if exp == 0 {
return verifyExp(nil, cmpTime, req)
}
return verifyExp(&newNumericDateFromSeconds(exp).Time, cmpTime, req)
case json.Number:
v, _ := exp.Float64()
return verifyExp(&newNumericDateFromSeconds(v).Time, cmpTime, req)
}
return false
}
// VerifyIssuedAt compares the exp claim against cmp (cmp >= iat).
// If req is false, it will return true, if iat is unset.
func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool {
iat, ok := m["iat"]
cmpTime := time.Unix(cmp, 0)
v, ok := m["iat"]
if !ok {
return !req
}
switch iatType := iat.(type) {
switch iat := v.(type) {
case float64:
return verifyIat(int64(iatType), cmp, req)
case json.Number:
v, _ := iatType.Int64()
return verifyIat(v, cmp, req)
if iat == 0 {
return verifyIat(nil, cmpTime, req)
}
return verifyIat(&newNumericDateFromSeconds(iat).Time, cmpTime, req)
case json.Number:
v, _ := iat.Float64()
return verifyIat(&newNumericDateFromSeconds(v).Time, cmpTime, req)
}
return false
}
// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
// If req is false, it will return true, if nbf is unset.
func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool {
cmpTime := time.Unix(cmp, 0)
v, ok := m["nbf"]
if !ok {
return !req
}
switch nbf := v.(type) {
case float64:
if nbf == 0 {
return verifyNbf(nil, cmpTime, req)
}
return verifyNbf(&newNumericDateFromSeconds(nbf).Time, cmpTime, req)
case json.Number:
v, _ := nbf.Float64()
return verifyNbf(&newNumericDateFromSeconds(v).Time, cmpTime, req)
}
return false
}
@ -72,24 +117,7 @@ func (m MapClaims) VerifyIssuer(cmp string, req bool) bool {
return verifyIss(iss, cmp, req)
}
// VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf).
// If req is false, it will return true, if nbf is unset.
func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool {
nbf, ok := m["nbf"]
if !ok {
return !req
}
switch nbfType := nbf.(type) {
case float64:
return verifyNbf(int64(nbfType), cmp, req)
case json.Number:
v, _ := nbfType.Int64()
return verifyNbf(v, cmp, req)
}
return false
}
// Valid calidates time based claims "exp, iat, nbf".
// Valid validates time based claims "exp, iat, nbf".
// There is no accounting for clock skew.
// As well, if any of the above claims are not in the token, it will still
// be considered a valid claim.

View File

@ -181,6 +181,61 @@ var jwtTestData = []struct {
0,
&jwt.Parser{UseJSONNumber: true, SkipClaimsValidation: true},
},
{
"RFC7519 Claims",
"",
defaultKeyFunc,
&jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Second * 10)),
},
true,
0,
&jwt.Parser{UseJSONNumber: true},
},
{
"RFC7519 Claims - single aud",
"",
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"test"},
},
true,
0,
&jwt.Parser{UseJSONNumber: true},
},
{
"RFC7519 Claims - multiple aud",
"",
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"test", "test"},
},
true,
0,
&jwt.Parser{UseJSONNumber: true},
},
{
"RFC7519 Claims - single aud with wrong type",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOjF9.8mAIDUfZNQT3TGm1QFIQp91OCpJpQpbB1-m9pA2mkHc", // { "aud": 1 }
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: nil, // because of the unmarshal error, this will be empty
},
false,
jwt.ValidationErrorMalformed,
&jwt.Parser{UseJSONNumber: true},
},
{
"RFC7519 Claims - multiple aud with wrong types",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsidGVzdCIsMV19.htEBUf7BVbfSmVoTFjXf3y6DLmDUuLy1vTJ14_EX7Ws", // { "aud": ["test", 1] }
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: nil, // because of the unmarshal error, this will be empty
},
false,
jwt.ValidationErrorMalformed,
&jwt.Parser{UseJSONNumber: true},
},
}
func TestParser_Parse(t *testing.T) {
@ -188,6 +243,7 @@ func TestParser_Parse(t *testing.T) {
// Iterate over test data set and run tests
for _, data := range jwtTestData {
t.Run(data.name, func(t *testing.T) {
// If the token string is blank, use helper function to generate string
if data.tokenString == "" {
data.tokenString = test.MakeSampleToken(data.claims, privateKey)
@ -206,6 +262,8 @@ func TestParser_Parse(t *testing.T) {
token, err = parser.ParseWithClaims(data.tokenString, jwt.MapClaims{}, data.keyfunc)
case *jwt.StandardClaims:
token, err = parser.ParseWithClaims(data.tokenString, &jwt.StandardClaims{}, data.keyfunc)
case *jwt.RegisteredClaims:
token, err = parser.ParseWithClaims(data.tokenString, &jwt.RegisteredClaims{}, data.keyfunc)
}
// Verify result matches expectation
@ -244,6 +302,7 @@ func TestParser_Parse(t *testing.T) {
if data.valid && token.Signature == "" {
t.Errorf("[%v] Signature is left unpopulated after parsing", data.name)
}
})
}
}
@ -252,6 +311,12 @@ func TestParser_ParseUnverified(t *testing.T) {
// Iterate over test data set and run tests
for _, data := range jwtTestData {
// Skip test data, that intentionally contains malformed tokens, as they would lead to an error
if data.errors&jwt.ValidationErrorMalformed != 0 {
continue
}
t.Run(data.name, func(t *testing.T) {
// If the token string is blank, use helper function to generate string
if data.tokenString == "" {
data.tokenString = test.MakeSampleToken(data.claims, privateKey)
@ -270,6 +335,8 @@ func TestParser_ParseUnverified(t *testing.T) {
token, _, err = parser.ParseUnverified(data.tokenString, jwt.MapClaims{})
case *jwt.StandardClaims:
token, _, err = parser.ParseUnverified(data.tokenString, &jwt.StandardClaims{})
case *jwt.RegisteredClaims:
token, _, err = parser.ParseUnverified(data.tokenString, &jwt.RegisteredClaims{})
}
if err != nil {
@ -284,6 +351,7 @@ func TestParser_ParseUnverified(t *testing.T) {
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying token: %T:%v", data.name, err, err)
}
})
}
}

View File

@ -131,9 +131,9 @@ func TestRSAPSSSaltLengthCompatibility(t *testing.T) {
}
func makeToken(method jwt.SigningMethod) string {
token := jwt.NewWithClaims(method, jwt.StandardClaims{
token := jwt.NewWithClaims(method, jwt.RegisteredClaims{
Issuer: "example",
IssuedAt: time.Now().Unix(),
IssuedAt: jwt.NewNumericDate(time.Now()),
})
privateKey := test.LoadRSAPrivateKeyFromDisk("test/sample_key")
signed, err := token.SignedString(privateKey)

125
types.go Normal file
View File

@ -0,0 +1,125 @@
package jwt
import (
"encoding/json"
"fmt"
"reflect"
"strconv"
"time"
)
// TimePrecision sets the precision of times and dates within this library.
// This has an influence on the precision of times when comparing expiry or
// other related time fields. Furthermore, it is also the precision of times
// when serializing.
//
// For backwards compatibility the default precision is set to seconds, so that
// no fractional timestamps are generated.
var TimePrecision = time.Second
// MarshalSingleStringAsArray modifies the behaviour of the ClaimStrings type, especially
// its MarshalJSON function.
//
// If it is set to true (the default), it will always serialize the type as an
// array of strings, even if it just contains one element, defaulting to the behaviour
// of the underlying []string. If it is set to false, it will serialize to a single
// string, if it contains one element. Otherwise, it will serialize to an array of strings.
var MarshalSingleStringAsArray = true
// NumericDate represents a JSON numeric date value, as referenced at
// https://datatracker.ietf.org/doc/html/rfc7519#section-2.
type NumericDate struct {
time.Time
}
// NewNumericDate constructs a new *NumericDate from a standard library time.Time struct.
// It will truncate the timestamp according to the precision specified in TimePrecision.
func NewNumericDate(t time.Time) *NumericDate {
return &NumericDate{t.Truncate(TimePrecision)}
}
// newNumericDateFromSeconds creates a new *NumericDate out of a float64 representing a
// UNIX epoch with the float fraction representing non-integer seconds.
func newNumericDateFromSeconds(f float64) *NumericDate {
return NewNumericDate(time.Unix(0, int64(f*float64(time.Second))))
}
// MarshalJSON is an implementation of the json.RawMessage interface and serializes the UNIX epoch
// represented in NumericDate to a byte array, using the precision specified in TimePrecision.
func (date NumericDate) MarshalJSON() (b []byte, err error) {
f := float64(date.Truncate(TimePrecision).UnixNano()) / float64(time.Second)
return []byte(strconv.FormatFloat(f, 'f', -1, 64)), nil
}
// UnmarshalJSON is an implementation of the json.RawMessage interface and deserializses a
// NumericDate from a JSON representation, i.e. a json.Number. This number represents an UNIX epoch
// with either integer or non-integer seconds.
func (date *NumericDate) UnmarshalJSON(b []byte) (err error) {
var (
number json.Number
f float64
)
if err = json.Unmarshal(b, &number); err != nil {
return fmt.Errorf("could not parse NumericData: %w", err)
}
if f, err = number.Float64(); err != nil {
return fmt.Errorf("could not convert json number value to float: %w", err)
}
n := newNumericDateFromSeconds(f)
*date = *n
return nil
}
// ClaimStrings is basically just a slice of strings, but it can be either serialized from a string array or just a string.
// This type is necessary, since the "aud" claim can either be a single string or an array.
type ClaimStrings []string
func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
var value interface{}
if err = json.Unmarshal(data, &value); err != nil {
return err
}
var aud []string
switch v := value.(type) {
case string:
aud = append(aud, v)
case []string:
aud = ClaimStrings(v)
case []interface{}:
for _, vv := range v {
vs, ok := vv.(string)
if !ok {
return &json.UnsupportedTypeError{Type: reflect.TypeOf(vv)}
}
aud = append(aud, vs)
}
case nil:
return nil
default:
return &json.UnsupportedTypeError{Type: reflect.TypeOf(v)}
}
*s = aud
return
}
func (s ClaimStrings) MarshalJSON() (b []byte, err error) {
// This handles a special case in the JWT RFC. If the string array, e.g. used by the "aud" field,
// only contains one element, it MAY be serialized as a single string. This may or may not be
// desired based on the ecosystem of other JWT library used, so we make it configurable by the
// variable MarshalSingleStringAsArray.
if len(s) == 1 && !MarshalSingleStringAsArray {
return json.Marshal(s[0])
}
return json.Marshal([]string(s))
}

67
types_test.go Normal file
View File

@ -0,0 +1,67 @@
package jwt_test
import (
"encoding/json"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
)
func TestNumericDate(t *testing.T) {
var s struct {
Iat jwt.NumericDate `json:"iat"`
Exp jwt.NumericDate `json:"exp"`
}
oldPrecision := jwt.TimePrecision
jwt.TimePrecision = time.Microsecond
raw := `{"iat":1516239022,"exp":1516239022.12345}`
err := json.Unmarshal([]byte(raw), &s)
if err != nil {
t.Errorf("Unexpected error: %s", err)
}
b, _ := json.Marshal(s)
if raw != string(b) {
t.Errorf("Serialized format of numeric date mismatch. Expecting: %s Got: %s", string(raw), string(b))
}
jwt.TimePrecision = oldPrecision
}
func TestSingleArrayMarshal(t *testing.T) {
jwt.MarshalSingleStringAsArray = false
s := jwt.ClaimStrings{"test"}
expected := `"test"`
b, err := json.Marshal(s)
if err != nil {
t.Errorf("Unexpected error: %s", err)
}
if expected != string(b) {
t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", string(expected), string(b))
}
jwt.MarshalSingleStringAsArray = true
expected = `["test"]`
b, err = json.Marshal(s)
if err != nil {
t.Errorf("Unexpected error: %s", err)
}
if expected != string(b) {
t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", string(expected), string(b))
}
}