This commit is contained in:
Dave Grijalva 2012-04-18 12:59:37 -07:00
parent c416da491f
commit 0008f0b730
5 changed files with 54 additions and 55 deletions

33
jwt.go
View File

@ -1,12 +1,12 @@
package jwt
import (
"strings"
"errors"
"encoding/base64"
"encoding/json"
"time"
"errors"
"net/http"
"strings"
"time"
)
// A JWT Token
@ -21,7 +21,7 @@ type Token struct {
// Parse, validate, and return a token.
// keyFunc will receive the parsed token and should return the key for validating.
// If everything is kosher, err will be nil
func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Token, err error) {
func Parse(tokenString string, keyFunc func(*Token) ([]byte, error)) (token *Token, err error) {
parts := strings.Split(tokenString, ".")
if len(parts) == 3 {
token = new(Token)
@ -33,7 +33,7 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
return
}
// parse Claims
var claimBytes []byte
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
@ -42,7 +42,7 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke
if err = json.Unmarshal(claimBytes, &token.Claims); err != nil {
return
}
// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method, err = GetSigningMethod(method); err != nil {
@ -65,21 +65,20 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke
if key, err = keyFunc(token); err != nil {
return
}
// Perform validation
if err = token.Method.Verify(strings.Join(parts[0:2], "."), parts[2], key); err == nil {
token.Valid = true
}
} else {
err = errors.New("Token contains an invalid number of segments")
}
return
}
func ParseFromRequest(req *http.Request, keyFunc func(*Token) ([]byte, error)) (token *Token, err error) {
func ParseFromRequest(req *http.Request, keyFunc func(*Token)([]byte, error))(token *Token, err error) {
// Look for an Authorization header
if ah := req.Header.Get("Authorization"); ah != "" {
// Should be a bearer token
@ -87,19 +86,19 @@ func ParseFromRequest(req *http.Request, keyFunc func(*Token)([]byte, error))(to
return Parse(ah[7:], keyFunc)
}
}
return nil, errors.New("No token present in request.")
}
func DecodeSegment(seg string)([]byte, error) {
func DecodeSegment(seg string) ([]byte, error) {
// len % 4
switch len(seg) % 4 {
case 2:
case 2:
seg = seg + "=="
case 3:
case 3:
seg = seg + "==="
}
return base64.URLEncoding.DecodeString(seg)
}
}

View File

@ -1,20 +1,20 @@
package jwt
import (
"os"
"io"
"bytes"
"testing"
"reflect"
"fmt"
"io"
"net/http"
"os"
"reflect"
"testing"
)
var jwtTestData = []struct{
name string
var jwtTestData = []struct {
name string
tokenString string
claims map[string]interface{}
valid bool
claims map[string]interface{}
valid bool
}{
{
"basic",
@ -36,10 +36,10 @@ func TestJWT(t *testing.T) {
io.Copy(buf, file)
key := buf.Bytes()
file.Close()
for _, data := range jwtTestData {
token, err := Parse(data.tokenString, func(t *Token)([]byte, error){ return key, nil })
token, err := Parse(data.tokenString, func(t *Token) ([]byte, error) { return key, nil })
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
}
@ -58,13 +58,13 @@ func TestParseRequest(t *testing.T) {
io.Copy(buf, file)
key := buf.Bytes()
file.Close()
// Bearer token request
for _, data := range jwtTestData {
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString))
token, err := ParseFromRequest(r, func(t *Token)([]byte, error){ return key, nil })
token, err := ParseFromRequest(r, func(t *Token) ([]byte, error) { return key, nil })
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
}
@ -75,4 +75,4 @@ func TestParseRequest(t *testing.T) {
t.Errorf("[%v] Invalid token passed validation", data.name)
}
}
}
}

View File

@ -1,15 +1,15 @@
package jwt
import (
"errors"
"encoding/pem"
"crypto"
"crypto/x509"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/pem"
"errors"
)
type SigningMethodRS256 struct {}
type SigningMethodRS256 struct{}
func init() {
RegisterSigningMethod("RS256", func() SigningMethod {
@ -17,7 +17,7 @@ func init() {
})
}
func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)(err error) {
func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) (err error) {
// Key
var sig []byte
if sig, err = DecodeSegment(signature); err == nil {
@ -28,7 +28,7 @@ func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)
if rsaKey, ok := parsedKey.(*rsa.PublicKey); ok {
hasher := sha256.New()
hasher.Write([]byte(signingString))
err = rsa.VerifyPKCS1v15(rsaKey, crypto.SHA256, hasher.Sum(nil), sig)
} else {
err = errors.New("Key is not a valid RSA public key")
@ -41,6 +41,6 @@ func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)
return
}
func (m *SigningMethodRS256) Sign(token *Token, key []byte)error {
func (m *SigningMethodRS256) Sign(token *Token, key []byte) error {
return nil
}
}

View File

@ -1,18 +1,18 @@
package jwt
import (
"os"
"io"
"bytes"
"testing"
"io"
"os"
"strings"
"testing"
)
var rsaTestData = []struct{
name string
var rsaTestData = []struct {
name string
tokenString string
claims map[string]interface{}
valid bool
claims map[string]interface{}
valid bool
}{
{
"basic: foo => bar",
@ -34,10 +34,10 @@ func TestRS256Verify(t *testing.T) {
io.Copy(buf, file)
key := buf.Bytes()
file.Close()
for _, data := range rsaTestData {
parts := strings.Split(data.tokenString, ".")
method, _ := GetSigningMethod("RS256")
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], key)
if data.valid && err != nil {
@ -47,4 +47,4 @@ func TestRS256Verify(t *testing.T) {
t.Errorf("[%v] Invalid key passed validation", data.name)
}
}
}
}

View File

@ -1,27 +1,27 @@
package jwt
import (
"fmt"
"errors"
"fmt"
)
var signingMethods = map[string]func() SigningMethod{}
// Signing method
type SigningMethod interface {
Verify(signingString, signature string, key []byte)error
Sign(token *Token, key []byte)error
Verify(signingString, signature string, key []byte) error
Sign(token *Token, key []byte) error
}
func RegisterSigningMethod(alg string, f func() SigningMethod) {
signingMethods[alg] = f
}
func GetSigningMethod(alg string)(method SigningMethod, err error) {
func GetSigningMethod(alg string) (method SigningMethod, err error) {
if methodF, ok := signingMethods[alg]; ok {
method = methodF()
} else {
err = errors.New(fmt.Sprintf("Invalid signing method (alg): %v", method))
}
return
}
}