diff --git a/encoder.go b/encoder.go index 1a98906..8b13411 100644 --- a/encoder.go +++ b/encoder.go @@ -9,10 +9,16 @@ type Base64Encoding interface { DecodeString(s string) ([]byte, error) } +type StrictFunc[T Base64Encoding] func() T + type Stricter[T Base64Encoding] interface { Strict() T } +func DoStrict[S Base64Encoding, T Stricter[S]](x T) Base64Encoding { + return x.Strict() +} + // JSONMarshalFunc is an function type that allows to implement custom JSON // encoding algorithms. type JSONMarshalFunc func(v any) ([]byte, error) diff --git a/parser.go b/parser.go index 65570b4..5b774e1 100644 --- a/parser.go +++ b/parser.go @@ -12,23 +12,24 @@ type Parser struct { // If populated, only these methods will be considered valid. validMethods []string - // Use JSON Number format in JSON decoder. - useJSONNumber bool - // Skip claims validation during token parsing. skipClaimsValidation bool validator *Validator - decoders + decoding } -type decoders struct { +type decoding struct { jsonUnmarshal JSONUnmarshalFunc jsonNewDecoder JSONNewDecoderFunc[JSONDecoder] rawUrlBase64Encoding Base64Encoding urlBase64Encoding Base64Encoding + strict StrictFunc[Base64Encoding] + + // Use JSON Number format in JSON decoder. + useJSONNumber bool decodeStrict bool decodePaddingAllowed bool @@ -246,13 +247,15 @@ func (p *Parser) DecodeSegment(seg string) ([]byte, error) { } if p.decodeStrict { - // For now we can only support the standard library here because of the - // current state of the type parameter system - stricter, ok := encoding.(Stricter[*base64.Encoding]) - if !ok { - return nil, newError("strict mode is only supported in encoding/base64", ErrUnsupported) + if p.strict != nil { + encoding = p.strict() + } else { + stricter, ok := encoding.(Stricter[*base64.Encoding]) + if !ok { + return nil, newError("WithStrictDecoding() was enabled but supplied base64 encoder does not support strict mode", ErrUnsupported) + } + encoding = stricter.Strict() } - encoding = stricter.Strict() } return encoding.DecodeString(seg) diff --git a/parser_option.go b/parser_option.go index 9513387..5cf5cd9 100644 --- a/parser_option.go +++ b/parser_option.go @@ -152,7 +152,7 @@ func WithStrictDecoding() ParserOption { // "github.com/bytedance/sonic" // ) // -// var parser = NewParser(WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder)) +// var parser = jwt.NewParser(jwt.WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder)) func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDecoder JSONNewDecoderFunc[T]) ParserOption { return func(p *Parser) { p.jsonUnmarshal = jsonUnmarshal @@ -184,10 +184,20 @@ func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDeco // asmbase64 "github.com/segmentio/asm/base64" // ) // -// var parser = NewParser(WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding)) -func WithBase64Decoder(rawURL Base64Encoding, url Base64Encoding) ParserOption { +// var parser = jwt.NewParser(jwt.WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding)) +func WithBase64Decoder[T Base64Encoding](rawURL Base64Encoding, url T) ParserOption { return func(p *Parser) { p.rawUrlBase64Encoding = rawURL p.urlBase64Encoding = url + + // Check, whether the library supports the Strict() function + stricter, ok := rawURL.(Stricter[T]) + if ok { + // We need to get rid of the type parameter T, so we need to wrap it + // here + p.strict = func() Base64Encoding { + return stricter.Strict() + } + } } }