supporting strict() for all libraries

This commit is contained in:
Christian Banse 2023-09-16 13:31:27 +02:00 committed by Christian Banse
parent 3ae2a4a3c8
commit f64f4609f3
3 changed files with 33 additions and 14 deletions

View File

@ -9,10 +9,16 @@ type Base64Encoding interface {
DecodeString(s string) ([]byte, error) DecodeString(s string) ([]byte, error)
} }
type StrictFunc[T Base64Encoding] func() T
type Stricter[T Base64Encoding] interface { type Stricter[T Base64Encoding] interface {
Strict() T Strict() T
} }
func DoStrict[S Base64Encoding, T Stricter[S]](x T) Base64Encoding {
return x.Strict()
}
// JSONMarshalFunc is an function type that allows to implement custom JSON // JSONMarshalFunc is an function type that allows to implement custom JSON
// encoding algorithms. // encoding algorithms.
type JSONMarshalFunc func(v any) ([]byte, error) type JSONMarshalFunc func(v any) ([]byte, error)

View File

@ -12,23 +12,24 @@ type Parser struct {
// If populated, only these methods will be considered valid. // If populated, only these methods will be considered valid.
validMethods []string validMethods []string
// Use JSON Number format in JSON decoder.
useJSONNumber bool
// Skip claims validation during token parsing. // Skip claims validation during token parsing.
skipClaimsValidation bool skipClaimsValidation bool
validator *Validator validator *Validator
decoders decoding
} }
type decoders struct { type decoding struct {
jsonUnmarshal JSONUnmarshalFunc jsonUnmarshal JSONUnmarshalFunc
jsonNewDecoder JSONNewDecoderFunc[JSONDecoder] jsonNewDecoder JSONNewDecoderFunc[JSONDecoder]
rawUrlBase64Encoding Base64Encoding rawUrlBase64Encoding Base64Encoding
urlBase64Encoding Base64Encoding urlBase64Encoding Base64Encoding
strict StrictFunc[Base64Encoding]
// Use JSON Number format in JSON decoder.
useJSONNumber bool
decodeStrict bool decodeStrict bool
decodePaddingAllowed bool decodePaddingAllowed bool
@ -246,14 +247,16 @@ func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
} }
if p.decodeStrict { if p.decodeStrict {
// For now we can only support the standard library here because of the if p.strict != nil {
// current state of the type parameter system encoding = p.strict()
} else {
stricter, ok := encoding.(Stricter[*base64.Encoding]) stricter, ok := encoding.(Stricter[*base64.Encoding])
if !ok { if !ok {
return nil, newError("strict mode is only supported in encoding/base64", ErrUnsupported) return nil, newError("WithStrictDecoding() was enabled but supplied base64 encoder does not support strict mode", ErrUnsupported)
} }
encoding = stricter.Strict() encoding = stricter.Strict()
} }
}
return encoding.DecodeString(seg) return encoding.DecodeString(seg)
} }

View File

@ -152,7 +152,7 @@ func WithStrictDecoding() ParserOption {
// "github.com/bytedance/sonic" // "github.com/bytedance/sonic"
// ) // )
// //
// var parser = NewParser(WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder)) // var parser = jwt.NewParser(jwt.WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder))
func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDecoder JSONNewDecoderFunc[T]) ParserOption { func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDecoder JSONNewDecoderFunc[T]) ParserOption {
return func(p *Parser) { return func(p *Parser) {
p.jsonUnmarshal = jsonUnmarshal p.jsonUnmarshal = jsonUnmarshal
@ -184,10 +184,20 @@ func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDeco
// asmbase64 "github.com/segmentio/asm/base64" // asmbase64 "github.com/segmentio/asm/base64"
// ) // )
// //
// var parser = NewParser(WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding)) // var parser = jwt.NewParser(jwt.WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding))
func WithBase64Decoder(rawURL Base64Encoding, url Base64Encoding) ParserOption { func WithBase64Decoder[T Base64Encoding](rawURL Base64Encoding, url T) ParserOption {
return func(p *Parser) { return func(p *Parser) {
p.rawUrlBase64Encoding = rawURL p.rawUrlBase64Encoding = rawURL
p.urlBase64Encoding = url p.urlBase64Encoding = url
// Check, whether the library supports the Strict() function
stricter, ok := rawURL.(Stricter[T])
if ok {
// We need to get rid of the type parameter T, so we need to wrap it
// here
p.strict = func() Base64Encoding {
return stricter.Strict()
}
}
} }
} }