diff --git a/rs256.go b/rs256.go index 05f24c4..33b26cc 100644 --- a/rs256.go +++ b/rs256.go @@ -22,67 +22,104 @@ func (m *SigningMethodRS256) Alg() string { return "RS256" } -func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) (err error) { - // Key - var sig []byte - if sig, err = DecodeSegment(signature); err == nil { - var block *pem.Block - if block, _ = pem.Decode(key); block != nil { - var parsedKey interface{} - if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil { - parsedKey, err = x509.ParseCertificate(block.Bytes) - } - if err == nil { - if rsaKey, ok := parsedKey.(*rsa.PublicKey); ok { - hasher := sha256.New() - hasher.Write([]byte(signingString)) +func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) error { + var err error - err = rsa.VerifyPKCS1v15(rsaKey, crypto.SHA256, hasher.Sum(nil), sig) - } else if cert, ok := parsedKey.(*x509.Certificate); ok { - err = cert.CheckSignature(x509.SHA256WithRSA, []byte(signingString), sig) - } else { - err = errors.New("Key is not a valid RSA public key") - } - } - } else { - err = errors.New("Could not parse key data") - } + // Decode the signature + var sig []byte + if sig, err = DecodeSegment(signature); err != nil { + return err } - return + + // Parse public key + var rsaKey *rsa.PublicKey + if rsaKey, err = m.parsePublicKey(key); err != nil { + return err + } + + // Create hasher + hasher := sha256.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + return rsa.VerifyPKCS1v15(rsaKey, crypto.SHA256, hasher.Sum(nil), sig) } // Implements the Sign method from SigningMethod // For this signing method, must be PEM encoded PKCS1 or PKCS8 RSA private key -func (m *SigningMethodRS256) Sign(signingString string, key []byte) (sig string, err error) { +func (m *SigningMethodRS256) Sign(signingString string, key []byte) (string, error) { + var err error + // Key var rsaKey *rsa.PrivateKey - if rsaKey, err = m.parsePrivateKey(key); err == nil { - hasher := sha256.New() - hasher.Write([]byte(signingString)) - - var sigBytes []byte - if sigBytes, err = rsa.SignPKCS1v15(rand.Reader, rsaKey, crypto.SHA256, hasher.Sum(nil)); err == nil { - sig = EncodeSegment(sigBytes) - } + if rsaKey, err = m.parsePrivateKey(key); err != nil { + return "", err } - return -} -func (m *SigningMethodRS256) parsePrivateKey(key []byte) (pkey *rsa.PrivateKey, err error) { - var block *pem.Block - if block, _ = pem.Decode(key); block != nil { - var parsedKey interface{} - if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil { - if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil { - return nil, err - } - } - var ok bool - if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok { - err = errors.New("Key is not a valid RSA private key") - } + // Create the hasher + hasher := sha256.New() + hasher.Write([]byte(signingString)) + + // Sign the string and return the encoded bytes + if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, crypto.SHA256, hasher.Sum(nil)); err == nil { + return EncodeSegment(sigBytes), nil } else { - err = errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key") + return "", err } - return + +} + +// Parse PEM encoded PKCS1 or PKCS8 public key +func (m *SigningMethodRS256) parsePublicKey(key []byte) (*rsa.PublicKey, error) { + var err error + + // Parse PEM block + var block *pem.Block + if block, _ = pem.Decode(key); block == nil { + return nil, errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key") + } + + // Parse the key + var parsedKey interface{} + if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil { + if cert, err := x509.ParseCertificate(block.Bytes); err == nil { + return nil, err + } else { + parsedKey = cert.PublicKey + } + } + + var pkey *rsa.PublicKey + var ok bool + if pkey, ok = parsedKey.(*rsa.PublicKey); !ok { + return nil, errors.New("Key is not a valid RSA public key") + } + + return pkey, nil +} + +// Parse PEM encoded PKCS1 or PKCS8 private key +func (m *SigningMethodRS256) parsePrivateKey(key []byte) (*rsa.PrivateKey, error) { + var err error + + // Parse PEM block + var block *pem.Block + if block, _ = pem.Decode(key); block == nil { + return nil, errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key") + } + + var parsedKey interface{} + if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil { + if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil { + return nil, err + } + } + + var pkey *rsa.PrivateKey + var ok bool + if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok { + return nil, errors.New("Key is not a valid RSA private key") + } + + return pkey, nil } diff --git a/rs256_test.go b/rs256_test.go index e0404e3..a0fbd5a 100644 --- a/rs256_test.go +++ b/rs256_test.go @@ -61,7 +61,7 @@ func TestRS256Sign(t *testing.T) { } } -func TestKeyParsing(t *testing.T) { +func TestRSAKeyParsing(t *testing.T) { key, _ := ioutil.ReadFile("test/sample_key") pubKey, _ := ioutil.ReadFile("test/sample_key.pub") badKey := []byte("All your base are belong to key")