From d736b8f86034ce7cf0b603d95151ed8d6ce7d610 Mon Sep 17 00:00:00 2001 From: Dave Grijalva Date: Fri, 6 Jul 2012 15:43:17 -0700 Subject: [PATCH] added signing support to RS256 and HS256 --- jwt.go | 26 ++++++++++++++++++++++++++ rs256.go | 37 +++++++++++++++++++++++++++++++++++-- rs256_test.go | 23 +++++++++++++++++++++++ sha256.go | 11 +++++++++-- sha256_test.go | 16 ++++++++++++++++ signing_method.go | 3 ++- 6 files changed, 111 insertions(+), 5 deletions(-) diff --git a/jwt.go b/jwt.go index f4147d4..8cf55c1 100644 --- a/jwt.go +++ b/jwt.go @@ -18,6 +18,28 @@ type Token struct { Valid bool } +func New(method SigningMethod)*Token { + return &Token{ + Header: map[string]interface{}{ + "typ": "JWT", + "alg": method.Alg(), + }, + Claims: make(map[string]interface{}), + } +} + +func Sign(key []byte) error { + return nil +} + +func SigningString()string { + return "" +} + +func String()string { + return "" +} + // Parse, validate, and return a token. // keyFunc will receive the parsed token and should return the key for validating. // If everything is kosher, err will be nil @@ -91,6 +113,10 @@ func ParseFromRequest(req *http.Request, keyFunc func(*Token) ([]byte, error)) ( } +func EncodeSegment(seg []byte)string { + return strings.TrimRight(base64.URLEncoding.EncodeToString(seg), "=") +} + func DecodeSegment(seg string) ([]byte, error) { // len % 4 switch len(seg) % 4 { diff --git a/rs256.go b/rs256.go index 530d37d..5e364bf 100644 --- a/rs256.go +++ b/rs256.go @@ -5,6 +5,7 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509" + "crypto/rand" "encoding/pem" "errors" ) @@ -17,6 +18,10 @@ func init() { }) } +func (m *SigningMethodRS256) Alg()string { + return "RS256" +} + func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) (err error) { // Key var sig []byte @@ -41,6 +46,34 @@ func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) return } -func (m *SigningMethodRS256) Sign(token *Token, key []byte) error { - return nil +func (m *SigningMethodRS256) Sign(signingString string, key []byte)(sig string, err error) { + // Key + var rsaKey *rsa.PrivateKey + if rsaKey, err = m.parsePrivateKey(key); err == nil { + hasher := sha256.New() + hasher.Write([]byte(signingString)) + + var sigBytes []byte + if sigBytes, err = rsa.SignPKCS1v15(rand.Reader, rsaKey, crypto.SHA256, hasher.Sum(nil)); err == nil { + sig = EncodeSegment(sigBytes) + } + } + return } + +func (m *SigningMethodRS256) parsePrivateKey(key []byte)(pkey *rsa.PrivateKey, err error) { + var block *pem.Block + if block, _ = pem.Decode(key); block != nil { + var parsedKey interface{} + if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil { + if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil { + return nil, err + } + } + var ok bool + if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok { + err = errors.New("Key is not a valid RSA private key") + } + } + return +} \ No newline at end of file diff --git a/rs256_test.go b/rs256_test.go index 2c71b4f..81cae60 100644 --- a/rs256_test.go +++ b/rs256_test.go @@ -48,3 +48,26 @@ func TestRS256Verify(t *testing.T) { } } } + + +func TestRS256Sign(t *testing.T) { + file, _ := os.Open("test/sample_key") + buf := new(bytes.Buffer) + io.Copy(buf, file) + key := buf.Bytes() + file.Close() + + for _, data := range rsaTestData { + if data.valid { + parts := strings.Split(data.tokenString, ".") + method, _ := GetSigningMethod("RS256") + sig, err := method.Sign(strings.Join(parts[0:2], "."), key) + if err != nil { + t.Errorf("[%v] Error signing token: %v", data.name, err) + } + if sig != parts[2] { + t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) + } + } + } +} \ No newline at end of file diff --git a/sha256.go b/sha256.go index 1ce1e7e..67192a6 100644 --- a/sha256.go +++ b/sha256.go @@ -15,6 +15,10 @@ func init() { }) } +func (m *SigningMethodHS256) Alg()string { + return "HS256" +} + func (m *SigningMethodHS256) Verify(signingString, signature string, key []byte) (err error) { // Key var sig []byte @@ -29,6 +33,9 @@ func (m *SigningMethodHS256) Verify(signingString, signature string, key []byte) return } -func (m *SigningMethodHS256) Sign(token *Token, key []byte) error { - return nil +func (m *SigningMethodHS256) Sign(signingString string, key []byte)(string, error) { + hasher := hmac.New(sha256.New, key) + hasher.Write([]byte(signingString)) + + return EncodeSegment(hasher.Sum(nil)), nil } diff --git a/sha256_test.go b/sha256_test.go index 530ab91..27a7d06 100644 --- a/sha256_test.go +++ b/sha256_test.go @@ -46,3 +46,19 @@ func TestHS256Verify(t *testing.T) { } } } + +func TestHS256Sign(t *testing.T) { + for _, data := range sha256TestData { + if data.valid { + parts := strings.Split(data.tokenString, ".") + method, _ := GetSigningMethod("HS256") + sig, err := method.Sign(strings.Join(parts[0:2], "."), sha256TestKey) + if err != nil { + t.Errorf("[%v] Error signing token: %v", data.name, err) + } + if sig != parts[2] { + t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) + } + } + } +} \ No newline at end of file diff --git a/signing_method.go b/signing_method.go index 3f19a85..88e670f 100644 --- a/signing_method.go +++ b/signing_method.go @@ -10,7 +10,8 @@ var signingMethods = map[string]func() SigningMethod{} // Signing method type SigningMethod interface { Verify(signingString, signature string, key []byte) error - Sign(token *Token, key []byte) error + Sign(signingString string, key []byte)(string, error) + Alg() string } func RegisterSigningMethod(alg string, f func() SigningMethod) {