diff --git a/rsa.go b/rsa.go index 0a10988..0bfdda5 100644 --- a/rsa.go +++ b/rsa.go @@ -71,30 +71,32 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface } // Implements the Sign method from SigningMethod -// For this signing method, must be PEM encoded PKCS1 or PKCS8 RSA private key +// For this signing method, must be either a PEM encoded PKCS1 or PKCS8 RSA private key as +// []byte, or an rsa.PrivateKey structure. func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) { var err error + var rsaKey *rsa.PrivateKey - if keyBytes, ok := key.([]byte); ok { - // Key - var rsaKey *rsa.PrivateKey - if rsaKey, err = m.parsePrivateKey(keyBytes); err != nil { - return "", err - } - - // Create the hasher - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - - // Sign the string and return the encoded bytes - if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil { - return EncodeSegment(sigBytes), nil - } else { + switch k := key.(type) { + case []byte: + if rsaKey, err = m.parsePrivateKey(k); err != nil { return "", err } + case *rsa.PrivateKey: + rsaKey = k + default: + return "", ErrInvalidKey } + // Create the hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) - return "", ErrInvalidKey + // Sign the string and return the encoded bytes + if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil { + return EncodeSegment(sigBytes), nil + } else { + return "", err + } } // Parse PEM encoded PKCS1 or PKCS8 public key diff --git a/rsa_test.go b/rsa_test.go index 109c82d..b2f2653 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -78,6 +78,24 @@ func TestRSASign(t *testing.T) { } } +func TestRSAWithPreParsedPrivateKey(t *testing.T) { + key, _ := ioutil.ReadFile("test/sample_key") + method := GetSigningMethod("RS256").(*SigningMethodRSA) + parsedKey, err := method.parsePrivateKey(key) + if err != nil { + t.Fatal(err) + } + testData := rsaTestData[0] + parts := strings.Split(testData.tokenString, ".") + sig, err := method.Sign(strings.Join(parts[0:2], "."), parsedKey) + if err != nil { + t.Errorf("[%v] Error signing token: %v", testData.name, err) + } + if sig != parts[2] { + t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", testData.name, sig, parts[2]) + } +} + func TestRSAKeyParsing(t *testing.T) { key, _ := ioutil.ReadFile("test/sample_key") pubKey, _ := ioutil.ReadFile("test/sample_key.pub")