mirror of https://github.com/golang-jwt/jwt.git
allow multiple audiences
This commit is contained in:
parent
bc8bdca5cc
commit
48c6492764
|
@ -66,15 +66,15 @@ func WithExpirationRequired() ParserOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithAudience configures the validator to require the specified audience in
|
// WithAudience configures the validator to require ONE of the specified
|
||||||
// the `aud` claim. Validation will fail if the audience is not listed in the
|
// audiences to be present in the `aud` claim. Validation will fail if none of
|
||||||
// token or the `aud` claim is missing.
|
// the audiences is listed in the token or the `aud` claim is missing.
|
||||||
//
|
//
|
||||||
// NOTE: While the `aud` claim is OPTIONAL in a JWT, the handling of it is
|
// NOTE: While the `aud` claim is OPTIONAL in a JWT, the handling of it is
|
||||||
// application-specific. Since this validation API is helping developers in
|
// application-specific. Since this validation API is helping developers in
|
||||||
// writing secure application, we decided to REQUIRE the existence of the claim,
|
// writing secure application, we decided to REQUIRE the existence of the claim,
|
||||||
// if an audience is expected.
|
// if an audience is expected.
|
||||||
func WithAudience(aud string) ParserOption {
|
func WithAudience(aud ...string) ParserOption {
|
||||||
return func(p *Parser) {
|
return func(p *Parser) {
|
||||||
p.validator.expectedAud = aud
|
p.validator.expectedAud = aud
|
||||||
}
|
}
|
||||||
|
|
10
validator.go
10
validator.go
|
@ -53,7 +53,7 @@ type Validator struct {
|
||||||
|
|
||||||
// expectedAud contains the audience this token expects. Supplying an empty
|
// expectedAud contains the audience this token expects. Supplying an empty
|
||||||
// string will disable aud checking.
|
// string will disable aud checking.
|
||||||
expectedAud string
|
expectedAud []string
|
||||||
|
|
||||||
// expectedIss contains the issuer this token expects. Supplying an empty
|
// expectedIss contains the issuer this token expects. Supplying an empty
|
||||||
// string will disable iss checking.
|
// string will disable iss checking.
|
||||||
|
@ -120,7 +120,7 @@ func (v *Validator) Validate(claims Claims) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have an expected audience, we also require the audience claim
|
// If we have an expected audience, we also require the audience claim
|
||||||
if v.expectedAud != "" {
|
if len(v.expectedAud) > 0 {
|
||||||
if err = v.verifyAudience(claims, v.expectedAud, true); err != nil {
|
if err = v.verifyAudience(claims, v.expectedAud, true); err != nil {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
|
@ -226,7 +226,7 @@ func (v *Validator) verifyNotBefore(claims Claims, cmp time.Time, required bool)
|
||||||
//
|
//
|
||||||
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
// Additionally, if any error occurs while retrieving the claim, e.g., when its
|
||||||
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
// the wrong type, an ErrTokenUnverifiable error will be returned.
|
||||||
func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) error {
|
func (v *Validator) verifyAudience(claims Claims, cmp []string, required bool) error {
|
||||||
aud, err := claims.GetAudience()
|
aud, err := claims.GetAudience()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -241,11 +241,13 @@ func (v *Validator) verifyAudience(claims Claims, cmp string, required bool) err
|
||||||
|
|
||||||
var stringClaims string
|
var stringClaims string
|
||||||
for _, a := range aud {
|
for _, a := range aud {
|
||||||
if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 {
|
for _, c := range cmp {
|
||||||
|
if subtle.ConstantTimeCompare([]byte(a), []byte(c)) != 0 {
|
||||||
result = true
|
result = true
|
||||||
}
|
}
|
||||||
stringClaims = stringClaims + a
|
stringClaims = stringClaims + a
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// case where "" is sent in one or many aud claims
|
// case where "" is sent in one or many aud claims
|
||||||
if stringClaims == "" {
|
if stringClaims == "" {
|
||||||
|
|
|
@ -25,7 +25,7 @@ func Test_Validator_Validate(t *testing.T) {
|
||||||
leeway time.Duration
|
leeway time.Duration
|
||||||
timeFunc func() time.Time
|
timeFunc func() time.Time
|
||||||
verifyIat bool
|
verifyIat bool
|
||||||
expectedAud string
|
expectedAud []string
|
||||||
expectedIss string
|
expectedIss string
|
||||||
expectedSub string
|
expectedSub string
|
||||||
}
|
}
|
||||||
|
@ -259,3 +259,50 @@ func Test_Validator_verifyIssuedAt(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_Validator_verifyAudience(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
expectedAud []string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
claims Claims
|
||||||
|
cmp []string
|
||||||
|
required bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single value in aud claim",
|
||||||
|
fields: fields{expectedAud: []string{"me", "you"}},
|
||||||
|
args: args{claims: MapClaims{"aud": "me"}, cmp: []string{"me"}},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple values in aud claim",
|
||||||
|
fields: fields{expectedAud: []string{"me"}},
|
||||||
|
args: args{claims: MapClaims{"aud": []string{"me", "you"}}, cmp: []string{"me"}},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "claims with invalid audience",
|
||||||
|
fields: fields{expectedAud: []string{"me"}},
|
||||||
|
args: args{claims: MapClaims{"aud": "you"}, cmp: []string{"me"}},
|
||||||
|
wantErr: ErrTokenInvalidAudience,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v := &Validator{
|
||||||
|
expectedAud: tt.fields.expectedAud,
|
||||||
|
}
|
||||||
|
err := v.verifyAudience(tt.args.claims, tt.args.cmp, tt.args.required)
|
||||||
|
if (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Errorf("validator.verifyAudience() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue