mirror of https://github.com/golang-jwt/jwt.git
Remove global state for audience marshalling
This commit is contained in:
parent
62e504c281
commit
d7e0970ea1
|
@ -12,5 +12,5 @@ type Claims interface {
|
|||
GetNotBefore() (*NumericDate, error)
|
||||
GetIssuer() (string, error)
|
||||
GetSubject() (string, error)
|
||||
GetAudience() (ClaimStrings, error)
|
||||
GetAudience() (*ClaimStrings, error)
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ func ExampleNewWithClaims_customClaimsType() {
|
|||
Issuer: "test",
|
||||
Subject: "somebody",
|
||||
ID: "1",
|
||||
Audience: []string{"somebody_else"},
|
||||
Audience: jwt.NewClaimStrings([]string{"somebody_else"}),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ func (m MapClaims) GetIssuedAt() (*NumericDate, error) {
|
|||
}
|
||||
|
||||
// GetAudience implements the Claims interface.
|
||||
func (m MapClaims) GetAudience() (ClaimStrings, error) {
|
||||
func (m MapClaims) GetAudience() (*ClaimStrings, error) {
|
||||
return m.parseClaimsString("aud")
|
||||
}
|
||||
|
||||
|
@ -66,7 +66,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) {
|
|||
|
||||
// parseClaimsString tries to parse a key in the map claims type as a
|
||||
// [ClaimsStrings] type, which can either be a string or an array of string.
|
||||
func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
|
||||
func (m MapClaims) parseClaimsString(key string) (*ClaimStrings, error) {
|
||||
var cs []string
|
||||
switch v := m[key].(type) {
|
||||
case string:
|
||||
|
@ -83,7 +83,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return cs, nil
|
||||
return NewClaimStrings(cs), nil
|
||||
}
|
||||
|
||||
// parseString tries to parse a key in the map claims type as a [string] type.
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -360,7 +361,7 @@ var jwtTestData = []struct {
|
|||
"",
|
||||
defaultKeyFunc,
|
||||
&jwt.RegisteredClaims{
|
||||
Audience: jwt.ClaimStrings{"test"},
|
||||
Audience: jwt.NewClaimStrings([]string{"test"}),
|
||||
},
|
||||
true,
|
||||
nil,
|
||||
|
@ -372,7 +373,7 @@ var jwtTestData = []struct {
|
|||
"",
|
||||
defaultKeyFunc,
|
||||
&jwt.RegisteredClaims{
|
||||
Audience: jwt.ClaimStrings{"test", "test"},
|
||||
Audience: jwt.NewClaimStrings([]string{"test", "test"}),
|
||||
},
|
||||
true,
|
||||
nil,
|
||||
|
@ -384,7 +385,7 @@ var jwtTestData = []struct {
|
|||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOjF9.8mAIDUfZNQT3TGm1QFIQp91OCpJpQpbB1-m9pA2mkHc", // { "aud": 1 }
|
||||
defaultKeyFunc,
|
||||
&jwt.RegisteredClaims{
|
||||
Audience: nil, // because of the unmarshal error, this will be empty
|
||||
Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty
|
||||
},
|
||||
false,
|
||||
[]error{jwt.ErrTokenMalformed},
|
||||
|
@ -396,7 +397,7 @@ var jwtTestData = []struct {
|
|||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsidGVzdCIsMV19.htEBUf7BVbfSmVoTFjXf3y6DLmDUuLy1vTJ14_EX7Ws", // { "aud": ["test", 1] }
|
||||
defaultKeyFunc,
|
||||
&jwt.RegisteredClaims{
|
||||
Audience: nil, // because of the unmarshal error, this will be empty
|
||||
Audience: jwt.NewClaimStrings([]string{}), // because of the unmarshal error, this will be empty
|
||||
},
|
||||
false,
|
||||
[]error{jwt.ErrTokenMalformed},
|
||||
|
@ -449,6 +450,50 @@ func signToken(claims jwt.Claims, signingMethod jwt.SigningMethod) string {
|
|||
return test.MakeSampleToken(claims, signingMethod, privateKey)
|
||||
}
|
||||
|
||||
func claimsEqual(a, b jwt.Claims) error {
|
||||
aExp, aErr := a.GetExpirationTime()
|
||||
bExp, bErr := b.GetExpirationTime()
|
||||
if !reflect.DeepEqual(aExp, bExp) || !reflect.DeepEqual(aErr, bErr) {
|
||||
return fmt.Errorf("mismatched `exp`: expected %v vs. %v", aExp, bExp)
|
||||
}
|
||||
|
||||
aIat, aErr := a.GetIssuedAt()
|
||||
bIat, bErr := b.GetIssuedAt()
|
||||
if !reflect.DeepEqual(aIat, bIat) || !reflect.DeepEqual(aErr, bErr) {
|
||||
return fmt.Errorf("mismatched `iat`: expected %v vs. %v", aIat, bIat)
|
||||
}
|
||||
|
||||
aNbf, aErr := a.GetNotBefore()
|
||||
bNbf, bErr := b.GetNotBefore()
|
||||
if !reflect.DeepEqual(aNbf, bNbf) || !reflect.DeepEqual(aErr, bErr) {
|
||||
return fmt.Errorf("mismatched `nbf`: expected %v vs. %v", aNbf, bNbf)
|
||||
}
|
||||
|
||||
aIss, aErr := a.GetIssuer()
|
||||
bIss, bErr := b.GetIssuer()
|
||||
if aIss != bIss || !reflect.DeepEqual(aErr, bErr) {
|
||||
return fmt.Errorf("mismatched `iss`: expected %v vs. %v", aIss, bIss)
|
||||
}
|
||||
|
||||
aSub, aErr := a.GetSubject()
|
||||
bSub, bErr := b.GetSubject()
|
||||
if aSub != bSub || !reflect.DeepEqual(aErr, bErr) {
|
||||
return fmt.Errorf("mismatched `sub`: expected %v vs. %v", aSub, bSub)
|
||||
}
|
||||
|
||||
aAud, aErr := a.GetAudience()
|
||||
bAud, bErr := b.GetAudience()
|
||||
if aAud != bAud {
|
||||
if aAud == nil || bAud == nil {
|
||||
return fmt.Errorf("mismatched `aud`: expected %v vs. %v", aAud, bAud)
|
||||
}
|
||||
if !slices.Equal(aAud.Claims(), bAud.Claims()) || !reflect.DeepEqual(aErr, bErr) {
|
||||
return fmt.Errorf("mismatched `aud`: expected %v vs %v", aAud, bAud)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestParser_Parse(t *testing.T) {
|
||||
// Iterate over test data set and run tests
|
||||
for _, data := range jwtTestData {
|
||||
|
@ -476,8 +521,10 @@ func TestParser_Parse(t *testing.T) {
|
|||
}
|
||||
|
||||
// Verify result matches expectation
|
||||
if data.claims != nil && !reflect.DeepEqual(data.claims, token.Claims) {
|
||||
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
|
||||
if data.claims != nil {
|
||||
if err := claimsEqual(data.claims, token.Claims); err != nil {
|
||||
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v: %v", data.name, data.claims, token.Claims, err)
|
||||
}
|
||||
}
|
||||
|
||||
if data.valid && err != nil {
|
||||
|
@ -557,8 +604,8 @@ func TestParser_ParseUnverified(t *testing.T) {
|
|||
}
|
||||
|
||||
// Verify result matches expectation
|
||||
if !reflect.DeepEqual(data.claims, token.Claims) {
|
||||
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
|
||||
if err := claimsEqual(data.claims, token.Claims); err != nil {
|
||||
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v %v", data.name, data.claims, token.Claims, err)
|
||||
}
|
||||
|
||||
if data.valid && err != nil {
|
||||
|
|
|
@ -6,7 +6,7 @@ package jwt
|
|||
//
|
||||
// This type can be used on its own, but then additional private and
|
||||
// public claims embedded in the JWT will not be parsed. The typical use-case
|
||||
// therefore is to embedded this in a user-defined claim type.
|
||||
// therefore is to embed this in a user-defined claim type.
|
||||
//
|
||||
// See examples for how to use this with your own claim types.
|
||||
type RegisteredClaims struct {
|
||||
|
@ -17,7 +17,7 @@ type RegisteredClaims struct {
|
|||
Subject string `json:"sub,omitempty"`
|
||||
|
||||
// the `aud` (Audience) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3
|
||||
Audience ClaimStrings `json:"aud,omitempty"`
|
||||
Audience *ClaimStrings `json:"aud,omitempty"`
|
||||
|
||||
// the `exp` (Expiration Time) claim. See https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.4
|
||||
ExpiresAt *NumericDate `json:"exp,omitempty"`
|
||||
|
@ -48,7 +48,7 @@ func (c RegisteredClaims) GetIssuedAt() (*NumericDate, error) {
|
|||
}
|
||||
|
||||
// GetAudience implements the Claims interface.
|
||||
func (c RegisteredClaims) GetAudience() (ClaimStrings, error) {
|
||||
func (c RegisteredClaims) GetAudience() (*ClaimStrings, error) {
|
||||
return c.Audience, nil
|
||||
}
|
||||
|
||||
|
|
68
types.go
68
types.go
|
@ -17,16 +17,6 @@ import (
|
|||
// no fractional timestamps are generated.
|
||||
var TimePrecision = time.Second
|
||||
|
||||
// MarshalSingleStringAsArray modifies the behavior of the ClaimStrings type,
|
||||
// especially its MarshalJSON function.
|
||||
//
|
||||
// If it is set to true (the default), it will always serialize the type as an
|
||||
// array of strings, even if it just contains one element, defaulting to the
|
||||
// behavior of the underlying []string. If it is set to false, it will serialize
|
||||
// to a single string, if it contains one element. Otherwise, it will serialize
|
||||
// to an array of strings.
|
||||
var MarshalSingleStringAsArray = true
|
||||
|
||||
// NumericDate represents a JSON numeric date value, as referenced at
|
||||
// https://datatracker.ietf.org/doc/html/rfc7519#section-2.
|
||||
type NumericDate struct {
|
||||
|
@ -100,10 +90,52 @@ func (date *NumericDate) UnmarshalJSON(b []byte) (err error) {
|
|||
// ClaimStrings is basically just a slice of strings, but it can be either
|
||||
// serialized from a string array or just a string. This type is necessary,
|
||||
// since the "aud" claim can either be a single string or an array.
|
||||
type ClaimStrings []string
|
||||
type ClaimStrings struct {
|
||||
claims []string
|
||||
marshalSingleStringAsArray bool
|
||||
}
|
||||
|
||||
type ClaimStringOption func(*ClaimStrings)
|
||||
|
||||
func NewClaimStrings(claims []string, opts ...ClaimStringOption) *ClaimStrings {
|
||||
ret := ClaimStrings{
|
||||
claims: claims,
|
||||
marshalSingleStringAsArray: true,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&ret)
|
||||
}
|
||||
return &ret
|
||||
}
|
||||
|
||||
// WithMarshalSingleStringAsArray modifies the behavior of the ClaimStrings type,
|
||||
// especially its MarshalJSON function.
|
||||
//
|
||||
// If it is set to true (the default), it will always serialize the type as an
|
||||
// array of strings, even if it just contains one element, defaulting to the
|
||||
// behavior of the underlying []string. If it is set to false, it will serialize
|
||||
// to a single string, if it contains one element. Otherwise, it will serialize
|
||||
// to an array of strings.
|
||||
func WithMarshalSingleStringAsArray(marshalSingleStringAsArray bool) func(claims *ClaimStrings) {
|
||||
return func(claims *ClaimStrings) {
|
||||
claims.marshalSingleStringAsArray = marshalSingleStringAsArray
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaimStrings) Len() int {
|
||||
return len(s.claims)
|
||||
}
|
||||
|
||||
func (s *ClaimStrings) Claims() []string {
|
||||
return s.claims
|
||||
}
|
||||
|
||||
func (s *ClaimStrings) String() string {
|
||||
return fmt.Sprintf("%v", s.claims)
|
||||
}
|
||||
|
||||
func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
|
||||
var value interface{}
|
||||
var value any
|
||||
|
||||
if err = json.Unmarshal(data, &value); err != nil {
|
||||
return err
|
||||
|
@ -115,7 +147,7 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
|
|||
case string:
|
||||
aud = append(aud, v)
|
||||
case []string:
|
||||
aud = ClaimStrings(v)
|
||||
aud = v
|
||||
case []interface{}:
|
||||
for _, vv := range v {
|
||||
vs, ok := vv.(string)
|
||||
|
@ -130,20 +162,20 @@ func (s *ClaimStrings) UnmarshalJSON(data []byte) (err error) {
|
|||
return ErrInvalidType
|
||||
}
|
||||
|
||||
*s = aud
|
||||
s.claims = aud
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s ClaimStrings) MarshalJSON() (b []byte, err error) {
|
||||
func (s *ClaimStrings) MarshalJSON() (b []byte, err error) {
|
||||
// This handles a special case in the JWT RFC. If the string array, e.g.
|
||||
// used by the "aud" field, only contains one element, it MAY be serialized
|
||||
// as a single string. This may or may not be desired based on the ecosystem
|
||||
// of other JWT library used, so we make it configurable by the variable
|
||||
// MarshalSingleStringAsArray.
|
||||
if len(s) == 1 && !MarshalSingleStringAsArray {
|
||||
return json.Marshal(s[0])
|
||||
if len(s.claims) == 1 && !s.marshalSingleStringAsArray {
|
||||
return json.Marshal(s.claims[0])
|
||||
}
|
||||
|
||||
return json.Marshal([]string(s))
|
||||
return json.Marshal(s.claims)
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package jwt_test
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -34,13 +35,12 @@ func TestNumericDate(t *testing.T) {
|
|||
jwt.TimePrecision = oldPrecision
|
||||
}
|
||||
|
||||
func TestSingleArrayMarshal(t *testing.T) {
|
||||
jwt.MarshalSingleStringAsArray = false
|
||||
|
||||
s := jwt.ClaimStrings{"test"}
|
||||
expected := `"test"`
|
||||
func TestClaimStrings(t *testing.T) {
|
||||
s := jwt.NewClaimStrings([]string{"test"})
|
||||
expected := `["test"]`
|
||||
|
||||
b, err := json.Marshal(s)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %s", err)
|
||||
}
|
||||
|
@ -48,13 +48,37 @@ func TestSingleArrayMarshal(t *testing.T) {
|
|||
if expected != string(b) {
|
||||
t.Errorf("Serialized format of string array mismatch. Expecting: %s Got: %s", expected, string(b))
|
||||
}
|
||||
}
|
||||
|
||||
jwt.MarshalSingleStringAsArray = true
|
||||
func TestClaimStringsInvalidType(t *testing.T) {
|
||||
j := `1`
|
||||
var s jwt.ClaimStrings
|
||||
err := json.Unmarshal([]byte(j), &s)
|
||||
if !errors.Is(err, jwt.ErrInvalidType) {
|
||||
t.Errorf("expected `ErrInvalidType` but was: %v", err)
|
||||
}
|
||||
if s.Claims() != nil {
|
||||
t.Errorf("expected claims to be nil but was: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
expected = `["test"]`
|
||||
func TestClaimStringsMismatchedTypes(t *testing.T) {
|
||||
j := `["test", 1]`
|
||||
var s jwt.ClaimStrings
|
||||
err := json.Unmarshal([]byte(j), &s)
|
||||
if !errors.Is(err, jwt.ErrInvalidType) {
|
||||
t.Errorf("expected `ErrInvalidType` but was: %v", err)
|
||||
}
|
||||
if s.Claims() != nil {
|
||||
t.Errorf("expected claims to be nil but was: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b, err = json.Marshal(s)
|
||||
func TestSingleArrayMarshal(t *testing.T) {
|
||||
s := jwt.NewClaimStrings([]string{"test"}, jwt.WithMarshalSingleStringAsArray(false))
|
||||
expected := `"test"`
|
||||
|
||||
b, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %s", err)
|
||||
}
|
||||
|
|
|
@ -232,7 +232,7 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err
|
|||
return err
|
||||
}
|
||||
|
||||
if len(aud) == 0 {
|
||||
if aud.Len() == 0 {
|
||||
return errorIfRequired(required, "aud")
|
||||
}
|
||||
|
||||
|
@ -240,7 +240,7 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err
|
|||
result := false
|
||||
|
||||
var stringClaims string
|
||||
for _, a := range aud {
|
||||
for _, a := range aud.Claims() {
|
||||
if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 {
|
||||
result = true
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue