diff --git a/encoder.go b/encoder.go index 375cb69..5f21e80 100644 --- a/encoder.go +++ b/encoder.go @@ -2,20 +2,21 @@ package jwt import "io" -// Base64Encoder is an interface that allows to implement custom Base64 encoding -// algorithms. -type Base64EncodeFunc func(src []byte) string +type Base64Encoding interface { + EncodeToString(src []byte) string + DecodeString(s string) ([]byte, error) +} -// Base64Decoder is an interface that allows to implement custom Base64 decoding -// algorithms. -type Base64DecodeFunc func(s string) ([]byte, error) +type Stricter[T Base64Encoding] interface { + Strict() T +} -// JSONEncoder is an interface that allows to implement custom JSON encoding -// algorithms. +// JSONMarshalFunc is an function type that allows to implement custom JSON +// encoding algorithms. type JSONMarshalFunc func(v any) ([]byte, error) -// JSONUnmarshal is an interface that allows to implement custom JSON unmarshal -// algorithms. +// JSONUnmarshalFunc is an function type that allows to implement custom JSON +// unmarshal algorithms. type JSONUnmarshalFunc func(data []byte, v any) error type JSONDecoder interface { diff --git a/errors.go b/errors.go index 23bb616..a8fe9be 100644 --- a/errors.go +++ b/errors.go @@ -22,6 +22,7 @@ var ( ErrTokenInvalidId = errors.New("token has invalid id") ErrTokenInvalidClaims = errors.New("token has invalid claims") ErrInvalidType = errors.New("invalid type for claim") + ErrUnsupported = errors.New("operation is unsupported") ) // joinedError is an error type that works similar to what [errors.Join] diff --git a/parser.go b/parser.go index 57fbe70..65570b4 100644 --- a/parser.go +++ b/parser.go @@ -26,12 +26,11 @@ type Parser struct { type decoders struct { jsonUnmarshal JSONUnmarshalFunc jsonNewDecoder JSONNewDecoderFunc[JSONDecoder] - base64Decode Base64DecodeFunc - // This field is disabled when using a custom base64 encoder. - decodeStrict bool + rawUrlBase64Encoding Base64Encoding + urlBase64Encoding Base64Encoding - // This field is disabled when using a custom base64 encoder. + decodeStrict bool decodePaddingAllowed bool } @@ -227,22 +226,35 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // take into account whether the [Parser] is configured with additional options, // such as [WithStrictDecoding] or [WithPaddingAllowed]. func (p *Parser) DecodeSegment(seg string) ([]byte, error) { - if p.base64Decode != nil { - return p.base64Decode(seg) + var encoding Base64Encoding + if p.rawUrlBase64Encoding != nil { + encoding = p.rawUrlBase64Encoding + } else { + encoding = base64.RawURLEncoding } - encoding := base64.RawURLEncoding - if p.decodePaddingAllowed { if l := len(seg) % 4; l > 0 { seg += strings.Repeat("=", 4-l) } - encoding = base64.URLEncoding + + if p.urlBase64Encoding != nil { + encoding = p.urlBase64Encoding + } else { + encoding = base64.URLEncoding + } } if p.decodeStrict { - encoding = encoding.Strict() + // For now we can only support the standard library here because of the + // current state of the type parameter system + stricter, ok := encoding.(Stricter[*base64.Encoding]) + if !ok { + return nil, newError("strict mode is only supported in encoding/base64", ErrUnsupported) + } + encoding = stricter.Strict() } + return encoding.DecodeString(seg) } diff --git a/parser_option.go b/parser_option.go index 9e30b43..8d701a0 100644 --- a/parser_option.go +++ b/parser_option.go @@ -142,9 +142,10 @@ func WithJSONDecoder[T JSONDecoder](f JSONUnmarshalFunc, f2 JSONNewDecoderFunc[T } } -// WithBase64Decoder supports a custom [Base64Decoder] to use in parsing the JWT. -func WithBase64Decoder(f Base64DecodeFunc) ParserOption { +// WithBase64Decoder supports a custom [Base64Encoding] to use in parsing the JWT. +func WithBase64Decoder(rawURL Base64Encoding, url Base64Encoding) ParserOption { return func(p *Parser) { - p.base64Decode = f + p.rawUrlBase64Encoding = rawURL + p.urlBase64Encoding = url } } diff --git a/parser_test.go b/parser_test.go index 0e7b32f..3319c97 100644 --- a/parser_test.go +++ b/parser_test.go @@ -454,7 +454,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, true, nil, - jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding.DecodeString)), + jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding, base64.URLEncoding)), jwt.SigningMethodRS256, }, { diff --git a/token.go b/token.go index 585770f..93b87a3 100644 --- a/token.go +++ b/token.go @@ -39,8 +39,8 @@ type Token struct { } type encoders struct { - jsonMarshal JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder - base64Encode Base64EncodeFunc // base64Encoder is the custom base64 encoder/decoder + jsonMarshal JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder + base64Encoding Base64Encoding // base64Encoder is the custom base64 encoding } // New creates a new [Token] with the specified signing method and an empty map @@ -114,12 +114,12 @@ func (t *Token) SigningString() (string, error) { // [TokenOption]. Therefore, this function exists as a method of [Token], rather // than a global function. func (t *Token) EncodeSegment(seg []byte) string { - var enc Base64EncodeFunc - if t.base64Encode != nil { - enc = t.base64Encode + var enc Base64Encoding + if t.base64Encoding != nil { + enc = t.base64Encoding } else { - enc = base64.RawURLEncoding.EncodeToString + enc = base64.RawURLEncoding } - return enc(seg) + return enc.EncodeToString(seg) } diff --git a/token_option.go b/token_option.go index 3a9ca8d..0fab6a3 100644 --- a/token_option.go +++ b/token_option.go @@ -10,8 +10,8 @@ func WithJSONEncoder(f JSONMarshalFunc) TokenOption { } } -func WithBase64Encoder(f Base64EncodeFunc) TokenOption { +func WithBase64Encoder(enc Base64Encoding) TokenOption { return func(token *Token) { - token.base64Encode = f + token.base64Encoding = enc } } diff --git a/token_test.go b/token_test.go index 7c76fad..d572339 100644 --- a/token_test.go +++ b/token_test.go @@ -53,7 +53,7 @@ func TestToken_SigningString(t1 *testing.T) { Valid: false, Options: []jwt.TokenOption{ jwt.WithJSONEncoder(json.Marshal), - jwt.WithBase64Encoder(base64.StdEncoding.EncodeToString), + jwt.WithBase64Encoder(base64.StdEncoding), }, }, want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30",