Pattern matching on subject

This commit is contained in:
Christian Banse 2022-10-26 22:15:55 +02:00
parent 5a65c47732
commit e8c10437a0
3 changed files with 81 additions and 9 deletions

View File

@ -32,9 +32,7 @@ type Validator struct {
// string will disable iss checking. // string will disable iss checking.
expectedIss string expectedIss string
// expectedSub contains the subject this token expects. Supplying an empty expectedSubPattern PatternFunc
// string will disable sub checking.
expectedSub string
} }
// CustomClaims represents a custom claims interface, which can be built upon the integrated // CustomClaims represents a custom claims interface, which can be built upon the integrated
@ -99,7 +97,7 @@ func (v *Validator) Validate(claims Claims) error {
} }
// If we have an expected subject, we also require the subject claim // If we have an expected subject, we also require the subject claim
if v.expectedSub != "" && !v.VerifySubject(claims, v.expectedSub, true) { if v.expectedSubPattern != nil && !v.VerifySubject(claims, v.expectedSubPattern, true) {
vErr.Inner = ErrTokenInvalidSubject vErr.Inner = ErrTokenInvalidSubject
vErr.Errors |= ValidationErrorSubject vErr.Errors |= ValidationErrorSubject
} }
@ -190,7 +188,7 @@ func (v *Validator) VerifyIssuer(claims Claims, cmp string, req bool) bool {
// VerifySubject compares the sub claim against cmp. // VerifySubject compares the sub claim against cmp.
// If required is false, this method will return true if the value matches or is unset // If required is false, this method will return true if the value matches or is unset
func (v *Validator) VerifySubject(claims Claims, cmp string, req bool) bool { func (v *Validator) VerifySubject(claims Claims, cmp PatternFunc, req bool) bool {
iss, err := claims.GetSubject() iss, err := claims.GetSubject()
if err != nil { if err != nil {
return false return false
@ -258,10 +256,10 @@ func verifyIss(iss string, cmp string, required bool) bool {
return iss == cmp return iss == cmp
} }
func verifySub(sub string, cmp string, required bool) bool { func verifySub(sub string, cmp PatternFunc, required bool) bool {
if sub == "" { if sub == "" {
return !required return !required
} }
return sub == cmp return cmp(sub)
} }

View File

@ -1,6 +1,9 @@
package jwt package jwt
import "time" import (
"strings"
"time"
)
// ValidatorOption is used to implement functional-style options that modify the // ValidatorOption is used to implement functional-style options that modify the
// behavior of the validator. To add new options, just create a function // behavior of the validator. To add new options, just create a function
@ -9,6 +12,20 @@ import "time"
// accordingly. // accordingly.
type ValidatorOption func(*Validator) type ValidatorOption func(*Validator)
type PatternFunc func(s string) bool
func HasPrefix(prefix string) PatternFunc {
return func(s string) bool {
return strings.HasPrefix(s, prefix)
}
}
func Equals(cmp string) PatternFunc {
return func(s string) bool {
return cmp == s
}
}
// WithLeeway returns the ValidatorOption for specifying the leeway window. // WithLeeway returns the ValidatorOption for specifying the leeway window.
func WithLeeway(leeway time.Duration) ValidatorOption { func WithLeeway(leeway time.Duration) ValidatorOption {
return func(v *Validator) { return func(v *Validator) {
@ -68,6 +85,12 @@ func WithIssuer(iss string) ValidatorOption {
// writing secure application, we decided to REQUIRE the existence of the claim. // writing secure application, we decided to REQUIRE the existence of the claim.
func WithSubject(sub string) ValidatorOption { func WithSubject(sub string) ValidatorOption {
return func(v *Validator) { return func(v *Validator) {
v.expectedSub = sub v.expectedSubPattern = Equals(sub)
}
}
func WithSubjectPattern(pattern PatternFunc) ValidatorOption {
return func(v *Validator) {
v.expectedSubPattern = pattern
} }
} }

51
validator_test.go Normal file
View File

@ -0,0 +1,51 @@
package jwt
import (
"testing"
"time"
)
func TestValidator_Validate(t *testing.T) {
type fields struct {
leeway time.Duration
timeFunc func() time.Time
verifyIat bool
expectedAud string
expectedIss string
expectedSubPattern PatternFunc
}
type args struct {
claims Claims
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
name: "with subject pattern",
fields: fields{
expectedSubPattern: HasPrefix("My"),
},
args: args{
claims: RegisteredClaims{Subject: "MyUser"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &Validator{
leeway: tt.fields.leeway,
timeFunc: tt.fields.timeFunc,
verifyIat: tt.fields.verifyIat,
expectedAud: tt.fields.expectedAud,
expectedIss: tt.fields.expectedIss,
expectedSubPattern: tt.fields.expectedSubPattern,
}
if err := v.Validate(tt.args.claims); (err != nil) != tt.wantErr {
t.Errorf("Validator.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}