diff --git a/jwt.go b/jwt.go index 4e3f393..14cdd5b 100644 --- a/jwt.go +++ b/jwt.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/json" "time" + "net/http" ) // A JWT Token @@ -26,7 +27,7 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke token = new(Token) // parse Header var headerBytes []byte - if headerBytes, err = base64.URLEncoding.DecodeString(parts[0]); err != nil { + if headerBytes, err = DecodeSegment(parts[0]); err != nil { return } if err = json.Unmarshal(headerBytes, &token.Header); err != nil { @@ -35,7 +36,7 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke // parse Claims var claimBytes []byte - if claimBytes, err = base64.URLEncoding.DecodeString(parts[1]); err != nil { + if claimBytes, err = DecodeSegment(parts[1]); err != nil { return } if err = json.Unmarshal(claimBytes, &token.Claims); err != nil { @@ -75,3 +76,22 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke } return } + + +func ParseFromRequest(req *http.Request, keyFunc func(*Token)([]byte, error))(token *Token, err error) { + + return nil, nil + +} + +func DecodeSegment(seg string)([]byte, error) { + // len % 4 + switch len(seg) % 4 { + case 2: + seg = seg + "==" + case 3: + seg = seg + "===" + } + + return base64.URLEncoding.DecodeString(seg) +} \ No newline at end of file diff --git a/jwt_test.go b/jwt_test.go index 16e05ae..eaf45b2 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -1,9 +1,51 @@ package jwt import ( + "os" + "io" + "bytes" "testing" + "reflect" ) -func TestJWT(t *testing.T) { - +var jwtTestData = []struct{ + name string + tokenString string + claims map[string]interface{} + valid bool +}{ + { + "basic: foo => bar", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + map[string]interface{}{"foo": "bar"}, + true, + }, + { + "basic invalid: foo => bar", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.EhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + map[string]interface{}{"foo": "bar"}, + false, + }, +} + +func TestJWT(t *testing.T) { + file, _ := os.Open("test/sample_key.pub") + buf := new(bytes.Buffer) + io.Copy(buf, file) + key := buf.Bytes() + file.Close() + + for _, data := range jwtTestData { + token, err := Parse(data.tokenString, func(t *Token)([]byte, error){ return key, nil }) + + if !reflect.DeepEqual(data.claims, token.Claims) { + t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) + } + if data.valid && err != nil { + t.Errorf("[%v] Error while verifying token: %v", data.name, err) + } + if !data.valid && err == nil { + t.Errorf("[%v] Invalid token passed validation", data.name) + } + } } diff --git a/rs256.go b/rs256.go index 11340bf..2fc2fa4 100644 --- a/rs256.go +++ b/rs256.go @@ -2,7 +2,6 @@ package jwt import ( "errors" - "encoding/base64" "encoding/pem" "crypto" "crypto/x509" @@ -19,17 +18,9 @@ func init() { } func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)(err error) { - // len % 4 - switch len(signature) % 4 { - case 2: - signature = signature + "==" - case 3: - signature = signature + "===" - } - // Key var sig []byte - if sig, err = base64.URLEncoding.DecodeString(signature); err == nil { + if sig, err = DecodeSegment(signature); err == nil { var block *pem.Block if block, _ = pem.Decode(key); block != nil { var parsedKey interface{} @@ -50,6 +41,6 @@ func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) return } -func (m *SigningMethodRS256) Sign(token, key []byte)error { +func (m *SigningMethodRS256) Sign(token *Token, key []byte)error { return nil } \ No newline at end of file diff --git a/rs256_test.go b/rs256_test.go index c829bfb..cf0bbe1 100644 --- a/rs256_test.go +++ b/rs256_test.go @@ -8,7 +8,7 @@ import ( "strings" ) -var testData = []struct{ +var rsaTestData = []struct{ name string tokenString string claims map[string]interface{} @@ -35,7 +35,7 @@ func TestRS256Verify(t *testing.T) { key := buf.Bytes() file.Close() - for _, data := range testData { + for _, data := range rsaTestData { parts := strings.Split(data.tokenString, ".") method, _ := GetSigningMethod("RS256") diff --git a/signing_method.go b/signing_method.go index ab0d032..90516aa 100644 --- a/signing_method.go +++ b/signing_method.go @@ -10,7 +10,7 @@ var signingMethods = map[string]func() SigningMethod{} // Signing method type SigningMethod interface { Verify(signingString, signature string, key []byte)error - Sign(token, key []byte)error + Sign(token *Token, key []byte)error } func RegisterSigningMethod(alg string, f func() SigningMethod) {