Remove global state for audience marshalling

This commit is contained in:
John Maguire 2024-07-30 11:54:03 +01:00
parent 62e504c281
commit d7e0970ea1
8 changed files with 147 additions and 44 deletions

View File

@ -12,5 +12,5 @@ type Claims interface {
GetNotBefore() (*NumericDate, error) GetNotBefore() (*NumericDate, error)
GetIssuer() (string, error) GetIssuer() (string, error)
GetSubject() (string, error) GetSubject() (string, error)
GetAudience() (ClaimStrings, error) GetAudience() (*ClaimStrings, error)
} }

View File

@ -50,7 +50,7 @@ func ExampleNewWithClaims_customClaimsType() {
Issuer: "test", Issuer: "test",
Subject: "somebody", Subject: "somebody",
ID: "1", ID: "1",
Audience: []string{"somebody_else"}, Audience: jwt.NewClaimStrings([]string{"somebody_else"}),
}, },
} }

View File

@ -25,7 +25,7 @@ func (m MapClaims) GetIssuedAt() (*NumericDate, error) {
} }
// GetAudience implements the Claims interface. // GetAudience implements the Claims interface.
func (m MapClaims) GetAudience() (ClaimStrings, error) { func (m MapClaims) GetAudience() (*ClaimStrings, error) {
return m.parseClaimsString("aud") return m.parseClaimsString("aud")
} }
@ -66,7 +66,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) {
// parseClaimsString tries to parse a key in the map claims type as a // parseClaimsString tries to parse a key in the map claims type as a
// [ClaimsStrings] type, which can either be a string or an array of string. // [ClaimsStrings] type, which can either be a string or an array of string.
func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) { func (m MapClaims) parseClaimsString(key string) (*ClaimStrings, error) {
var cs []string var cs []string
switch v := m[key].(type) { switch v := m[key].(type) {
case string: case string:
@ -83,7 +83,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
} }
} }
return cs, nil return NewClaimStrings(cs), nil
} }
// parseString tries to parse a key in the map claims type as a [string] type. // parseString tries to parse a key in the map claims type as a [string] type.

