From b357385d3ee5c53c16db81addf9c6c8059c4cf15 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Fri, 24 Mar 2023 19:13:09 +0100 Subject: [PATCH] Moving `DecodeSegement` to `Parser` (#278) * Moving `DecodeSegement` to `Parser` This would allow us to remove some global variables and move them to parser options as well as potentially introduce interfaces for json and b64 encoding/decoding to replace the std lib, if someone wanted to do that for performance reasons. We keep the functions exported because of explicit user demand. * Sign/Verify does take the decoded form now --- ecdsa.go | 22 ++++------- ecdsa_test.go | 26 ++++++++++--- ed25519.go | 25 +++++-------- ed25519_test.go | 8 ++-- hmac.go | 16 +++----- hmac_test.go | 5 ++- none.go | 11 +++--- none_test.go | 5 ++- parser.go | 57 +++++++++++++++++++++++++++-- parser_option.go | 26 +++++++++++++ parser_test.go | 20 ++++++---- rsa.go | 20 +++------- rsa_pss.go | 20 +++------- rsa_pss_test.go | 22 ++++++----- rsa_test.go | 9 +++-- signing_method.go | 11 ++++-- token.go | 93 ++++++++--------------------------------------- token_option.go | 5 +++ token_test.go | 7 ++-- 19 files changed, 212 insertions(+), 196 deletions(-) create mode 100644 token_option.go diff --git a/ecdsa.go b/ecdsa.go index eac023f..4ccae2a 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -55,15 +55,7 @@ func (m *SigningMethodECDSA) Alg() string { // Verify implements token verification for the SigningMethod. // For this verify method, key must be an ecdsa.PublicKey struct -func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error { - var err error - - // Decode the signature - var sig []byte - if sig, err = DecodeSegment(signature); err != nil { - return err - } - +func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key interface{}) error { // Get the key var ecdsaKey *ecdsa.PublicKey switch k := key.(type) { @@ -97,19 +89,19 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa // Sign implements token signing for the SigningMethod. // For this signing method, key must be an ecdsa.PrivateKey struct -func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) { +func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) ([]byte, error) { // Get the key var ecdsaKey *ecdsa.PrivateKey switch k := key.(type) { case *ecdsa.PrivateKey: ecdsaKey = k default: - return "", ErrInvalidKeyType + return nil, ErrInvalidKeyType } // Create the hasher if !m.Hash.Available() { - return "", ErrHashUnavailable + return nil, ErrHashUnavailable } hasher := m.Hash.New() @@ -120,7 +112,7 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string curveBits := ecdsaKey.Curve.Params().BitSize if m.CurveBits != curveBits { - return "", ErrInvalidKey + return nil, ErrInvalidKey } keyBytes := curveBits / 8 @@ -135,8 +127,8 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string r.FillBytes(out[0:keyBytes]) // r is assigned to the first half of output. s.FillBytes(out[keyBytes:]) // s is assigned to the second half of output. - return EncodeSegment(out), nil + return out, nil } else { - return "", err + return nil, err } } diff --git a/ecdsa_test.go b/ecdsa_test.go index 7c6d482..3caf0a8 100644 --- a/ecdsa_test.go +++ b/ecdsa_test.go @@ -3,6 +3,7 @@ package jwt_test import ( "crypto/ecdsa" "os" + "reflect" "strings" "testing" @@ -65,7 +66,7 @@ func TestECDSAVerify(t *testing.T) { parts := strings.Split(data.tokenString, ".") method := jwt.GetSigningMethod(data.alg) - err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ecdsaKey) + err = method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), ecdsaKey) if data.valid && err != nil { t.Errorf("[%v] Error while verifying key: %v", data.name, err) } @@ -90,12 +91,13 @@ func TestECDSASign(t *testing.T) { toSign := strings.Join(parts[0:2], ".") method := jwt.GetSigningMethod(data.alg) sig, err := method.Sign(toSign, ecdsaKey) - if err != nil { t.Errorf("[%v] Error signing token: %v", data.name, err) } - if sig == parts[2] { - t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig) + + ssig := encodeSegment(sig) + if ssig == parts[2] { + t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], ssig) } err = method.Verify(toSign, sig, ecdsaKey.Public()) @@ -155,10 +157,24 @@ func BenchmarkECDSASigning(b *testing.B) { if err != nil { b.Fatalf("[%v] Error signing token: %v", data.name, err) } - if sig == parts[2] { + if reflect.DeepEqual(sig, decodeSegment(b, parts[2])) { b.Fatalf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig) } } }) } } + +func decodeSegment(t interface{ Fatalf(string, ...any) }, signature string) (sig []byte) { + var err error + sig, err = jwt.NewParser().DecodeSegment(signature) + if err != nil { + t.Fatalf("could not decode segment: %v", err) + } + + return +} + +func encodeSegment(sig []byte) string { + return (&jwt.Token{}).EncodeSegment(sig) +} diff --git a/ed25519.go b/ed25519.go index 07d3aac..3db00e4 100644 --- a/ed25519.go +++ b/ed25519.go @@ -34,8 +34,7 @@ func (m *SigningMethodEd25519) Alg() string { // Verify implements token verification for the SigningMethod. // For this verify method, key must be an ed25519.PublicKey -func (m *SigningMethodEd25519) Verify(signingString, signature string, key interface{}) error { - var err error +func (m *SigningMethodEd25519) Verify(signingString string, sig []byte, key interface{}) error { var ed25519Key ed25519.PublicKey var ok bool @@ -47,12 +46,6 @@ func (m *SigningMethodEd25519) Verify(signingString, signature string, key inter return ErrInvalidKey } - // Decode the signature - var sig []byte - if sig, err = DecodeSegment(signature); err != nil { - return err - } - // Verify the signature if !ed25519.Verify(ed25519Key, []byte(signingString), sig) { return ErrEd25519Verification @@ -63,23 +56,25 @@ func (m *SigningMethodEd25519) Verify(signingString, signature string, key inter // Sign implements token signing for the SigningMethod. // For this signing method, key must be an ed25519.PrivateKey -func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) (string, error) { +func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) ([]byte, error) { var ed25519Key crypto.Signer var ok bool if ed25519Key, ok = key.(crypto.Signer); !ok { - return "", ErrInvalidKeyType + return nil, ErrInvalidKeyType } if _, ok := ed25519Key.Public().(ed25519.PublicKey); !ok { - return "", ErrInvalidKey + return nil, ErrInvalidKey } - // Sign the string and return the encoded result - // ed25519 performs a two-pass hash as part of its algorithm. Therefore, we need to pass a non-prehashed message into the Sign function, as indicated by crypto.Hash(0) + // Sign the string and return the result. ed25519 performs a two-pass hash + // as part of its algorithm. Therefore, we need to pass a non-prehashed + // message into the Sign function, as indicated by crypto.Hash(0) sig, err := ed25519Key.Sign(rand.Reader, []byte(signingString), crypto.Hash(0)) if err != nil { - return "", err + return nil, err } - return EncodeSegment(sig), nil + + return sig, nil } diff --git a/ed25519_test.go b/ed25519_test.go index cd05818..e9c7432 100644 --- a/ed25519_test.go +++ b/ed25519_test.go @@ -49,7 +49,7 @@ func TestEd25519Verify(t *testing.T) { method := jwt.GetSigningMethod(data.alg) - err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ed25519Key) + err = method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), ed25519Key) if data.valid && err != nil { t.Errorf("[%v] Error while verifying key: %v", data.name, err) } @@ -77,8 +77,10 @@ func TestEd25519Sign(t *testing.T) { if err != nil { t.Errorf("[%v] Error signing token: %v", data.name, err) } - if sig == parts[2] && !data.valid { - t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig) + + ssig := encodeSegment(sig) + if ssig == parts[2] && !data.valid { + t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], ssig) } } } diff --git a/hmac.go b/hmac.go index 011f68a..8609f4a 100644 --- a/hmac.go +++ b/hmac.go @@ -46,19 +46,13 @@ func (m *SigningMethodHMAC) Alg() string { } // Verify implements token verification for the SigningMethod. Returns nil if the signature is valid. -func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error { +func (m *SigningMethodHMAC) Verify(signingString string, sig []byte, key interface{}) error { // Verify the key is the right type keyBytes, ok := key.([]byte) if !ok { return ErrInvalidKeyType } - // Decode signature, for comparison - sig, err := DecodeSegment(signature) - if err != nil { - return err - } - // Can we use the specified hashing method? if !m.Hash.Available() { return ErrHashUnavailable @@ -79,17 +73,17 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac // Sign implements token signing for the SigningMethod. // Key must be []byte -func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) { +func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) ([]byte, error) { if keyBytes, ok := key.([]byte); ok { if !m.Hash.Available() { - return "", ErrHashUnavailable + return nil, ErrHashUnavailable } hasher := hmac.New(m.Hash.New, keyBytes) hasher.Write([]byte(signingString)) - return EncodeSegment(hasher.Sum(nil)), nil + return hasher.Sum(nil), nil } - return "", ErrInvalidKeyType + return nil, ErrInvalidKeyType } diff --git a/hmac_test.go b/hmac_test.go index 83d2c3e..264a2a4 100644 --- a/hmac_test.go +++ b/hmac_test.go @@ -2,6 +2,7 @@ package jwt_test import ( "os" + "reflect" "strings" "testing" @@ -53,7 +54,7 @@ func TestHMACVerify(t *testing.T) { parts := strings.Split(data.tokenString, ".") method := jwt.GetSigningMethod(data.alg) - err := method.Verify(strings.Join(parts[0:2], "."), parts[2], hmacTestKey) + err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), hmacTestKey) if data.valid && err != nil { t.Errorf("[%v] Error while verifying key: %v", data.name, err) } @@ -72,7 +73,7 @@ func TestHMACSign(t *testing.T) { if err != nil { t.Errorf("[%v] Error signing token: %v", data.name, err) } - if sig != parts[2] { + if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) } } diff --git a/none.go b/none.go index a16495a..c93daa5 100644 --- a/none.go +++ b/none.go @@ -25,14 +25,14 @@ func (m *signingMethodNone) Alg() string { } // Only allow 'none' alg type if UnsafeAllowNoneSignatureType is specified as the key -func (m *signingMethodNone) Verify(signingString, signature string, key interface{}) (err error) { +func (m *signingMethodNone) Verify(signingString string, sig []byte, key interface{}) (err error) { // Key must be UnsafeAllowNoneSignatureType to prevent accidentally // accepting 'none' signing method if _, ok := key.(unsafeNoneMagicConstant); !ok { return NoneSignatureTypeDisallowedError } // If signing method is none, signature must be an empty string - if signature != "" { + if string(sig) != "" { return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable) } @@ -41,9 +41,10 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac } // Only allow 'none' signing if UnsafeAllowNoneSignatureType is specified as the key -func (m *signingMethodNone) Sign(signingString string, key interface{}) (string, error) { +func (m *signingMethodNone) Sign(signingString string, key interface{}) ([]byte, error) { if _, ok := key.(unsafeNoneMagicConstant); ok { - return "", nil + return []byte{}, nil } - return "", NoneSignatureTypeDisallowedError + + return nil, NoneSignatureTypeDisallowedError } diff --git a/none_test.go b/none_test.go index 35ff13a..d370cf8 100644 --- a/none_test.go +++ b/none_test.go @@ -1,6 +1,7 @@ package jwt_test import ( + "reflect" "strings" "testing" @@ -46,7 +47,7 @@ func TestNoneVerify(t *testing.T) { parts := strings.Split(data.tokenString, ".") method := jwt.GetSigningMethod(data.alg) - err := method.Verify(strings.Join(parts[0:2], "."), parts[2], data.key) + err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), data.key) if data.valid && err != nil { t.Errorf("[%v] Error while verifying key: %v", data.name, err) } @@ -65,7 +66,7 @@ func TestNoneSign(t *testing.T) { if err != nil { t.Errorf("[%v] Error signing token: %v", data.name, err) } - if sig != parts[2] { + if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) } } diff --git a/parser.go b/parser.go index 46b6793..f4386fb 100644 --- a/parser.go +++ b/parser.go @@ -2,6 +2,7 @@ package jwt import ( "bytes" + "encoding/base64" "encoding/json" "fmt" "strings" @@ -18,6 +19,10 @@ type Parser struct { skipClaimsValidation bool validator *validator + + decodeStrict bool + + decodePaddingAllowed bool } // NewParser creates a new Parser with the specified options @@ -79,8 +84,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf return token, newError("error while executing keyfunc", ErrTokenUnverifiable, err) } + // Decode signature + token.Signature, err = p.DecodeSegment(parts[2]) + if err != nil { + return token, newError("could not base64 decode signature", ErrTokenMalformed, err) + } + // Perform signature validation - token.Signature = parts[2] if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { return token, newError("", ErrTokenSignatureInvalid, err) } @@ -119,7 +129,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // parse Header var headerBytes []byte - if headerBytes, err = DecodeSegment(parts[0]); err != nil { + if headerBytes, err = p.DecodeSegment(parts[0]); err != nil { if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") { return token, parts, newError("tokenstring should not contain 'bearer '", ErrTokenMalformed) } @@ -133,7 +143,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke var claimBytes []byte token.Claims = claims - if claimBytes, err = DecodeSegment(parts[1]); err != nil { + if claimBytes, err = p.DecodeSegment(parts[1]); err != nil { return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err) } dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) @@ -162,3 +172,44 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke return token, parts, nil } + +// DecodeSegment decodes a JWT specific base64url encoding. This function will +// take into account whether the [Parser] is configured with additional options, +// such as [WithStrictDecoding] or [WithPaddingAllowed]. +func (p *Parser) DecodeSegment(seg string) ([]byte, error) { + encoding := base64.RawURLEncoding + + if p.decodePaddingAllowed { + if l := len(seg) % 4; l > 0 { + seg += strings.Repeat("=", 4-l) + } + encoding = base64.URLEncoding + } + + if p.decodeStrict { + encoding = encoding.Strict() + } + return encoding.DecodeString(seg) +} + +// Parse parses, validates, verifies the signature and returns the parsed token. +// keyFunc will receive the parsed token and should return the cryptographic key +// for verifying the signature. The caller is strongly encouraged to set the +// WithValidMethods option to validate the 'alg' claim in the token matches the +// expected algorithm. For more details about the importance of validating the +// 'alg' claim, see +// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ +func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token, error) { + return NewParser(options...).Parse(tokenString, keyFunc) +} + +// ParseWithClaims is a shortcut for NewParser().ParseWithClaims(). +// +// Note: If you provide a custom claim implementation that embeds one of the +// standard claims (such as RegisteredClaims), make sure that a) you either +// embed a non-pointer version of the claims or b) if you are using a pointer, +// allocate the proper memory for it before passing in the overall claims, +// otherwise you might run into a panic. +func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc, options ...ParserOption) (*Token, error) { + return NewParser(options...).ParseWithClaims(tokenString, claims, keyFunc) +} diff --git a/parser_option.go b/parser_option.go index 8d5917e..3ad17bc 100644 --- a/parser_option.go +++ b/parser_option.go @@ -99,3 +99,29 @@ func WithSubject(sub string) ParserOption { p.validator.expectedSub = sub } } + +// WithPaddingAllowed will enable the codec used for decoding JWTs to allow +// padding. Note that the JWS RFC7515 states that the tokens will utilize a +// Base64url encoding with no padding. Unfortunately, some implementations of +// JWT are producing non-standard tokens, and thus require support for decoding. +// Note that this is a global variable, and updating it will change the behavior +// on a package level, and is also NOT go-routine safe. To use the +// non-recommended decoding, set this boolean to `true` prior to using this +// package. +func WithPaddingAllowed() ParserOption { + return func(p *Parser) { + p.decodePaddingAllowed = true + } +} + +// WithStrictDecoding will switch the codec used for decoding JWTs into strict +// mode. In this mode, the decoder requires that trailing padding bits are zero, +// as described in RFC 4648 section 3.5. Note that this is a global variable, +// and updating it will change the behavior on a package level, and is also NOT +// go-routine safe. To use strict decoding, set this boolean to `true` prior to +// using this package. +func WithStrictDecoding() ParserOption { + return func(p *Parser) { + p.decodeStrict = true + } +} diff --git a/parser_test.go b/parser_test.go index fdb5eef..5b912b1 100644 --- a/parser_test.go +++ b/parser_test.go @@ -415,7 +415,7 @@ func TestParser_Parse(t *testing.T) { } if data.valid { - if token.Signature == "" { + if len(token.Signature) == 0 { t.Errorf("[%v] Signature is left unpopulated after parsing", data.name) } if !token.Valid { @@ -473,7 +473,7 @@ func TestParser_ParseUnverified(t *testing.T) { // The 'Valid' field should not be set to true when invoking ParseUnverified() t.Errorf("[%v] Token.Valid field mismatch. Expecting false, got %v", data.name, token.Valid) } - if token.Signature != "" { + if len(token.Signature) != 0 { // The signature was not validated, hence the 'Signature' field is not populated. t.Errorf("[%v] Token.Signature field mismatch. Expecting '', got %v", data.name, token.Signature) } @@ -641,9 +641,6 @@ var setPaddingTestData = []struct { func TestSetPadding(t *testing.T) { for _, data := range setPaddingTestData { t.Run(data.name, func(t *testing.T) { - jwt.DecodePaddingAllowed = data.paddedDecode - jwt.DecodeStrict = data.strictDecode - // If the token string is blank, use helper function to generate string if data.tokenString == "" { data.tokenString = signToken(data.claims, data.signingMethod) @@ -652,7 +649,16 @@ func TestSetPadding(t *testing.T) { // Parse the token var token *jwt.Token var err error - parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + var opts []jwt.ParserOption = []jwt.ParserOption{jwt.WithoutClaimsValidation()} + + if data.paddedDecode { + opts = append(opts, jwt.WithPaddingAllowed()) + } + if data.strictDecode { + opts = append(opts, jwt.WithStrictDecoding()) + } + + parser := jwt.NewParser(opts...) // Figure out correct claims type token, err = parser.ParseWithClaims(data.tokenString, jwt.MapClaims{}, data.keyfunc) @@ -666,8 +672,6 @@ func TestSetPadding(t *testing.T) { } }) - jwt.DecodePaddingAllowed = false - jwt.DecodeStrict = false } } diff --git a/rsa.go b/rsa.go index b910b19..daff094 100644 --- a/rsa.go +++ b/rsa.go @@ -46,15 +46,7 @@ func (m *SigningMethodRSA) Alg() string { // Verify implements token verification for the SigningMethod // For this signing method, must be an *rsa.PublicKey structure. -func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error { - var err error - - // Decode the signature - var sig []byte - if sig, err = DecodeSegment(signature); err != nil { - return err - } - +func (m *SigningMethodRSA) Verify(signingString string, sig []byte, key interface{}) error { var rsaKey *rsa.PublicKey var ok bool @@ -75,18 +67,18 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface // Sign implements token signing for the SigningMethod // For this signing method, must be an *rsa.PrivateKey structure. -func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) { +func (m *SigningMethodRSA) Sign(signingString string, key interface{}) ([]byte, error) { var rsaKey *rsa.PrivateKey var ok bool // Validate type of key if rsaKey, ok = key.(*rsa.PrivateKey); !ok { - return "", ErrInvalidKey + return nil, ErrInvalidKey } // Create the hasher if !m.Hash.Available() { - return "", ErrHashUnavailable + return nil, ErrHashUnavailable } hasher := m.Hash.New() @@ -94,8 +86,8 @@ func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, // Sign the string and return the encoded bytes if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil { - return EncodeSegment(sigBytes), nil + return sigBytes, nil } else { - return "", err + return nil, err } } diff --git a/rsa_pss.go b/rsa_pss.go index 4fd6f9e..9599f0a 100644 --- a/rsa_pss.go +++ b/rsa_pss.go @@ -82,15 +82,7 @@ func init() { // Verify implements token verification for the SigningMethod. // For this verify method, key must be an rsa.PublicKey struct -func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error { - var err error - - // Decode the signature - var sig []byte - if sig, err = DecodeSegment(signature); err != nil { - return err - } - +func (m *SigningMethodRSAPSS) Verify(signingString string, sig []byte, key interface{}) error { var rsaKey *rsa.PublicKey switch k := key.(type) { case *rsa.PublicKey: @@ -116,19 +108,19 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf // Sign implements token signing for the SigningMethod. // For this signing method, key must be an rsa.PrivateKey struct -func (m *SigningMethodRSAPSS) Sign(signingString string, key interface{}) (string, error) { +func (m *SigningMethodRSAPSS) Sign(signingString string, key interface{}) ([]byte, error) { var rsaKey *rsa.PrivateKey switch k := key.(type) { case *rsa.PrivateKey: rsaKey = k default: - return "", ErrInvalidKeyType + return nil, ErrInvalidKeyType } // Create the hasher if !m.Hash.Available() { - return "", ErrHashUnavailable + return nil, ErrHashUnavailable } hasher := m.Hash.New() @@ -136,8 +128,8 @@ func (m *SigningMethodRSAPSS) Sign(signingString string, key interface{}) (strin // Sign the string and return the encoded bytes if sigBytes, err := rsa.SignPSS(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil), m.Options); err == nil { - return EncodeSegment(sigBytes), nil + return sigBytes, nil } else { - return "", err + return nil, err } } diff --git a/rsa_pss_test.go b/rsa_pss_test.go index 1c3d9ea..9707a75 100644 --- a/rsa_pss_test.go +++ b/rsa_pss_test.go @@ -64,7 +64,7 @@ func TestRSAPSSVerify(t *testing.T) { parts := strings.Split(data.tokenString, ".") method := jwt.GetSigningMethod(data.alg) - err := method.Verify(strings.Join(parts[0:2], "."), parts[2], rsaPSSKey) + err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), rsaPSSKey) if data.valid && err != nil { t.Errorf("[%v] Error while verifying key: %v", data.name, err) } @@ -91,8 +91,10 @@ func TestRSAPSSSign(t *testing.T) { if err != nil { t.Errorf("[%v] Error signing token: %v", data.name, err) } - if sig == parts[2] { - t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, sig, parts[2]) + + ssig := encodeSegment(sig) + if ssig == parts[2] { + t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, ssig, parts[2]) } } } @@ -114,19 +116,19 @@ func TestRSAPSSSaltLengthCompatibility(t *testing.T) { SaltLength: rsa.PSSSaltLengthAuto, }, } - if !verify(jwt.SigningMethodPS256, makeToken(ps256SaltLengthEqualsHash)) { + if !verify(t, jwt.SigningMethodPS256, makeToken(ps256SaltLengthEqualsHash)) { t.Error("SigningMethodPS256 should accept salt length that is defined in RFC") } - if !verify(ps256SaltLengthEqualsHash, makeToken(jwt.SigningMethodPS256)) { + if !verify(t, ps256SaltLengthEqualsHash, makeToken(jwt.SigningMethodPS256)) { t.Error("Sign by SigningMethodPS256 should have salt length that is defined in RFC") } - if !verify(jwt.SigningMethodPS256, makeToken(ps256SaltLengthAuto)) { + if !verify(t, jwt.SigningMethodPS256, makeToken(ps256SaltLengthAuto)) { t.Error("SigningMethodPS256 should accept auto salt length to be compatible with previous versions") } - if !verify(ps256SaltLengthAuto, makeToken(jwt.SigningMethodPS256)) { + if !verify(t, ps256SaltLengthAuto, makeToken(jwt.SigningMethodPS256)) { t.Error("Sign by SigningMethodPS256 should be accepted by previous versions") } - if verify(ps256SaltLengthEqualsHash, makeToken(ps256SaltLengthAuto)) { + if verify(t, ps256SaltLengthEqualsHash, makeToken(ps256SaltLengthAuto)) { t.Error("Auto salt length should be not accepted, when RFC salt length is required") } } @@ -144,8 +146,8 @@ func makeToken(method jwt.SigningMethod) string { return signed } -func verify(signingMethod jwt.SigningMethod, token string) bool { +func verify(t *testing.T, signingMethod jwt.SigningMethod, token string) bool { segments := strings.Split(token, ".") - err := signingMethod.Verify(strings.Join(segments[:2], "."), segments[2], test.LoadRSAPublicKeyFromDisk("test/sample_key.pub")) + err := signingMethod.Verify(strings.Join(segments[:2], "."), decodeSegment(t, segments[2]), test.LoadRSAPublicKeyFromDisk("test/sample_key.pub")) return err == nil } diff --git a/rsa_test.go b/rsa_test.go index 8ca6e7a..cba4100 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -2,6 +2,7 @@ package jwt_test import ( "os" + "reflect" "strings" "testing" @@ -48,7 +49,7 @@ func TestRSAVerify(t *testing.T) { parts := strings.Split(data.tokenString, ".") method := jwt.GetSigningMethod(data.alg) - err := method.Verify(strings.Join(parts[0:2], "."), parts[2], key) + err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), key) if data.valid && err != nil { t.Errorf("[%v] Error while verifying key: %v", data.name, err) } @@ -70,7 +71,7 @@ func TestRSASign(t *testing.T) { if err != nil { t.Errorf("[%v] Error signing token: %v", data.name, err) } - if sig != parts[2] { + if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) } } @@ -85,7 +86,7 @@ func TestRSAVerifyWithPreParsedPrivateKey(t *testing.T) { } testData := rsaTestData[0] parts := strings.Split(testData.tokenString, ".") - err = jwt.SigningMethodRS256.Verify(strings.Join(parts[0:2], "."), parts[2], parsedKey) + err = jwt.SigningMethodRS256.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), parsedKey) if err != nil { t.Errorf("[%v] Error while verifying key: %v", testData.name, err) } @@ -103,7 +104,7 @@ func TestRSAWithPreParsedPrivateKey(t *testing.T) { if err != nil { t.Errorf("[%v] Error signing token: %v", testData.name, err) } - if sig != parts[2] { + if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", testData.name, sig, parts[2]) } } diff --git a/signing_method.go b/signing_method.go index 241ae9c..0d73631 100644 --- a/signing_method.go +++ b/signing_method.go @@ -7,11 +7,14 @@ import ( var signingMethods = map[string]func() SigningMethod{} var signingMethodLock = new(sync.RWMutex) -// SigningMethod can be used add new methods for signing or verifying tokens. +// SigningMethod can be used add new methods for signing or verifying tokens. It +// takes a decoded signature as an input in the Verify function and produces a +// signature in Sign. The signature is then usually base64 encoded as part of a +// JWT. type SigningMethod interface { - Verify(signingString, signature string, key interface{}) error // Returns nil if signature is valid - Sign(signingString string, key interface{}) (string, error) // Returns encoded signature or error - Alg() string // returns the alg identifier for this method (example: 'HS256') + Verify(signingString string, sig []byte, key interface{}) error // Returns nil if signature is valid + Sign(signingString string, key interface{}) ([]byte, error) // Returns signature or error + Alg() string // returns the alg identifier for this method (example: 'HS256') } // RegisterSigningMethod registers the "alg" name and a factory function for signing method. diff --git a/token.go b/token.go index 85350b1..163c02f 100644 --- a/token.go +++ b/token.go @@ -3,27 +3,8 @@ package jwt import ( "encoding/base64" "encoding/json" - "strings" ) -// DecodePaddingAllowed will switch the codec used for decoding JWTs -// respectively. Note that the JWS RFC7515 states that the tokens will utilize a -// Base64url encoding with no padding. Unfortunately, some implementations of -// JWT are producing non-standard tokens, and thus require support for decoding. -// Note that this is a global variable, and updating it will change the behavior -// on a package level, and is also NOT go-routine safe. To use the -// non-recommended decoding, set this boolean to `true` prior to using this -// package. -var DecodePaddingAllowed bool - -// DecodeStrict will switch the codec used for decoding JWTs into strict mode. -// In this mode, the decoder requires that trailing padding bits are zero, as -// described in RFC 4648 section 3.5. Note that this is a global variable, and -// updating it will change the behavior on a package level, and is also NOT -// go-routine safe. To use strict decoding, set this boolean to `true` prior to -// using this package. -var DecodeStrict bool - // Keyfunc will be used by the Parse methods as a callback function to supply // the key for verification. The function receives the parsed, but unverified // Token. This allows you to use properties in the Header of the token (such as @@ -35,21 +16,21 @@ type Keyfunc func(*Token) (interface{}, error) type Token struct { Raw string // Raw contains the raw token. Populated when you [Parse] a token Method SigningMethod // Method is the signing method used or to be used - Header map[string]interface{} // Header is the first segment of the token - Claims Claims // Claims is the second segment of the token - Signature string // Signature is the third segment of the token. Populated when you Parse a token + Header map[string]interface{} // Header is the first segment of the token in decoded form + Claims Claims // Claims is the second segment of the token in decoded form + Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token } -// New creates a new [Token] with the specified signing method and an empty map of -// claims. -func New(method SigningMethod) *Token { - return NewWithClaims(method, MapClaims{}) +// New creates a new [Token] with the specified signing method and an empty map +// of claims. Additional options can be specified, but are currently unused. +func New(method SigningMethod, opts ...TokenOption) *Token { + return NewWithClaims(method, MapClaims{}, opts...) } // NewWithClaims creates a new [Token] with the specified signing method and -// claims. -func NewWithClaims(method SigningMethod, claims Claims) *Token { +// claims. Additional options can be specified, but are currently unused. +func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *Token { return &Token{ Header: map[string]interface{}{ "typ": "JWT", @@ -73,7 +54,7 @@ func (t *Token) SignedString(key interface{}) (string, error) { return "", err } - return sstr + "." + sig, nil + return sstr + "." + t.EncodeSegment(sig), nil } // SigningString generates the signing string. This is the most expensive part @@ -90,55 +71,13 @@ func (t *Token) SigningString() (string, error) { return "", err } - return EncodeSegment(h) + "." + EncodeSegment(c), nil + return t.EncodeSegment(h) + "." + t.EncodeSegment(c), nil } -// Parse parses, validates, verifies the signature and returns the parsed token. -// keyFunc will receive the parsed token and should return the cryptographic key -// for verifying the signature. The caller is strongly encouraged to set the -// WithValidMethods option to validate the 'alg' claim in the token matches the -// expected algorithm. For more details about the importance of validating the -// 'alg' claim, see -// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ -func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token, error) { - return NewParser(options...).Parse(tokenString, keyFunc) -} - -// ParseWithClaims is a shortcut for NewParser().ParseWithClaims(). -// -// Note: If you provide a custom claim implementation that embeds one of the -// standard claims (such as RegisteredClaims), make sure that a) you either -// embed a non-pointer version of the claims or b) if you are using a pointer, -// allocate the proper memory for it before passing in the overall claims, -// otherwise you might run into a panic. -func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc, options ...ParserOption) (*Token, error) { - return NewParser(options...).ParseWithClaims(tokenString, claims, keyFunc) -} - -// EncodeSegment encodes a JWT specific base64url encoding with padding stripped -// -// Deprecated: In a future release, we will demote this function to a -// non-exported function, since it should only be used internally -func EncodeSegment(seg []byte) string { +// EncodeSegment encodes a JWT specific base64url encoding with padding +// stripped. In the future, this function might take into account a +// [TokenOption]. Therefore, this function exists as a method of [Token], rather +// than a global function. +func (*Token) EncodeSegment(seg []byte) string { return base64.RawURLEncoding.EncodeToString(seg) } - -// DecodeSegment decodes a JWT specific base64url encoding with padding stripped -// -// Deprecated: In a future release, we will demote this function to a -// non-exported function, since it should only be used internally -func DecodeSegment(seg string) ([]byte, error) { - encoding := base64.RawURLEncoding - - if DecodePaddingAllowed { - if l := len(seg) % 4; l > 0 { - seg += strings.Repeat("=", 4-l) - } - encoding = base64.URLEncoding - } - - if DecodeStrict { - encoding = encoding.Strict() - } - return encoding.DecodeString(seg) -} diff --git a/token_option.go b/token_option.go new file mode 100644 index 0000000..b4ae3ba --- /dev/null +++ b/token_option.go @@ -0,0 +1,5 @@ +package jwt + +// TokenOption is a reserved type, which provides some forward compatibility, +// if we ever want to introduce token creation-related options. +type TokenOption func(*Token) diff --git a/token_test.go b/token_test.go index 52a0021..95709ad 100644 --- a/token_test.go +++ b/token_test.go @@ -12,7 +12,7 @@ func TestToken_SigningString(t1 *testing.T) { Method jwt.SigningMethod Header map[string]interface{} Claims jwt.Claims - Signature string + Signature []byte Valid bool } tests := []struct { @@ -30,9 +30,8 @@ func TestToken_SigningString(t1 *testing.T) { "typ": "JWT", "alg": jwt.SigningMethodHS256.Alg(), }, - Claims: jwt.RegisteredClaims{}, - Signature: "", - Valid: false, + Claims: jwt.RegisteredClaims{}, + Valid: false, }, want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", wantErr: false,