This commit is contained in:
Dave Grijalva 2012-04-18 12:59:37 -07:00
parent c416da491f
commit 0008f0b730
5 changed files with 54 additions and 55 deletions

17
jwt.go
View File

@ -1,12 +1,12 @@
package jwt package jwt
import ( import (
"strings"
"errors"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"time" "errors"
"net/http" "net/http"
"strings"
"time"
) )
// A JWT Token // A JWT Token
@ -21,7 +21,7 @@ type Token struct {
// Parse, validate, and return a token. // Parse, validate, and return a token.
// keyFunc will receive the parsed token and should return the key for validating. // keyFunc will receive the parsed token and should return the key for validating.
// If everything is kosher, err will be nil // If everything is kosher, err will be nil
func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Token, err error) { func Parse(tokenString string, keyFunc func(*Token) ([]byte, error)) (token *Token, err error) {
parts := strings.Split(tokenString, ".") parts := strings.Split(tokenString, ".")
if len(parts) == 3 { if len(parts) == 3 {
token = new(Token) token = new(Token)
@ -77,8 +77,7 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke
return return
} }
func ParseFromRequest(req *http.Request, keyFunc func(*Token) ([]byte, error)) (token *Token, err error) {
func ParseFromRequest(req *http.Request, keyFunc func(*Token)([]byte, error))(token *Token, err error) {
// Look for an Authorization header // Look for an Authorization header
if ah := req.Header.Get("Authorization"); ah != "" { if ah := req.Header.Get("Authorization"); ah != "" {
@ -92,12 +91,12 @@ func ParseFromRequest(req *http.Request, keyFunc func(*Token)([]byte, error))(to
} }
func DecodeSegment(seg string)([]byte, error) { func DecodeSegment(seg string) ([]byte, error) {
// len % 4 // len % 4
switch len(seg) % 4 { switch len(seg) % 4 {
case 2: case 2:
seg = seg + "==" seg = seg + "=="
case 3: case 3:
seg = seg + "===" seg = seg + "==="
} }

View File

@ -1,20 +1,20 @@
package jwt package jwt
import ( import (
"os"
"io"
"bytes" "bytes"
"testing"
"reflect"
"fmt" "fmt"
"io"
"net/http" "net/http"
"os"
"reflect"
"testing"
) )
var jwtTestData = []struct{ var jwtTestData = []struct {
name string name string
tokenString string tokenString string
claims map[string]interface{} claims map[string]interface{}
valid bool valid bool
}{ }{
{ {
"basic", "basic",
@ -38,7 +38,7 @@ func TestJWT(t *testing.T) {
file.Close() file.Close()
for _, data := range jwtTestData { for _, data := range jwtTestData {
token, err := Parse(data.tokenString, func(t *Token)([]byte, error){ return key, nil }) token, err := Parse(data.tokenString, func(t *Token) ([]byte, error) { return key, nil })
if !reflect.DeepEqual(data.claims, token.Claims) { if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
@ -63,7 +63,7 @@ func TestParseRequest(t *testing.T) {
for _, data := range jwtTestData { for _, data := range jwtTestData {
r, _ := http.NewRequest("GET", "/", nil) r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString)) r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString))
token, err := ParseFromRequest(r, func(t *Token)([]byte, error){ return key, nil }) token, err := ParseFromRequest(r, func(t *Token) ([]byte, error) { return key, nil })
if !reflect.DeepEqual(data.claims, token.Claims) { if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)

View File

@ -1,15 +1,15 @@
package jwt package jwt
import ( import (
"errors"
"encoding/pem"
"crypto" "crypto"
"crypto/x509"
"crypto/rsa" "crypto/rsa"
"crypto/sha256" "crypto/sha256"
"crypto/x509"
"encoding/pem"
"errors"
) )
type SigningMethodRS256 struct {} type SigningMethodRS256 struct{}
func init() { func init() {
RegisterSigningMethod("RS256", func() SigningMethod { RegisterSigningMethod("RS256", func() SigningMethod {
@ -17,7 +17,7 @@ func init() {
}) })
} }
func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)(err error) { func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte) (err error) {
// Key // Key
var sig []byte var sig []byte
if sig, err = DecodeSegment(signature); err == nil { if sig, err = DecodeSegment(signature); err == nil {
@ -41,6 +41,6 @@ func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)
return return
} }
func (m *SigningMethodRS256) Sign(token *Token, key []byte)error { func (m *SigningMethodRS256) Sign(token *Token, key []byte) error {
return nil return nil
} }

View File

@ -1,18 +1,18 @@
package jwt package jwt
import ( import (
"os"
"io"
"bytes" "bytes"
"testing" "io"
"os"
"strings" "strings"
"testing"
) )
var rsaTestData = []struct{ var rsaTestData = []struct {
name string name string
tokenString string tokenString string
claims map[string]interface{} claims map[string]interface{}
valid bool valid bool
}{ }{
{ {
"basic: foo => bar", "basic: foo => bar",

View File

@ -1,23 +1,23 @@
package jwt package jwt
import ( import (
"fmt"
"errors" "errors"
"fmt"
) )
var signingMethods = map[string]func() SigningMethod{} var signingMethods = map[string]func() SigningMethod{}
// Signing method // Signing method
type SigningMethod interface { type SigningMethod interface {
Verify(signingString, signature string, key []byte)error Verify(signingString, signature string, key []byte) error
Sign(token *Token, key []byte)error Sign(token *Token, key []byte) error
} }
func RegisterSigningMethod(alg string, f func() SigningMethod) { func RegisterSigningMethod(alg string, f func() SigningMethod) {
signingMethods[alg] = f signingMethods[alg] = f
} }
func GetSigningMethod(alg string)(method SigningMethod, err error) { func GetSigningMethod(alg string) (method SigningMethod, err error) {
if methodF, ok := signingMethods[alg]; ok { if methodF, ok := signingMethods[alg]; ok {
method = methodF() method = methodF()
} else { } else {