Remove global state for audience marshalling

This commit is contained in:
John Maguire 2024-07-30 11:54:03 +01:00
parent 62e504c281
commit d7e0970ea1
8 changed files with 147 additions and 44 deletions

View File

@ -12,5 +12,5 @@ type Claims interface {
GetNotBefore() (*NumericDate, error)
GetIssuer() (string, error)
GetSubject() (string, error)
GetAudience() (ClaimStrings, error)
GetAudience() (*ClaimStrings, error)
}

View File

@ -50,7 +50,7 @@ func ExampleNewWithClaims_customClaimsType() {
Issuer: "test",
Subject: "somebody",
ID: "1",
Audience: []string{"somebody_else"},
Audience: jwt.NewClaimStrings([]string{"somebody_else"}),
},
}

View File

@ -25,7 +25,7 @@ func (m MapClaims) GetIssuedAt() (*NumericDate, error) {
}
// GetAudience implements the Claims interface.
func (m MapClaims) GetAudience() (ClaimStrings, error) {
func (m MapClaims) GetAudience() (*ClaimStrings, error) {
return m.parseClaimsString("aud")
}
@ -66,7 +66,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) {
// parseClaimsString tries to parse a key in the map claims type as a
// [ClaimsStrings] type, which can either be a string or an array of string.
func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
func (m MapClaims) parseClaimsString(key string) (*ClaimStrings, error) {
var cs []string
switch v := m[key].(type) {
case string:
@ -83,7 +83,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
}
}
return cs, nil
return NewClaimStrings(cs), nil
}
// parseString tries to parse a key in the map claims type as a [string] type.

View File