View File

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"slices"
"testing" "testing"
"time" "time"
@ -360,7 +361,7 @@ var jwtTestData = []struct {
"", "",
defaultKeyFunc, defaultKeyFunc,
&jwt.RegisteredClaims{ &jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"test"}, Audience: jwt.NewClaimStrings([]string{"test"}),
}, },
true, true,
nil, nil,
@ -372,7 +373,7 @@ var jwtTestData = []struct {
"", "",
defaultKeyFunc, defaultKeyFunc,
&jwt.RegisteredClaims{ &jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"test", "test"}, Audience: jwt.NewClaimStrings([]string{"test", "test"}),
}, },
true, true,
nil, nil,
@ -384,7 +385,7 @@ var jwtTestData = []struct {
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOjF9.8mAIDUfZNQT3TGm1QFIQp91OCpJpQpbB1-m9pA2mkHc", // { "aud": 1 } "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOjF9.8mAIDUfZNQT3TGm1QFIQp91OCpJpQpbB1-m9pA2mkHc", // { "aud": 1 }
defaultKeyFunc, defaultKeyFunc,
&jwt.RegisteredClaims{ &jwt.RegisteredClaims{
Audience: nil, // because of the unmarshal error, this will be empty Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty
}, },
false, false,
[]error{jwt.ErrTokenMalformed}, []error{jwt.ErrTokenMalformed},
@ -396,7 +397,7 @@ var jwtTestData = []struct {
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsidGVzdCIsMV19.htEBUf7BVbfSmVoTFjXf3y6DLmDUuLy1vTJ14_EX7Ws", // { "aud": ["test", 1] } "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsidGVzdCIsMV19.htEBUf7BVbfSmVoTFjXf3y6DLmDUuLy1vTJ14_EX7Ws", // { "aud": ["test", 1] }
defaultKeyFunc, defaultKeyFunc,
&jwt.RegisteredClaims{ &jwt.RegisteredClaims{
Audience: nil, // because of the unmarshal error, this will be empty Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty
}, },
false, false,
[]error{jwt.ErrTokenMalformed}, []error{jwt.ErrTokenMalformed},
@ -449,6 +450,50 @@ func signToken(claims jwt.Claims, signingMethod jwt.SigningMethod) string {
return test.MakeSampleToken(claims, signingMethod, privateKey) return test.MakeSampleToken(claims, signingMethod, privateKey)
} }
func claimsEqual(a, b jwt.Claims) error {
aExp, aErr := a.GetExpirationTime()
bExp, bErr := b.GetExpirationTime()
if !reflect.DeepEqual(aExp, bExp) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `exp`: expected %v vs. %v", aExp, bExp)
}
aIat, aErr := a.GetIssuedAt()
bIat, bErr := b.GetIssuedAt()
if !reflect.DeepEqual(aIat, bIat) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `iat`: expected %v vs. %v", aIat, bIat)
}
aNbf, aErr := a.GetNotBefore()
bNbf, bErr := b.GetNotBefore()
if !reflect.DeepEqual(aNbf, bNbf) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `nbf`: expected %v vs. %v", aNbf, bNbf)
}
aIss, aErr := a.GetIssuer()
bIss, bErr := b.GetIssuer()
if aIss != bIss || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `iss`: expected %v vs. %v", aIss, bIss)
}
aSub, aErr := a.GetSubject()
bSub, bErr := b.GetSubject()
if aSub != bSub || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `sub`: expected %v vs. %v", aSub, bSub)
}
aAud, aErr := a.GetAudience()
bAud, bErr := b.GetAudience()
if aAud != bAud {
if aAud == nil || bAud == nil {
return fmt.Errorf("mismatched `aud`: expected %v vs. %v", aAud, bAud)
}
if !slices.Equal(aAud.Claims(), bAud.Claims()) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `aud`: expected %v vs %v", aAud, bAud)
}
}
return nil
}
func TestParser_Parse(t *testing.T) { func TestParser_Parse(t *testing.T) {
// Iterate over test data set and run tests // Iterate over test data set and run tests
for _, data := range jwtTestData { for _, data := range jwtTestData {
@ -476,8 +521,10 @@ func TestParser_Parse(t *testing.T) {
} }
// Verify result matches expectation // Verify result matches expectation
if data.claims != nil && !reflect.DeepEqual(data.claims, token.Claims) { if data.claims != nil {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) if err := claimsEqual(data.claims, token.Claims); err != nil {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v: %v", data.name, data.claims, token.Claims, err)
}
} }
if data.valid && err != nil { if data.valid && err != nil {
@ -557,8 +604,8 @@ func TestParser_ParseUnverified(t *testing.T) {
} }
// Verify result matches expectation // Verify result matches expectation
if !reflect.DeepEqual(data.claims, token.Claims) { if err := claimsEqual(data.claims, token.Claims); err != nil {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v %v", data.name, data.claims, token.Claims, err)
} }
if data.valid && err != nil { if data.valid && err != nil {

View File

@ -6,7 +6,7 @@ package jwt
// //
// This type can be used on its own, but then additional private and // This type can be used on its own, but then additional private and
// public claims embedded in the JWT will not be parsed. The typical use-case // public claims embedded in the JWT will not be parsed. The typical use-case
// therefore is to embedded this in a user-defined claim type. // therefore is to embed this in a user-defined claim type.
// //
// See examples for how to use this with your own claim types. // See examples for how to use this with your own claim types.
type RegisteredClaims struct { type RegisteredClaims struct {
@ -17,7 +17,7 @@ type RegisteredClaims struct {
Subject string `json:"sub,omitempty"` Subject string `json:"sub,omitempty"`
// the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3 // the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3
Audience ClaimStrings `json:"aud,omitempty"` Audience *ClaimStrings `json:"aud,omitempty"`
// the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4 // the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt *NumericDate `json:"exp,omitempty"` ExpiresAt *NumericDate `json:"exp,omitempty"`
@ -48,7 +48,7 @@ func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) {
} }
// GetAudience implements the Claims interface. // GetAudience implements the Claims interface.
func (c RegisteredClaims) GetAudience() (ClaimStrings, error) { func (c RegisteredClaims) GetAudience() (*ClaimStrings, error) {
return c.Audience, nil return c.Audience, nil
} }

View File

@ -17,16 +17,6 @@ import (
// no fractional timestamps are generated. // no fractional timestamps are generated.
var TimePrecision = time.Second var TimePrecision = time.Second
// MarshalSingleStringAsArray modifies the behavior of the ClaimStrings type,
// especially its MarshalJSON function.
//
// If it is set to true (the default), it will always serialize the type as an
// array of strings, even if it just contains one element, defaulting to the
// behavior of the underlying []string. If it is set to false, it will serialize
// to a single string, if it contains one element. Otherwise, it will serialize
// to an array of strings.
var MarshalSingleStringAsArray = true
// NumericDate represents a JSON numeric date value, as referenced at // NumericDate represents a JSON numeric date value, as referenced at
// https://datatracker.ietf.org/doc/html/rfc7519#section-2. // https://datatracker.ietf.org/doc/html/rfc7519#section-2.
type NumericDate struct { type NumericDate struct {
@ -100,10 +90,52 @@ func (date *NumericDate) UnmarshalJSON(b []byte) (err error) {
// ClaimStrings is basically just a slice of strings, but it can be either // ClaimStrings is basically just a slice of strings, but it can be either
// serialized from a string array or just a string. This type is necessary, // serialized from a string array or just a string. This type is necessary,
// since the "aud" claim can either be a single string or an array. // since the "aud" claim can either be a single string or an array.
type ClaimStrings []string type ClaimStrings struct {
claims []string
marshalSingleStringAsArray bool
}
type ClaimStringOption func(*ClaimStrings)
func NewClaimStrings(claims []string, opts ...ClaimStringOption) *ClaimStrings {
ret := ClaimStrings{
claims: claims,
marshalSingleStringAsArray: true,
}
for _, opt := range opts {
opt(&ret)
}
return &ret
}
// WithMarshalSingleStringAsArray modifies the behavior of the ClaimStrings type,
// especially its MarshalJSON function.
//
// If it is set to true (the default), it will always serialize the type as an
// array of strings, even if it just contains one element, defaulting to the
// behavior of the underlying []string. If it is set to false, it will serialize
// to a single string, if it contains one element. Otherwise, it will serialize
// to an array of strings.
func WithMarshalSingleStringAsArray(marshalSingleStringAsArray bool) func(claims *ClaimStrings) {
return func(claims *ClaimStrings) {
claims.marshalSingleStringAsArray = marshalSingleStringAsArray
}
}
func (s *ClaimStrings) Len() int {
return len(s.claims)
}
func (s *ClaimStrings) Claims() []string {
return s.claims
}
func (s *ClaimStrings) String() string {
return fmt.Sprintf("%v", s.claims)
}
func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) { func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
var value interface{} var value any
if err = json.Unmarshal(data, &value); err != nil { if err = json.Unmarshal(data, &value); err != nil {
return err return err
@ -115,7 +147,7 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
case string: case string:
aud = append(aud, v) aud = append(aud, v)
case []string: case []string:
aud = ClaimStrings(v) aud = v
case []interface{}: case []interface{}:
for _, vv := range v { for _, vv := range v {
vs, ok := vv.(string) vs, ok := vv.(string)
@ -130,20 +162,20 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
return ErrInvalidType return ErrInvalidType
} }
*s = aud s.claims = aud
return return
} }
func (s ClaimStrings) MarshalJSON() (b []byte, err error) { func (s *ClaimStrings) MarshalJSON() (b []byte, err error) {
// This handles a special case in the JWT RFC. If the string array, e.g. // This handles a special case in the JWT RFC. If the string array, e.g.
// used by the "aud" field, only contains one element, it MAY be serialized // used by the "aud" field, only contains one element, it MAY be serialized
// as a single string. This may or may not be desired based on the ecosystem // as a single string. This may or may not be desired based on the ecosystem
// of other JWT library used, so we make it configurable by the variable // of other JWT library used, so we make it configurable by the variable
// MarshalSingleStringAsArray. // MarshalSingleStringAsArray.
if len(s) == 1 && !MarshalSingleStringAsArray { if len(s.claims) == 1 && !s.marshalSingleStringAsArray {
return json.Marshal(s[0]) return json.Marshal(s.claims[0])
} }
return json.Marshal([]string(s)) return json.Marshal(s.claims)
} }

View File

@ -2,6 +2,7 @@ package jwt_test
import ( import (
"encoding/json" "encoding/json"
"errors"
"math" "math"
"testing" "testing"
"time" "time"
@ -34,13 +35,12 @@ func TestNumericDate(t *testing.T) {
jwt.TimePrecision = oldPrecision jwt.TimePrecision = oldPrecision
} }
func TestSingleArrayMarshal(t *testing.T) { func TestClaimStrings(t *testing.T) {
jwt.MarshalSingleStringAsArray = false s := jwt.NewClaimStrings([]string{"test"})
expected := `["test"]`
s := jwt.ClaimStrings{"test"}
expected := `"test"`
b, err := json.Marshal(s) b, err := json.Marshal(s)
if err != nil { if err != nil {
t.Errorf("Unexpected error: %s", err) t.Errorf("Unexpected error: %s", err)
} }
@ -48,13 +48,37 @@ func TestSingleArrayMarshal(t *testing.T) {
if expected != string(b) { if expected != string(b) {
t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", expected, string(b)) t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", expected, string(b))
} }
}
jwt.MarshalSingleStringAsArray = true func TestClaimStringsInvalidType(t *testing.T) {
j := `1`
var s jwt.ClaimStrings
err := json.Unmarshal([]byte(j), &s)
if !errors.Is(err, jwt.ErrInvalidType) {
t.Errorf("expected `ErrInvalidType` but was: %v", err)
}
if s.Claims() != nil {
t.Errorf("expected claims to be nil but was: %v", err)
}
}
expected = `["test"]` func TestClaimStringsMismatchedTypes(t *testing.T) {
j := `["test", 1]`
var s jwt.ClaimStrings
err := json.Unmarshal([]byte(j), &s)
if !errors.Is(err, jwt.ErrInvalidType) {
t.Errorf("expected `ErrInvalidType` but was: %v", err)
}
if s.Claims() != nil {
t.Errorf("expected claims to be nil but was: %v", err)
}
}
b, err = json.Marshal(s) func TestSingleArrayMarshal(t *testing.T) {
s := jwt.NewClaimStrings([]string{"test"}, jwt.WithMarshalSingleStringAsArray(false))
expected := `"test"`
b, err := json.Marshal(s)
if err != nil { if err != nil {
t.Errorf("Unexpected error: %s", err) t.Errorf("Unexpected error: %s", err)
} }

View File

@ -232,7 +232,7 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err
return err return err
} }
if len(aud) == 0 { if aud.Len() == 0 {
return errorIfRequired(required, "aud") return errorIfRequired(required, "aud")
} }
@ -240,7 +240,7 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err
result := false result := false
var stringClaims string var stringClaims string
for _, a := range aud { for _, a := range aud.Claims() {
if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 {
result = true result = true
} }