forked from mirror/jwt
gofmt
This commit is contained in:
parent
c416da491f
commit
0008f0b730
33
jwt.go
33
jwt.go
|
@ -1,12 +1,12 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"errors"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"time"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A JWT Token
|
||||
|
@ -21,7 +21,7 @@ type Token struct {
|
|||
// Parse, validate, and return a token.
|
||||
// keyFunc will receive the parsed token and should return the key for validating.
|
||||
// If everything is kosher, err will be nil
|
||||
func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Token, err error) {
|
||||
func Parse(tokenString string, keyFunc func(*Token) ([]byte, error)) (token *Token, err error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) == 3 {
|
||||
token = new(Token)
|
||||
|
@ -33,7 +33,7 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke
|
|||
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// parse Claims
|
||||
var claimBytes []byte
|
||||
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
|
||||
|
@ -42,7 +42,7 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke
|
|||
if err = json.Unmarshal(claimBytes, &token.Claims); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Lookup signature method
|
||||
if method, ok := token.Header["alg"].(string); ok {
|
||||
if token.Method, err = GetSigningMethod(method); err != nil {
|
||||
|
@ -65,21 +65,20 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke
|
|||
if key, err = keyFunc(token); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Perform validation
|
||||
if err = token.Method.Verify(strings.Join(parts[0:2], "."), parts[2], key); err == nil {
|
||||
token.Valid = true
|
||||
}
|
||||
|
||||
|
||||
} else {
|
||||
err = errors.New("Token contains an invalid number of segments")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ParseFromRequest(req *http.Request, keyFunc func(*Token) ([]byte, error)) (token *Token, err error) {
|
||||
|
||||
func ParseFromRequest(req *http.Request, keyFunc func(*Token)([]byte, error))(token *Token, err error) {
|
||||
|
||||
// Look for an Authorization header
|
||||
if ah := req.Header.Get("Authorization"); ah != "" {
|
||||
// Should be a bearer token
|
||||
|
@ -87,19 +86,19 @@ func ParseFromRequest(req *http.Request, keyFunc func(*Token)([]byte, error))(to
|
|||
return Parse(ah[7:], keyFunc)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return nil, errors.New("No token present in request.")
|
||||
|
||||
|
||||
}
|
||||
|
||||
func DecodeSegment(seg string)([]byte, error) {
|
||||
func DecodeSegment(seg string) ([]byte, error) {
|
||||
// len % 4
|
||||
switch len(seg) % 4 {
|
||||
case 2:
|
||||
case 2:
|
||||
seg = seg + "=="
|
||||
case 3:
|
||||
case 3:
|
||||
seg = seg + "==="
|
||||
}
|
||||
|
||||
|
||||
return base64.URLEncoding.DecodeString(seg)
|
||||
}
|
||||
}
|
||||
|
|
30
jwt_test.go
30
jwt_test.go
|
@ -1,20 +1,20 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"os"
|
||||
"io"
|
||||
"bytes"
|
||||
"testing"
|
||||
"reflect"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var jwtTestData = []struct{
|
||||
name string
|
||||
var jwtTestData = []struct {
|
||||
name string
|
||||
tokenString string
|
||||
claims map[string]interface{}
|
||||
valid bool
|
||||
claims map[string]interface{}
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
"basic",
|
||||
|
@ -36,10 +36,10 @@ func TestJWT(t *testing.T) {
|
|||
io.Copy(buf, file)
|
||||
key := buf.Bytes()
|
||||
file.Close()
|
||||
|
||||
|
||||
for _, data := range jwtTestData {
|
||||
token, err := Parse(data.tokenString, func(t *Token)([]byte, error){ return key, nil })
|
||||
|
||||
token, err := Parse(data.tokenString, func(t *Token) ([]byte, error) { return key, nil })
|
||||
|
||||
if !reflect.DeepEqual(data.claims, token.Claims) {
|
||||
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
|
||||
}
|
||||
|
@ -58,13 +58,13 @@ func TestParseRequest(t *testing.T) {
|
|||
io.Copy(buf, file)
|
||||
key := buf.Bytes()
|
||||
file.Close()
|
||||
|
||||
|
||||
// Bearer token request
|
||||
for _, data := range jwtTestData {
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString))
|
||||
token, err := ParseFromRequest(r, func(t *Token)([]byte, error){ return key, nil })
|
||||
|
||||
token, err := ParseFromRequest(r, func(t *Token) ([]byte, error) { return key, nil })
|
||||
|
||||
if !reflect.DeepEqual(data.claims, token.Claims) {
|
||||
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
|
||||
}
|
||||
|
@ -75,4 +75,4 @@ func TestParseRequest(t *testing.T) {
|
|||
t.Errorf("[%v] Invalid token passed validation", data.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
16
rs256.go
16
rs256.go
|
@ -1,15 +1,15 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"encoding/pem"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type SigningMethodRS256 struct {}
|
||||
type SigningMethodRS256 struct{}
|
||||
|
||||
func init() {
|
||||
RegisterSigningMethod("RS256", func() SigningMethod {
|
||||
|
@ -17,7 +17,7 @@ func init() {
|
|||
})
|
||||
}
|
||||
|
||||
func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)(err error) {
|
||||
func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) (err error) {
|
||||
// Key
|
||||
var sig []byte
|
||||
if sig, err = DecodeSegment(signature); err == nil {
|
||||
|
@ -28,7 +28,7 @@ func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)
|
|||
if rsaKey, ok := parsedKey.(*rsa.PublicKey); ok {
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
|
||||
err = rsa.VerifyPKCS1v15(rsaKey, crypto.SHA256, hasher.Sum(nil), sig)
|
||||
} else {
|
||||
err = errors.New("Key is not a valid RSA public key")
|
||||
|
@ -41,6 +41,6 @@ func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)
|
|||
return
|
||||
}
|
||||
|
||||
func (m *SigningMethodRS256) Sign(token *Token, key []byte)error {
|
||||
func (m *SigningMethodRS256) Sign(token *Token, key []byte) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"os"
|
||||
"io"
|
||||
"bytes"
|
||||
"testing"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var rsaTestData = []struct{
|
||||
name string
|
||||
var rsaTestData = []struct {
|
||||
name string
|
||||
tokenString string
|
||||
claims map[string]interface{}
|
||||
valid bool
|
||||
claims map[string]interface{}
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
"basic: foo => bar",
|
||||
|
@ -34,10 +34,10 @@ func TestRS256Verify(t *testing.T) {
|
|||
io.Copy(buf, file)
|
||||
key := buf.Bytes()
|
||||
file.Close()
|
||||
|
||||
|
||||
for _, data := range rsaTestData {
|
||||
parts := strings.Split(data.tokenString, ".")
|
||||
|
||||
|
||||
method, _ := GetSigningMethod("RS256")
|
||||
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], key)
|
||||
if data.valid && err != nil {
|
||||
|
@ -47,4 +47,4 @@ func TestRS256Verify(t *testing.T) {
|
|||
t.Errorf("[%v] Invalid key passed validation", data.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,27 +1,27 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var signingMethods = map[string]func() SigningMethod{}
|
||||
|
||||
// Signing method
|
||||
type SigningMethod interface {
|
||||
Verify(signingString, signature string, key []byte)error
|
||||
Sign(token *Token, key []byte)error
|
||||
Verify(signingString, signature string, key []byte) error
|
||||
Sign(token *Token, key []byte) error
|
||||
}
|
||||
|
||||
func RegisterSigningMethod(alg string, f func() SigningMethod) {
|
||||
signingMethods[alg] = f
|
||||
}
|
||||
|
||||
func GetSigningMethod(alg string)(method SigningMethod, err error) {
|
||||
func GetSigningMethod(alg string) (method SigningMethod, err error) {
|
||||
if methodF, ok := signingMethods[alg]; ok {
|
||||
method = methodF()
|
||||
} else {
|
||||
err = errors.New(fmt.Sprintf("Invalid signing method (alg): %v", method))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue