supporting strict() for all libraries

This commit is contained in:
Christian Banse 2023-09-16 13:31:27 +02:00 committed by Christian Banse
parent 3ae2a4a3c8
commit f64f4609f3
3 changed files with 33 additions and 14 deletions

View File

@ -9,10 +9,16 @@ type Base64Encoding interface {
DecodeString(s string) ([]byte, error)
}
type StrictFunc[T Base64Encoding] func() T
type Stricter[T Base64Encoding] interface {
Strict() T
}
func DoStrict[S Base64Encoding, T Stricter[S]](x T) Base64Encoding {
return x.Strict()
}
// JSONMarshalFunc is an function type that allows to implement custom JSON
// encoding algorithms.
type JSONMarshalFunc func(v any) ([]byte, error)

View File

@ -12,23 +12,24 @@ type Parser struct {
// If populated, only these methods will be considered valid.
validMethods []string
// Use JSON Number format in JSON decoder.
useJSONNumber bool
// Skip claims validation during token parsing.
skipClaimsValidation bool
validator *Validator
decoders
decoding
}
type decoders struct {
type decoding struct {
jsonUnmarshal JSONUnmarshalFunc
jsonNewDecoder JSONNewDecoderFunc[JSONDecoder]
rawUrlBase64Encoding Base64Encoding
urlBase64Encoding Base64Encoding
strict StrictFunc[Base64Encoding]
// Use JSON Number format in JSON decoder.
useJSONNumber bool
decodeStrict bool
decodePaddingAllowed bool
@ -246,13 +247,15 @@ func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
}
if p.decodeStrict {
// For now we can only support the standard library here because of the
// current state of the type parameter system
stricter, ok := encoding.(Stricter[*base64.Encoding])
if !ok {
return nil, newError("strict mode is only supported in encoding/base64", ErrUnsupported)
if p.strict != nil {
encoding = p.strict()
} else {
stricter, ok := encoding.(Stricter[*base64.Encoding])
if !ok {
return nil, newError("WithStrictDecoding() was enabled but supplied base64 encoder does not support strict mode", ErrUnsupported)
}
encoding = stricter.Strict()
}
encoding = stricter.Strict()
}
return encoding.DecodeString(seg)

View File

@ -152,7 +152,7 @@ func WithStrictDecoding() ParserOption {
// "github.com/bytedance/sonic"
// )
//
// var parser = NewParser(WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder))
// var parser = jwt.NewParser(jwt.WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder))
func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDecoder JSONNewDecoderFunc[T]) ParserOption {
return func(p *Parser) {
p.jsonUnmarshal = jsonUnmarshal
@ -184,10 +184,20 @@ func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDeco
// asmbase64 "github.com/segmentio/asm/base64"
// )
//
// var parser = NewParser(WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding))
func WithBase64Decoder(rawURL Base64Encoding, url Base64Encoding) ParserOption {
// var parser = jwt.NewParser(jwt.WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding))
func WithBase64Decoder[T Base64Encoding](rawURL Base64Encoding, url T) ParserOption {
return func(p *Parser) {
p.rawUrlBase64Encoding = rawURL
p.urlBase64Encoding = url
// Check, whether the library supports the Strict() function
stricter, ok := rawURL.(Stricter[T])
if ok {
// We need to get rid of the type parameter T, so we need to wrap it
// here
p.strict = func() Base64Encoding {
return stricter.Strict()
}
}
}
}