diff --git a/rsa_pss.go b/rsa_pss.go index 10ee9db..c014708 100644 --- a/rsa_pss.go +++ b/rsa_pss.go @@ -12,9 +12,14 @@ import ( type SigningMethodRSAPSS struct { *SigningMethodRSA Options *rsa.PSSOptions + // VerifyOptions is optional. If set overrides Options for rsa.VerifyPPS. + // Used to accept tokens signed with rsa.PSSSaltLengthAuto, what doesn't follow + // https://tools.ietf.org/html/rfc7518#section-3.5 but was used previously. + // See https://github.com/dgrijalva/jwt-go/issues/285#issuecomment-437451244 for details. + VerifyOptions *rsa.PSSOptions } -// Specific instances for RS/PS and company +// Specific instances for RS/PS and company. var ( SigningMethodPS256 *SigningMethodRSAPSS SigningMethodPS384 *SigningMethodRSAPSS @@ -24,13 +29,15 @@ var ( func init() { // PS256 SigningMethodPS256 = &SigningMethodRSAPSS{ - &SigningMethodRSA{ + SigningMethodRSA: &SigningMethodRSA{ Name: "PS256", Hash: crypto.SHA256, }, - &rsa.PSSOptions{ + Options: &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + }, + VerifyOptions: &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, - Hash: crypto.SHA256, }, } RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod { @@ -39,13 +46,15 @@ func init() { // PS384 SigningMethodPS384 = &SigningMethodRSAPSS{ - &SigningMethodRSA{ + SigningMethodRSA: &SigningMethodRSA{ Name: "PS384", Hash: crypto.SHA384, }, - &rsa.PSSOptions{ + Options: &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + }, + VerifyOptions: &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, - Hash: crypto.SHA384, }, } RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod { @@ -54,13 +63,15 @@ func init() { // PS512 SigningMethodPS512 = &SigningMethodRSAPSS{ - &SigningMethodRSA{ + SigningMethodRSA: &SigningMethodRSA{ Name: "PS512", Hash: crypto.SHA512, }, - &rsa.PSSOptions{ + Options: &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + }, + VerifyOptions: &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, - Hash: crypto.SHA512, }, } RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod { @@ -94,7 +105,12 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf hasher := m.Hash.New() hasher.Write([]byte(signingString)) - return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, m.Options) + opts := m.Options + if m.VerifyOptions != nil { + opts = m.VerifyOptions + } + + return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts) } // Implements the Sign method from SigningMethod diff --git a/rsa_pss_test.go b/rsa_pss_test.go index 9045aaf..e0134d9 100644 --- a/rsa_pss_test.go +++ b/rsa_pss_test.go @@ -7,8 +7,10 @@ import ( "io/ioutil" "strings" "testing" + "time" "github.com/dgrijalva/jwt-go" + "github.com/dgrijalva/jwt-go/test" ) var rsaPSSTestData = []struct { @@ -94,3 +96,55 @@ func TestRSAPSSSign(t *testing.T) { } } } + +func TestRSAPSSSaltLengthCompatibility(t *testing.T) { + // Fails token verify, if salt length is auto. + ps256SaltLengthEqualsHash := &jwt.SigningMethodRSAPSS{ + SigningMethodRSA: jwt.SigningMethodPS256.SigningMethodRSA, + Options: &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + }, + } + + // Behaves as before https://github.com/dgrijalva/jwt-go/issues/285 fix. + ps256SaltLengthAuto := &jwt.SigningMethodRSAPSS{ + SigningMethodRSA: jwt.SigningMethodPS256.SigningMethodRSA, + Options: &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthAuto, + }, + } + if !verify(jwt.SigningMethodPS256, makeToken(ps256SaltLengthEqualsHash)) { + t.Error("SigningMethodPS256 should accept salt length that is defined in RFC") + } + if !verify(ps256SaltLengthEqualsHash, makeToken(jwt.SigningMethodPS256)) { + t.Error("Sign by SigningMethodPS256 should have salt length that is defined in RFC") + } + if !verify(jwt.SigningMethodPS256, makeToken(ps256SaltLengthAuto)) { + t.Error("SigningMethodPS256 should accept auto salt length to be compatible with previous versions") + } + if !verify(ps256SaltLengthAuto, makeToken(jwt.SigningMethodPS256)) { + t.Error("Sign by SigningMethodPS256 should be accepted by previous versions") + } + if verify(ps256SaltLengthEqualsHash, makeToken(ps256SaltLengthAuto)) { + t.Error("Auto salt length should be not accepted, when RFC salt length is required") + } +} + +func makeToken(method jwt.SigningMethod) string { + token := jwt.NewWithClaims(method, jwt.StandardClaims{ + Issuer: "example", + IssuedAt: time.Now().Unix(), + }) + privateKey := test.LoadRSAPrivateKeyFromDisk("test/sample_key") + signed, err := token.SignedString(privateKey) + if err != nil { + panic(err) + } + return signed +} + +func verify(signingMethod jwt.SigningMethod, token string) bool { + segments := strings.Split(token, ".") + err := signingMethod.Verify(strings.Join(segments[:2], "."), segments[2], test.LoadRSAPublicKeyFromDisk("test/sample_key.pub")) + return err == nil +}