@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"reflect"
"slices"
"testing"
"time"
@ -360,7 +361,7 @@ var jwtTestData = []struct {
"",
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"test"},
Audience: jwt.NewClaimStrings([]string{"test"}),
},
true,
nil,
@ -372,7 +373,7 @@ var jwtTestData = []struct {
"",
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"test", "test"},
Audience: jwt.NewClaimStrings([]string{"test", "test"}),
},
true,
nil,
@ -384,7 +385,7 @@ var jwtTestData = []struct {
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOjF9.8mAIDUfZNQT3TGm1QFIQp91OCpJpQpbB1-m9pA2mkHc", // { "aud": 1 }
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: nil, // because of the unmarshal error, this will be empty
Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty
},
false,
[]error{jwt.ErrTokenMalformed},
@ -396,7 +397,7 @@ var jwtTestData = []struct {
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsidGVzdCIsMV19.htEBUf7BVbfSmVoTFjXf3y6DLmDUuLy1vTJ14_EX7Ws", // { "aud": ["test", 1] }
defaultKeyFunc,
&jwt.RegisteredClaims{
Audience: nil, // because of the unmarshal error, this will be empty
Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty
},
false,
[]error{jwt.ErrTokenMalformed},
@ -449,6 +450,50 @@ func signToken(claims jwt.Claims, signingMethod jwt.SigningMethod) string {
return test.MakeSampleToken(claims, signingMethod, privateKey)
}
func claimsEqual(a, b jwt.Claims) error {
aExp, aErr := a.GetExpirationTime()
bExp, bErr := b.GetExpirationTime()
if !reflect.DeepEqual(aExp, bExp) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `exp`: expected %v vs. %v", aExp, bExp)
}
aIat, aErr := a.GetIssuedAt()
bIat, bErr := b.GetIssuedAt()
if !reflect.DeepEqual(aIat, bIat) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `iat`: expected %v vs. %v", aIat, bIat)
}
aNbf, aErr := a.GetNotBefore()
bNbf, bErr := b.GetNotBefore()
if !reflect.DeepEqual(aNbf, bNbf) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `nbf`: expected %v vs. %v", aNbf, bNbf)
}
aIss, aErr := a.GetIssuer()
bIss, bErr := b.GetIssuer()
if aIss != bIss || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `iss`: expected %v vs. %v", aIss, bIss)
}
aSub, aErr := a.GetSubject()
bSub, bErr := b.GetSubject()
if aSub != bSub || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `sub`: expected %v vs. %v", aSub, bSub)
}
aAud, aErr := a.GetAudience()
bAud, bErr := b.GetAudience()
if aAud != bAud {
if aAud == nil || bAud == nil {
return fmt.Errorf("mismatched `aud`: expected %v vs. %v", aAud, bAud)
}
if !slices.Equal(aAud.Claims(), bAud.Claims()) || !reflect.DeepEqual(aErr, bErr) {
return fmt.Errorf("mismatched `aud`: expected %v vs %v", aAud, bAud)
}
}
return nil
}
func TestParser_Parse(t *testing.T) {
// Iterate over test data set and run tests
for _, data := range jwtTestData {
@ -476,8 +521,10 @@ func TestParser_Parse(t *testing.T) {
}
// Verify result matches expectation
if data.claims != nil && !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
if data.claims != nil {
if err := claimsEqual(data.claims, token.Claims); err != nil {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v: %v", data.name, data.claims, token.Claims, err)
}
}
if data.valid && err != nil {
@ -557,8 +604,8 @@ func TestParser_ParseUnverified(t *testing.T) {
}
// Verify result matches expectation
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
if err := claimsEqual(data.claims, token.Claims); err != nil {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v %v", data.name, data.claims, token.Claims, err)
}
if data.valid && err != nil {

View File

@ -6,7 +6,7 @@ package jwt
//
// This type can be used on its own, but then additional private and
// public claims embedded in the JWT will not be parsed. The typical use-case
// therefore is to embedded this in a user-defined claim type.
// therefore is to embed this in a user-defined claim type.
//
// See examples for how to use this with your own claim types.
type RegisteredClaims struct {
@ -17,7 +17,7 @@ type RegisteredClaims struct {
Subject string `json:"sub,omitempty"`
// the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3
Audience ClaimStrings `json:"aud,omitempty"`
Audience *ClaimStrings `json:"aud,omitempty"`
// the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
ExpiresAt *NumericDate `json:"exp,omitempty"`
@ -48,7 +48,7 @@ func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) {
}
// GetAudience implements the Claims interface.
func (c RegisteredClaims) GetAudience() (ClaimStrings, error) {
func (c RegisteredClaims) GetAudience() (*ClaimStrings, error) {
return c.Audience, nil
}

View File

@ -17,16 +17,6 @@ import (
// no fractional timestamps are generated.
var TimePrecision = time.Second
// MarshalSingleStringAsArray modifies the behavior of the ClaimStrings type,
// especially its MarshalJSON function.
//
// If it is set to true (the default), it will always serialize the type as an
// array of strings, even if it just contains one element, defaulting to the
// behavior of the underlying []string. If it is set to false, it will serialize
// to a single string, if it contains one element. Otherwise, it will serialize
// to an array of strings.
var MarshalSingleStringAsArray = true
// NumericDate represents a JSON numeric date value, as referenced at
// https://datatracker.ietf.org/doc/html/rfc7519#section-2.
type NumericDate struct {
@ -100,10 +90,52 @@ func (date *NumericDate) UnmarshalJSON(b []byte) (err error) {
// ClaimStrings is basically just a slice of strings, but it can be either
// serialized from a string array or just a string. This type is necessary,
// since the "aud" claim can either be a single string or an array.
type ClaimStrings []string
type ClaimStrings struct {
claims []string
marshalSingleStringAsArray bool
}
type ClaimStringOption func(*ClaimStrings)
func NewClaimStrings(claims []string, opts ...ClaimStringOption) *ClaimStrings {
ret := ClaimStrings{
claims: claims,
marshalSingleStringAsArray: true,
}
for _, opt := range opts {
opt(&ret)
}
return &ret
}
// WithMarshalSingleStringAsArray modifies the behavior of the ClaimStrings type,
// especially its MarshalJSON function.
//
// If it is set to true (the default), it will always serialize the type as an
// array of strings, even if it just contains one element, defaulting to the
// behavior of the underlying []string. If it is set to false, it will serialize
// to a single string, if it contains one element. Otherwise, it will serialize
// to an array of strings.
func WithMarshalSingleStringAsArray(marshalSingleStringAsArray bool) func(claims *ClaimStrings) {
return func(claims *ClaimStrings) {
claims.marshalSingleStringAsArray = marshalSingleStringAsArray
}
}
func (s *ClaimStrings) Len() int {
return len(s.claims)
}
func (s *ClaimStrings) Claims() []string {
return s.claims
}
func (s *ClaimStrings) String() string {
return fmt.Sprintf("%v", s.claims)
}
func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
var value interface{}
var value any
if err = json.Unmarshal(data, &value); err != nil {
return err
@ -115,7 +147,7 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
case string:
aud = append(aud, v)
case []string:
aud = ClaimStrings(v)
aud = v
case []interface{}:
for _, vv := range v {
vs, ok := vv.(string)
@ -130,20 +162,20 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
return ErrInvalidType
}
*s = aud
s.claims = aud
return
}
func (s ClaimStrings) MarshalJSON() (b []byte, err error) {
func (s *ClaimStrings) MarshalJSON() (b []byte, err error) {
// This handles a special case in the JWT RFC. If the string array, e.g.
// used by the "aud" field, only contains one element, it MAY be serialized
// as a single string. This may or may not be desired based on the ecosystem
// of other JWT library used, so we make it configurable by the variable
// MarshalSingleStringAsArray.
if len(s) == 1 && !MarshalSingleStringAsArray {
return json.Marshal(s[0])
if len(s.claims) == 1 && !s.marshalSingleStringAsArray {
return json.Marshal(s.claims[0])
}
return json.Marshal([]string(s))
return json.Marshal(s.claims)
}

View File

@ -2,6 +2,7 @@ package jwt_test
import (
"encoding/json"
"errors"
"math"
"testing"
"time"
@ -34,13 +35,12 @@ func TestNumericDate(t *testing.T) {
jwt.TimePrecision = oldPrecision
}
func TestSingleArrayMarshal(t *testing.T) {
jwt.MarshalSingleStringAsArray = false
s := jwt.ClaimStrings{"test"}
expected := `"test"`
func TestClaimStrings(t *testing.T) {
s := jwt.NewClaimStrings([]string{"test"})
expected := `["test"]`
b, err := json.Marshal(s)
if err != nil {
t.Errorf("Unexpected error: %s", err)
}
@ -48,13 +48,37 @@ func TestSingleArrayMarshal(t *testing.T) {
if expected != string(b) {
t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", expected, string(b))
}
}
jwt.MarshalSingleStringAsArray = true
func TestClaimStringsInvalidType(t *testing.T) {
j := `1`
var s jwt.ClaimStrings
err := json.Unmarshal([]byte(j), &s)
if !errors.Is(err, jwt.ErrInvalidType) {
t.Errorf("expected `ErrInvalidType` but was: %v", err)
}
if s.Claims() != nil {
t.Errorf("expected claims to be nil but was: %v", err)
}
}
expected = `["test"]`
func TestClaimStringsMismatchedTypes(t *testing.T) {
j := `["test", 1]`
var s jwt.ClaimStrings
err := json.Unmarshal([]byte(j), &s)
if !errors.Is(err, jwt.ErrInvalidType) {
t.Errorf("expected `ErrInvalidType` but was: %v", err)
}
if s.Claims() != nil {
t.Errorf("expected claims to be nil but was: %v", err)
}
}
b, err = json.Marshal(s)
func TestSingleArrayMarshal(t *testing.T) {
s := jwt.NewClaimStrings([]string{"test"}, jwt.WithMarshalSingleStringAsArray(false))
expected := `"test"`
b, err := json.Marshal(s)
if err != nil {
t.Errorf("Unexpected error: %s", err)
}

View File

@ -232,7 +232,7 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err
return err
}
if len(aud) == 0 {
if aud.Len() == 0 {
return errorIfRequired(required, "aud")
}
@ -240,7 +240,7 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err
result := false
var stringClaims string
for _, a := range aud {
for _, a := range aud.Claims() {
if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 {
result = true
}