forked from mirror/jwt
Pattern matching on subject
This commit is contained in:
parent
5a65c47732
commit
e8c10437a0
12
validator.go
12
validator.go
|
@ -32,9 +32,7 @@ type Validator struct {
|
||||||
// string will disable iss checking.
|
// string will disable iss checking.
|
||||||
expectedIss string
|
expectedIss string
|
||||||
|
|
||||||
// expectedSub contains the subject this token expects. Supplying an empty
|
expectedSubPattern PatternFunc
|
||||||
// string will disable sub checking.
|
|
||||||
expectedSub string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CustomClaims represents a custom claims interface, which can be built upon the integrated
|
// CustomClaims represents a custom claims interface, which can be built upon the integrated
|
||||||
|
@ -99,7 +97,7 @@ func (v *Validator) Validate(claims Claims) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have an expected subject, we also require the subject claim
|
// If we have an expected subject, we also require the subject claim
|
||||||
if v.expectedSub != "" && !v.VerifySubject(claims, v.expectedSub, true) {
|
if v.expectedSubPattern != nil && !v.VerifySubject(claims, v.expectedSubPattern, true) {
|
||||||
vErr.Inner = ErrTokenInvalidSubject
|
vErr.Inner = ErrTokenInvalidSubject
|
||||||
vErr.Errors |= ValidationErrorSubject
|
vErr.Errors |= ValidationErrorSubject
|
||||||
}
|
}
|
||||||
|
@ -190,7 +188,7 @@ func (v *Validator) VerifyIssuer(claims Claims, cmp string, req bool) bool {
|
||||||
|
|
||||||
// VerifySubject compares the sub claim against cmp.
|
// VerifySubject compares the sub claim against cmp.
|
||||||
// If required is false, this method will return true if the value matches or is unset
|
// If required is false, this method will return true if the value matches or is unset
|
||||||
func (v *Validator) VerifySubject(claims Claims, cmp string, req bool) bool {
|
func (v *Validator) VerifySubject(claims Claims, cmp PatternFunc, req bool) bool {
|
||||||
iss, err := claims.GetSubject()
|
iss, err := claims.GetSubject()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
|
@ -258,10 +256,10 @@ func verifyIss(iss string, cmp string, required bool) bool {
|
||||||
return iss == cmp
|
return iss == cmp
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifySub(sub string, cmp string, required bool) bool {
|
func verifySub(sub string, cmp PatternFunc, required bool) bool {
|
||||||
if sub == "" {
|
if sub == "" {
|
||||||
return !required
|
return !required
|
||||||
}
|
}
|
||||||
|
|
||||||
return sub == cmp
|
return cmp(sub)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
package jwt
|
package jwt
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
// ValidatorOption is used to implement functional-style options that modify the
|
// ValidatorOption is used to implement functional-style options that modify the
|
||||||
// behavior of the validator. To add new options, just create a function
|
// behavior of the validator. To add new options, just create a function
|
||||||
|
@ -9,6 +12,20 @@ import "time"
|
||||||
// accordingly.
|
// accordingly.
|
||||||
type ValidatorOption func(*Validator)
|
type ValidatorOption func(*Validator)
|
||||||
|
|
||||||
|
type PatternFunc func(s string) bool
|
||||||
|
|
||||||
|
func HasPrefix(prefix string) PatternFunc {
|
||||||
|
return func(s string) bool {
|
||||||
|
return strings.HasPrefix(s, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Equals(cmp string) PatternFunc {
|
||||||
|
return func(s string) bool {
|
||||||
|
return cmp == s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithLeeway returns the ValidatorOption for specifying the leeway window.
|
// WithLeeway returns the ValidatorOption for specifying the leeway window.
|
||||||
func WithLeeway(leeway time.Duration) ValidatorOption {
|
func WithLeeway(leeway time.Duration) ValidatorOption {
|
||||||
return func(v *Validator) {
|
return func(v *Validator) {
|
||||||
|
@ -68,6 +85,12 @@ func WithIssuer(iss string) ValidatorOption {
|
||||||
// writing secure application, we decided to REQUIRE the existence of the claim.
|
// writing secure application, we decided to REQUIRE the existence of the claim.
|
||||||
func WithSubject(sub string) ValidatorOption {
|
func WithSubject(sub string) ValidatorOption {
|
||||||
return func(v *Validator) {
|
return func(v *Validator) {
|
||||||
v.expectedSub = sub
|
v.expectedSubPattern = Equals(sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithSubjectPattern(pattern PatternFunc) ValidatorOption {
|
||||||
|
return func(v *Validator) {
|
||||||
|
v.expectedSubPattern = pattern
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidator_Validate(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
leeway time.Duration
|
||||||
|
timeFunc func() time.Time
|
||||||
|
verifyIat bool
|
||||||
|
expectedAud string
|
||||||
|
expectedIss string
|
||||||
|
expectedSubPattern PatternFunc
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
claims Claims
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with subject pattern",
|
||||||
|
fields: fields{
|
||||||
|
expectedSubPattern: HasPrefix("My"),
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
claims: RegisteredClaims{Subject: "MyUser"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v := &Validator{
|
||||||
|
leeway: tt.fields.leeway,
|
||||||
|
timeFunc: tt.fields.timeFunc,
|
||||||
|
verifyIat: tt.fields.verifyIat,
|
||||||
|
expectedAud: tt.fields.expectedAud,
|
||||||
|
expectedIss: tt.fields.expectedIss,
|
||||||
|
expectedSubPattern: tt.fields.expectedSubPattern,
|
||||||
|
}
|
||||||
|
if err := v.Validate(tt.args.claims); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Validator.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue