diff --git a/jwt.go b/jwt.go index 0b7fc5d..f4147d4 100644 --- a/jwt.go +++ b/jwt.go @@ -1,12 +1,12 @@ package jwt import ( - "strings" - "errors" "encoding/base64" "encoding/json" - "time" + "errors" "net/http" + "strings" + "time" ) // A JWT Token @@ -21,7 +21,7 @@ type Token struct { // Parse, validate, and return a token. // keyFunc will receive the parsed token and should return the key for validating. // If everything is kosher, err will be nil -func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Token, err error) { +func Parse(tokenString string, keyFunc func(*Token) ([]byte, error)) (token *Token, err error) { parts := strings.Split(tokenString, ".") if len(parts) == 3 { token = new(Token) @@ -33,7 +33,7 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke if err = json.Unmarshal(headerBytes, &token.Header); err != nil { return } - + // parse Claims var claimBytes []byte if claimBytes, err = DecodeSegment(parts[1]); err != nil { @@ -42,7 +42,7 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke if err = json.Unmarshal(claimBytes, &token.Claims); err != nil { return } - + // Lookup signature method if method, ok := token.Header["alg"].(string); ok { if token.Method, err = GetSigningMethod(method); err != nil { @@ -65,21 +65,20 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke if key, err = keyFunc(token); err != nil { return } - + // Perform validation if err = token.Method.Verify(strings.Join(parts[0:2], "."), parts[2], key); err == nil { token.Valid = true } - + } else { err = errors.New("Token contains an invalid number of segments") } return } +func ParseFromRequest(req *http.Request, keyFunc func(*Token) ([]byte, error)) (token *Token, err error) { -func ParseFromRequest(req *http.Request, keyFunc func(*Token)([]byte, error))(token *Token, err error) { - // Look for an Authorization header if ah := req.Header.Get("Authorization"); ah != "" { // Should be a bearer token @@ -87,19 +86,19 @@ func ParseFromRequest(req *http.Request, keyFunc func(*Token)([]byte, error))(to return Parse(ah[7:], keyFunc) } } - + return nil, errors.New("No token present in request.") - + } -func DecodeSegment(seg string)([]byte, error) { +func DecodeSegment(seg string) ([]byte, error) { // len % 4 switch len(seg) % 4 { - case 2: + case 2: seg = seg + "==" - case 3: + case 3: seg = seg + "===" } - + return base64.URLEncoding.DecodeString(seg) -} \ No newline at end of file +} diff --git a/jwt_test.go b/jwt_test.go index b612a95..11b366e 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -1,20 +1,20 @@ package jwt import ( - "os" - "io" "bytes" - "testing" - "reflect" "fmt" + "io" "net/http" + "os" + "reflect" + "testing" ) -var jwtTestData = []struct{ - name string +var jwtTestData = []struct { + name string tokenString string - claims map[string]interface{} - valid bool + claims map[string]interface{} + valid bool }{ { "basic", @@ -36,10 +36,10 @@ func TestJWT(t *testing.T) { io.Copy(buf, file) key := buf.Bytes() file.Close() - + for _, data := range jwtTestData { - token, err := Parse(data.tokenString, func(t *Token)([]byte, error){ return key, nil }) - + token, err := Parse(data.tokenString, func(t *Token) ([]byte, error) { return key, nil }) + if !reflect.DeepEqual(data.claims, token.Claims) { t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) } @@ -58,13 +58,13 @@ func TestParseRequest(t *testing.T) { io.Copy(buf, file) key := buf.Bytes() file.Close() - + // Bearer token request for _, data := range jwtTestData { r, _ := http.NewRequest("GET", "/", nil) r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString)) - token, err := ParseFromRequest(r, func(t *Token)([]byte, error){ return key, nil }) - + token, err := ParseFromRequest(r, func(t *Token) ([]byte, error) { return key, nil }) + if !reflect.DeepEqual(data.claims, token.Claims) { t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) } @@ -75,4 +75,4 @@ func TestParseRequest(t *testing.T) { t.Errorf("[%v] Invalid token passed validation", data.name) } } -} \ No newline at end of file +} diff --git a/rs256.go b/rs256.go index 2fc2fa4..530d37d 100644 --- a/rs256.go +++ b/rs256.go @@ -1,15 +1,15 @@ package jwt import ( - "errors" - "encoding/pem" "crypto" - "crypto/x509" "crypto/rsa" "crypto/sha256" + "crypto/x509" + "encoding/pem" + "errors" ) -type SigningMethodRS256 struct {} +type SigningMethodRS256 struct{} func init() { RegisterSigningMethod("RS256", func() SigningMethod { @@ -17,7 +17,7 @@ func init() { }) } -func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)(err error) { +func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) (err error) { // Key var sig []byte if sig, err = DecodeSegment(signature); err == nil { @@ -28,7 +28,7 @@ func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) if rsaKey, ok := parsedKey.(*rsa.PublicKey); ok { hasher := sha256.New() hasher.Write([]byte(signingString)) - + err = rsa.VerifyPKCS1v15(rsaKey, crypto.SHA256, hasher.Sum(nil), sig) } else { err = errors.New("Key is not a valid RSA public key") @@ -41,6 +41,6 @@ func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) return } -func (m *SigningMethodRS256) Sign(token *Token, key []byte)error { +func (m *SigningMethodRS256) Sign(token *Token, key []byte) error { return nil -} \ No newline at end of file +} diff --git a/rs256_test.go b/rs256_test.go index cf0bbe1..2c71b4f 100644 --- a/rs256_test.go +++ b/rs256_test.go @@ -1,18 +1,18 @@ package jwt import ( - "os" - "io" "bytes" - "testing" + "io" + "os" "strings" + "testing" ) -var rsaTestData = []struct{ - name string +var rsaTestData = []struct { + name string tokenString string - claims map[string]interface{} - valid bool + claims map[string]interface{} + valid bool }{ { "basic: foo => bar", @@ -34,10 +34,10 @@ func TestRS256Verify(t *testing.T) { io.Copy(buf, file) key := buf.Bytes() file.Close() - + for _, data := range rsaTestData { parts := strings.Split(data.tokenString, ".") - + method, _ := GetSigningMethod("RS256") err := method.Verify(strings.Join(parts[0:2], "."), parts[2], key) if data.valid && err != nil { @@ -47,4 +47,4 @@ func TestRS256Verify(t *testing.T) { t.Errorf("[%v] Invalid key passed validation", data.name) } } -} \ No newline at end of file +} diff --git a/signing_method.go b/signing_method.go index 90516aa..3f19a85 100644 --- a/signing_method.go +++ b/signing_method.go @@ -1,27 +1,27 @@ package jwt import ( - "fmt" "errors" + "fmt" ) var signingMethods = map[string]func() SigningMethod{} // Signing method type SigningMethod interface { - Verify(signingString, signature string, key []byte)error - Sign(token *Token, key []byte)error + Verify(signingString, signature string, key []byte) error + Sign(token *Token, key []byte) error } func RegisterSigningMethod(alg string, f func() SigningMethod) { signingMethods[alg] = f } -func GetSigningMethod(alg string)(method SigningMethod, err error) { +func GetSigningMethod(alg string) (method SigningMethod, err error) { if methodF, ok := signingMethods[alg]; ok { method = methodF() } else { err = errors.New(fmt.Sprintf("Invalid signing method (alg): %v", method)) } return -} \ No newline at end of file +}