From 97430c0b8b45d93df00c07b6db2961dd1145a4a0 Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Sat, 5 Jul 2014 15:25:29 -0700 Subject: [PATCH] cleaned up and flattened RS256 implementation --- rs256.go | 88 ------------------------ rsa.go | 125 +++++++++++++++++++++++++++++++++++ rs256_test.go => rsa_test.go | 13 ++-- 3 files changed, 133 insertions(+), 93 deletions(-) delete mode 100644 rs256.go create mode 100644 rsa.go rename rs256_test.go => rsa_test.go (91%) diff --git a/rs256.go b/rs256.go deleted file mode 100644 index 05f24c4..0000000 --- a/rs256.go +++ /dev/null @@ -1,88 +0,0 @@ -package jwt - -import ( - "crypto" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "crypto/x509" - "encoding/pem" - "errors" -) - -type SigningMethodRS256 struct{} - -func init() { - RegisterSigningMethod("RS256", func() SigningMethod { - return new(SigningMethodRS256) - }) -} - -func (m *SigningMethodRS256) Alg() string { - return "RS256" -} - -func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) (err error) { - // Key - var sig []byte - if sig, err = DecodeSegment(signature); err == nil { - var block *pem.Block - if block, _ = pem.Decode(key); block != nil { - var parsedKey interface{} - if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil { - parsedKey, err = x509.ParseCertificate(block.Bytes) - } - if err == nil { - if rsaKey, ok := parsedKey.(*rsa.PublicKey); ok { - hasher := sha256.New() - hasher.Write([]byte(signingString)) - - err = rsa.VerifyPKCS1v15(rsaKey, crypto.SHA256, hasher.Sum(nil), sig) - } else if cert, ok := parsedKey.(*x509.Certificate); ok { - err = cert.CheckSignature(x509.SHA256WithRSA, []byte(signingString), sig) - } else { - err = errors.New("Key is not a valid RSA public key") - } - } - } else { - err = errors.New("Could not parse key data") - } - } - return -} - -// Implements the Sign method from SigningMethod -// For this signing method, must be PEM encoded PKCS1 or PKCS8 RSA private key -func (m *SigningMethodRS256) Sign(signingString string, key []byte) (sig string, err error) { - // Key - var rsaKey *rsa.PrivateKey - if rsaKey, err = m.parsePrivateKey(key); err == nil { - hasher := sha256.New() - hasher.Write([]byte(signingString)) - - var sigBytes []byte - if sigBytes, err = rsa.SignPKCS1v15(rand.Reader, rsaKey, crypto.SHA256, hasher.Sum(nil)); err == nil { - sig = EncodeSegment(sigBytes) - } - } - return -} - -func (m *SigningMethodRS256) parsePrivateKey(key []byte) (pkey *rsa.PrivateKey, err error) { - var block *pem.Block - if block, _ = pem.Decode(key); block != nil { - var parsedKey interface{} - if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil { - if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil { - return nil, err - } - } - var ok bool - if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok { - err = errors.New("Key is not a valid RSA private key") - } - } else { - err = errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key") - } - return -} diff --git a/rsa.go b/rsa.go new file mode 100644 index 0000000..33b26cc --- /dev/null +++ b/rsa.go @@ -0,0 +1,125 @@ +package jwt + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "errors" +) + +type SigningMethodRS256 struct{} + +func init() { + RegisterSigningMethod("RS256", func() SigningMethod { + return new(SigningMethodRS256) + }) +} + +func (m *SigningMethodRS256) Alg() string { + return "RS256" +} + +func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) error { + var err error + + // Decode the signature + var sig []byte + if sig, err = DecodeSegment(signature); err != nil { + return err + } + + // Parse public key + var rsaKey *rsa.PublicKey + if rsaKey, err = m.parsePublicKey(key); err != nil { + return err + } + + // Create hasher + hasher := sha256.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + return rsa.VerifyPKCS1v15(rsaKey, crypto.SHA256, hasher.Sum(nil), sig) +} + +// Implements the Sign method from SigningMethod +// For this signing method, must be PEM encoded PKCS1 or PKCS8 RSA private key +func (m *SigningMethodRS256) Sign(signingString string, key []byte) (string, error) { + var err error + + // Key + var rsaKey *rsa.PrivateKey + if rsaKey, err = m.parsePrivateKey(key); err != nil { + return "", err + } + + // Create the hasher + hasher := sha256.New() + hasher.Write([]byte(signingString)) + + // Sign the string and return the encoded bytes + if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, crypto.SHA256, hasher.Sum(nil)); err == nil { + return EncodeSegment(sigBytes), nil + } else { + return "", err + } + +} + +// Parse PEM encoded PKCS1 or PKCS8 public key +func (m *SigningMethodRS256) parsePublicKey(key []byte) (*rsa.PublicKey, error) { + var err error + + // Parse PEM block + var block *pem.Block + if block, _ = pem.Decode(key); block == nil { + return nil, errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key") + } + + // Parse the key + var parsedKey interface{} + if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil { + if cert, err := x509.ParseCertificate(block.Bytes); err == nil { + return nil, err + } else { + parsedKey = cert.PublicKey + } + } + + var pkey *rsa.PublicKey + var ok bool + if pkey, ok = parsedKey.(*rsa.PublicKey); !ok { + return nil, errors.New("Key is not a valid RSA public key") + } + + return pkey, nil +} + +// Parse PEM encoded PKCS1 or PKCS8 private key +func (m *SigningMethodRS256) parsePrivateKey(key []byte) (*rsa.PrivateKey, error) { + var err error + + // Parse PEM block + var block *pem.Block + if block, _ = pem.Decode(key); block == nil { + return nil, errors.New("Invalid Key: Key must be PEM encoded PKCS1 or PKCS8 private key") + } + + var parsedKey interface{} + if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil { + if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil { + return nil, err + } + } + + var pkey *rsa.PrivateKey + var ok bool + if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok { + return nil, errors.New("Key is not a valid RSA private key") + } + + return pkey, nil +} diff --git a/rs256_test.go b/rsa_test.go similarity index 91% rename from rs256_test.go rename to rsa_test.go index e0404e3..c71c98f 100644 --- a/rs256_test.go +++ b/rsa_test.go @@ -9,30 +9,33 @@ import ( var rsaTestData = []struct { name string tokenString string + alg string claims map[string]interface{} valid bool }{ { "basic: foo => bar", "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + "RS256", map[string]interface{}{"foo": "bar"}, true, }, { "basic invalid: foo => bar", "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.EhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + "RS256", map[string]interface{}{"foo": "bar"}, false, }, } -func TestRS256Verify(t *testing.T) { +func TestRSAVerify(t *testing.T) { key, _ := ioutil.ReadFile("test/sample_key.pub") for _, data := range rsaTestData { parts := strings.Split(data.tokenString, ".") - method := GetSigningMethod("RS256") + method := GetSigningMethod(data.alg) err := method.Verify(strings.Join(parts[0:2], "."), parts[2], key) if data.valid && err != nil { t.Errorf("[%v] Error while verifying key: %v", data.name, err) @@ -43,13 +46,13 @@ func TestRS256Verify(t *testing.T) { } } -func TestRS256Sign(t *testing.T) { +func TestRSASign(t *testing.T) { key, _ := ioutil.ReadFile("test/sample_key") for _, data := range rsaTestData { if data.valid { parts := strings.Split(data.tokenString, ".") - method := GetSigningMethod("RS256") + method := GetSigningMethod(data.alg) sig, err := method.Sign(strings.Join(parts[0:2], "."), key) if err != nil { t.Errorf("[%v] Error signing token: %v", data.name, err) @@ -61,7 +64,7 @@ func TestRS256Sign(t *testing.T) { } } -func TestKeyParsing(t *testing.T) { +func TestRSAKeyParsing(t *testing.T) { key, _ := ioutil.ReadFile("test/sample_key") pubKey, _ := ioutil.ReadFile("test/sample_key.pub") badKey := []byte("All your base are belong to key")