allow multiple audiences

This commit is contained in:
Lukas Gruber 2024-12-17 16:45:55 +01:00
parent bc8bdca5cc
commit 48c6492764
3 changed files with 60 additions and 11 deletions

View File

@ -66,15 +66,15 @@ func WithExpirationRequired() ParserOption {
} }
} }
// WithAudience configures the validator to require the specified audience in // WithAudience configures the validator to require ONE of the specified
// the `aud` claim. Validation will fail if the audience is not listed in the // audiences to be present in the `aud` claim. Validation will fail if none of
// token or the `aud` claim is missing. // the audiences is listed in the token or the `aud` claim is missing.
// //
// NOTE: While the `aud` claim is OPTIONAL in a JWT, the handling of it is // NOTE: While the `aud` claim is OPTIONAL in a JWT, the handling of it is
// application-specific. Since this validation API is helping developers in // application-specific. Since this validation API is helping developers in
// writing secure application, we decided to REQUIRE the existence of the claim, // writing secure application, we decided to REQUIRE the existence of the claim,
// if an audience is expected. // if an audience is expected.
func WithAudience(aud string) ParserOption { func WithAudience(aud ...string) ParserOption {
return func(p *Parser) { return func(p *Parser) {
p.validator.expectedAud = aud p.validator.expectedAud = aud
} }

View File

@ -53,7 +53,7 @@ type Validator struct {
// expectedAud contains the audience this token expects. Supplying an empty // expectedAud contains the audience this token expects. Supplying an empty
// string will disable aud checking. // string will disable aud checking.
expectedAud string expectedAud []string
// expectedIss contains the issuer this token expects. Supplying an empty // expectedIss contains the issuer this token expects. Supplying an empty
// string will disable iss checking. // string will disable iss checking.
@ -120,7 +120,7 @@ func (v *Validator) Validate(claims Claims) error {
} }
// If we have an expected audience, we also require the audience claim // If we have an expected audience, we also require the audience claim
if v.expectedAud != "" { if len(v.expectedAud) > 0 {
if err = v.verifyAudience(claims, v.expectedAud, true); err != nil { if err = v.verifyAudience(claims, v.expectedAud, true); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
@ -226,7 +226,7 @@ func (v *Validator) verifyNotBefore(claims Claims, cmp time.Time, required bool)
// //
// Additionally, if any error occurs while retrieving the claim, e.g., when its // Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, an ErrTokenUnverifiable error will be returned. // the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) error { func (v *Validator) verifyAudience(claims Claims, cmp []string, required bool) error {
aud, err := claims.GetAudience() aud, err := claims.GetAudience()
if err != nil { if err != nil {
return err return err
@ -241,11 +241,13 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err
var stringClaims string var stringClaims string
for _, a := range aud { for _, a := range aud {
if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { for _, c := range cmp {
if subtle.ConstantTimeCompare([]byte(a), []byte(c)) != 0 {
result = true result = true
} }
stringClaims = stringClaims + a stringClaims = stringClaims + a
} }
}
// case where "" is sent in one or many aud claims // case where "" is sent in one or many aud claims
if stringClaims == "" { if stringClaims == "" {

View File

@ -25,7 +25,7 @@ func Test_Validator_Validate(t *testing.T) {
leeway time.Duration leeway time.Duration
timeFunc func() time.Time timeFunc func() time.Time
verifyIat bool verifyIat bool
expectedAud string expectedAud []string
expectedIss string expectedIss string
expectedSub string expectedSub string
} }
@ -259,3 +259,50 @@ func Test_Validator_verifyIssuedAt(t *testing.T) {
}) })
} }
} }
func Test_Validator_verifyAudience(t *testing.T) {
type fields struct {
expectedAud []string
}
type args struct {
claims Claims
cmp []string
required bool
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "single value in aud claim",
fields: fields{expectedAud: []string{"me", "you"}},
args: args{claims: MapClaims{"aud": "me"}, cmp: []string{"me"}},
wantErr: nil,
},
{
name: "multiple values in aud claim",
fields: fields{expectedAud: []string{"me"}},
args: args{claims: MapClaims{"aud": []string{"me", "you"}}, cmp: []string{"me"}},
wantErr: nil,
},
{
name: "claims with invalid audience",
fields: fields{expectedAud: []string{"me"}},
args: args{claims: MapClaims{"aud": "you"}, cmp: []string{"me"}},
wantErr: ErrTokenInvalidAudience,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &Validator{
expectedAud: tt.fields.expectedAud,
}
err := v.verifyAudience(tt.args.claims, tt.args.cmp, tt.args.required)
if (err != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("validator.verifyAudience() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}