can now pass a PrivateKey to SigningMethodRSA.Sign

Conflicts:
	rsa.go
This commit is contained in:
Simon Jefford 2014-08-06 11:13:23 +01:00 committed by Dave Grijalva
parent c9b532b51b
commit dc2f34cdb1
2 changed files with 37 additions and 17 deletions

20
rsa.go
View File

@ -71,17 +71,22 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface
} }
// Implements the Sign method from SigningMethod // Implements the Sign method from SigningMethod
// For this signing method, must be PEM encoded PKCS1 or PKCS8 RSA private key // For this signing method, must be either a PEM encoded PKCS1 or PKCS8 RSA private key as
// []byte, or an rsa.PrivateKey structure.
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) { func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) {
var err error var err error
if keyBytes, ok := key.([]byte); ok {
// Key
var rsaKey *rsa.PrivateKey var rsaKey *rsa.PrivateKey
if rsaKey, err = m.parsePrivateKey(keyBytes); err != nil {
switch k := key.(type) {
case []byte:
if rsaKey, err = m.parsePrivateKey(k); err != nil {
return "", err return "", err
} }
case *rsa.PrivateKey:
rsaKey = k
default:
return "", ErrInvalidKey
}
// Create the hasher // Create the hasher
hasher := m.Hash.New() hasher := m.Hash.New()
hasher.Write([]byte(signingString)) hasher.Write([]byte(signingString))
@ -92,9 +97,6 @@ func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string,
} else { } else {
return "", err return "", err
} }
}
return "", ErrInvalidKey
} }
// Parse PEM encoded PKCS1 or PKCS8 public key // Parse PEM encoded PKCS1 or PKCS8 public key

View File

@ -78,6 +78,24 @@ func TestRSASign(t *testing.T) {
} }
} }
func TestRSAWithPreParsedPrivateKey(t *testing.T) {
key, _ := ioutil.ReadFile("test/sample_key")
method := GetSigningMethod("RS256").(*SigningMethodRSA)
parsedKey, err := method.parsePrivateKey(key)
if err != nil {
t.Fatal(err)
}
testData := rsaTestData[0]
parts := strings.Split(testData.tokenString, ".")
sig, err := method.Sign(strings.Join(parts[0:2], "."), parsedKey)
if err != nil {
t.Errorf("[%v] Error signing token: %v", testData.name, err)
}
if sig != parts[2] {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", testData.name, sig, parts[2])
}
}
func TestRSAKeyParsing(t *testing.T) { func TestRSAKeyParsing(t *testing.T) {
key, _ := ioutil.ReadFile("test/sample_key") key, _ := ioutil.ReadFile("test/sample_key")
pubKey, _ := ioutil.ReadFile("test/sample_key.pub") pubKey, _ := ioutil.ReadFile("test/sample_key.pub")