Support almost all base64 options

This commit is contained in:
Christian Banse 2023-09-14 21:25:59 +02:00 committed by Christian Banse
parent f0fa303116
commit 7684d3e29a
8 changed files with 49 additions and 34 deletions

View File

@ -2,20 +2,21 @@ package jwt
import "io" import "io"
// Base64Encoder is an interface that allows to implement custom Base64 encoding type Base64Encoding interface {
// algorithms. EncodeToString(src []byte) string
type Base64EncodeFunc func(src []byte) string DecodeString(s string) ([]byte, error)
}
// Base64Decoder is an interface that allows to implement custom Base64 decoding type Stricter[T Base64Encoding] interface {
// algorithms. Strict() T
type Base64DecodeFunc func(s string) ([]byte, error) }
// JSONEncoder is an interface that allows to implement custom JSON encoding // JSONMarshalFunc is an function type that allows to implement custom JSON
// algorithms. // encoding algorithms.
type JSONMarshalFunc func(v any) ([]byte, error) type JSONMarshalFunc func(v any) ([]byte, error)
// JSONUnmarshal is an interface that allows to implement custom JSON unmarshal // JSONUnmarshalFunc is an function type that allows to implement custom JSON
// algorithms. // unmarshal algorithms.
type JSONUnmarshalFunc func(data []byte, v any) error type JSONUnmarshalFunc func(data []byte, v any) error
type JSONDecoder interface { type JSONDecoder interface {

View File

@ -22,6 +22,7 @@ var (
ErrTokenInvalidId = errors.New("token has invalid id") ErrTokenInvalidId = errors.New("token has invalid id")
ErrTokenInvalidClaims = errors.New("token has invalid claims") ErrTokenInvalidClaims = errors.New("token has invalid claims")
ErrInvalidType = errors.New("invalid type for claim") ErrInvalidType = errors.New("invalid type for claim")
ErrUnsupported = errors.New("operation is unsupported")
) )
// joinedError is an error type that works similar to what [errors.Join] // joinedError is an error type that works similar to what [errors.Join]

View File

@ -26,12 +26,11 @@ type Parser struct {
type decoders struct { type decoders struct {
jsonUnmarshal JSONUnmarshalFunc jsonUnmarshal JSONUnmarshalFunc
jsonNewDecoder JSONNewDecoderFunc[JSONDecoder] jsonNewDecoder JSONNewDecoderFunc[JSONDecoder]
base64Decode Base64DecodeFunc
// This field is disabled when using a custom base64 encoder. rawUrlBase64Encoding Base64Encoding
decodeStrict bool urlBase64Encoding Base64Encoding
// This field is disabled when using a custom base64 encoder. decodeStrict bool
decodePaddingAllowed bool decodePaddingAllowed bool
} }
@ -227,22 +226,35 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
// take into account whether the [Parser] is configured with additional options, // take into account whether the [Parser] is configured with additional options,
// such as [WithStrictDecoding] or [WithPaddingAllowed]. // such as [WithStrictDecoding] or [WithPaddingAllowed].
func (p *Parser) DecodeSegment(seg string) ([]byte, error) { func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
if p.base64Decode != nil { var encoding Base64Encoding
return p.base64Decode(seg) if p.rawUrlBase64Encoding != nil {
encoding = p.rawUrlBase64Encoding
} else {
encoding = base64.RawURLEncoding
} }
encoding := base64.RawURLEncoding
if p.decodePaddingAllowed { if p.decodePaddingAllowed {
if l := len(seg) % 4; l > 0 { if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l) seg += strings.Repeat("=", 4-l)
} }
encoding = base64.URLEncoding
if p.urlBase64Encoding != nil {
encoding = p.urlBase64Encoding
} else {
encoding = base64.URLEncoding
}
} }
if p.decodeStrict { if p.decodeStrict {
encoding = encoding.Strict() // For now we can only support the standard library here because of the
// current state of the type parameter system
stricter, ok := encoding.(Stricter[*base64.Encoding])
if !ok {
return nil, newError("strict mode is only supported in encoding/base64", ErrUnsupported)
}
encoding = stricter.Strict()
} }
return encoding.DecodeString(seg) return encoding.DecodeString(seg)
} }

View File

@ -142,9 +142,10 @@ func WithJSONDecoder[T JSONDecoder](f JSONUnmarshalFunc, f2 JSONNewDecoderFunc[T
} }
} }
// WithBase64Decoder supports a custom [Base64Decoder] to use in parsing the JWT. // WithBase64Decoder supports a custom [Base64Encoding] to use in parsing the JWT.
func WithBase64Decoder(f Base64DecodeFunc) ParserOption { func WithBase64Decoder(rawURL Base64Encoding, url Base64Encoding) ParserOption {
return func(p *Parser) { return func(p *Parser) {
p.base64Decode = f p.rawUrlBase64Encoding = rawURL
p.urlBase64Encoding = url
} }
} }

View File

@ -454,7 +454,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
true, true,
nil, nil,
jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding.DecodeString)), jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding, base64.URLEncoding)),
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
{ {

View File

@ -39,8 +39,8 @@ type Token struct {
} }
type encoders struct { type encoders struct {
jsonMarshal JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder jsonMarshal JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder
base64Encode Base64EncodeFunc // base64Encoder is the custom base64 encoder/decoder base64Encoding Base64Encoding // base64Encoder is the custom base64 encoding
} }
// New creates a new [Token] with the specified signing method and an empty map // New creates a new [Token] with the specified signing method and an empty map
@ -114,12 +114,12 @@ func (t *Token) SigningString() (string, error) {
// [TokenOption]. Therefore, this function exists as a method of [Token], rather // [TokenOption]. Therefore, this function exists as a method of [Token], rather
// than a global function. // than a global function.
func (t *Token) EncodeSegment(seg []byte) string { func (t *Token) EncodeSegment(seg []byte) string {
var enc Base64EncodeFunc var enc Base64Encoding
if t.base64Encode != nil { if t.base64Encoding != nil {
enc = t.base64Encode enc = t.base64Encoding
} else { } else {
enc = base64.RawURLEncoding.EncodeToString enc = base64.RawURLEncoding
} }
return enc(seg) return enc.EncodeToString(seg)
} }

View File

@ -10,8 +10,8 @@ func WithJSONEncoder(f JSONMarshalFunc) TokenOption {
} }
} }
func WithBase64Encoder(f Base64EncodeFunc) TokenOption { func WithBase64Encoder(enc Base64Encoding) TokenOption {
return func(token *Token) { return func(token *Token) {
token.base64Encode = f token.base64Encoding = enc
} }
} }

View File

@ -53,7 +53,7 @@ func TestToken_SigningString(t1 *testing.T) {
Valid: false, Valid: false,
Options: []jwt.TokenOption{ Options: []jwt.TokenOption{
jwt.WithJSONEncoder(json.Marshal), jwt.WithJSONEncoder(json.Marshal),
jwt.WithBase64Encoder(base64.StdEncoding.EncodeToString), jwt.WithBase64Encoder(base64.StdEncoding),
}, },
}, },
want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30",