moved tests into jwt_test package for better clarity

This commit is contained in:
Dave Grijalva 2014-12-28 12:24:41 -08:00
parent 93e34196bb
commit 6a9bcca4b5
3 changed files with 44 additions and 42 deletions

View File

@ -1,6 +1,7 @@
package jwt package jwt_test
import ( import (
"github.com/dgrijalva/jwt-go"
"io/ioutil" "io/ioutil"
"strings" "strings"
"testing" "testing"
@ -50,7 +51,7 @@ func TestHMACVerify(t *testing.T) {
for _, data := range hmacTestData { for _, data := range hmacTestData {
parts := strings.Split(data.tokenString, ".") parts := strings.Split(data.tokenString, ".")
method := GetSigningMethod(data.alg) method := jwt.GetSigningMethod(data.alg)
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], hmacTestKey) err := method.Verify(strings.Join(parts[0:2], "."), parts[2], hmacTestKey)
if data.valid && err != nil { if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err) t.Errorf("[%v] Error while verifying key: %v", data.name, err)
@ -65,7 +66,7 @@ func TestHMACSign(t *testing.T) {
for _, data := range hmacTestData { for _, data := range hmacTestData {
if data.valid { if data.valid {
parts := strings.Split(data.tokenString, ".") parts := strings.Split(data.tokenString, ".")
method := GetSigningMethod(data.alg) method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey) sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey)
if err != nil { if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err) t.Errorf("[%v] Error signing token: %v", data.name, err)

View File

@ -1,7 +1,8 @@
package jwt package jwt_test
import ( import (
"fmt" "fmt"
"github.com/dgrijalva/jwt-go"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"reflect" "reflect"
@ -11,19 +12,19 @@ import (
var ( var (
jwtTestDefaultKey []byte jwtTestDefaultKey []byte
defaultKeyFunc Keyfunc = func(t *Token) (interface{}, error) { return jwtTestDefaultKey, nil } defaultKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return jwtTestDefaultKey, nil }
emptyKeyFunc Keyfunc = func(t *Token) (interface{}, error) { return nil, nil } emptyKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return nil, nil }
errorKeyFunc Keyfunc = func(t *Token) (interface{}, error) { return nil, fmt.Errorf("error loading key") } errorKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return nil, fmt.Errorf("error loading key") }
nilKeyFunc Keyfunc = nil nilKeyFunc jwt.Keyfunc = nil
) )
var jwtTestData = []struct { var jwtTestData = []struct {
name string name string
tokenString string tokenString string
keyfunc Keyfunc keyfunc jwt.Keyfunc
claims map[string]interface{} claims map[string]interface{}
valid bool valid bool
validationError *ValidationError errors uint32
}{ }{
{ {
"basic", "basic",
@ -31,7 +32,7 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
map[string]interface{}{"foo": "bar"}, map[string]interface{}{"foo": "bar"},
true, true,
nil, 0,
}, },
{ {
"basic expired", "basic expired",
@ -39,7 +40,7 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
map[string]interface{}{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, map[string]interface{}{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
false, false,
&ValidationError{Errors: ValidationErrorExpired}, jwt.ValidationErrorExpired,
}, },
{ {
"basic nbf", "basic nbf",
@ -47,7 +48,7 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
map[string]interface{}{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, map[string]interface{}{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
false, false,
&ValidationError{Errors: ValidationErrorNotValidYet}, jwt.ValidationErrorNotValidYet,
}, },
{ {
"expired and nbf", "expired and nbf",
@ -55,7 +56,7 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
map[string]interface{}{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, map[string]interface{}{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
false, false,
&ValidationError{Errors: ValidationErrorNotValidYet | ValidationErrorExpired}, jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
}, },
{ {
"basic invalid", "basic invalid",
@ -63,7 +64,7 @@ var jwtTestData = []struct {
defaultKeyFunc, defaultKeyFunc,
map[string]interface{}{"foo": "bar"}, map[string]interface{}{"foo": "bar"},
false, false,
&ValidationError{Errors: ValidationErrorSignatureInvalid}, jwt.ValidationErrorSignatureInvalid,
}, },
{ {
"basic nokeyfunc", "basic nokeyfunc",
@ -71,7 +72,7 @@ var jwtTestData = []struct {
nilKeyFunc, nilKeyFunc,
map[string]interface{}{"foo": "bar"}, map[string]interface{}{"foo": "bar"},
false, false,
&ValidationError{Errors: ValidationErrorUnverifiable}, jwt.ValidationErrorUnverifiable,
}, },
{ {
"basic nokey", "basic nokey",
@ -79,7 +80,7 @@ var jwtTestData = []struct {
emptyKeyFunc, emptyKeyFunc,
map[string]interface{}{"foo": "bar"}, map[string]interface{}{"foo": "bar"},
false, false,
&ValidationError{Errors: ValidationErrorSignatureInvalid}, jwt.ValidationErrorSignatureInvalid,
}, },
{ {
"basic errorkey", "basic errorkey",
@ -87,7 +88,7 @@ var jwtTestData = []struct {
errorKeyFunc, errorKeyFunc,
map[string]interface{}{"foo": "bar"}, map[string]interface{}{"foo": "bar"},
false, false,
&ValidationError{Errors: ValidationErrorUnverifiable}, jwt.ValidationErrorUnverifiable,
}, },
} }
@ -104,7 +105,7 @@ func makeSample(c map[string]interface{}) string {
panic(e.Error()) panic(e.Error())
} }
token := New(SigningMethodRS256) token := jwt.New(jwt.SigningMethodRS256)
token.Claims = c token.Claims = c
s, e := token.SignedString(key) s, e := token.SignedString(key)
@ -120,7 +121,7 @@ func TestJWT(t *testing.T) {
if data.tokenString == "" { if data.tokenString == "" {
data.tokenString = makeSample(data.claims) data.tokenString = makeSample(data.claims)
} }
token, err := Parse(data.tokenString, data.keyfunc) token, err := jwt.Parse(data.tokenString, data.keyfunc)
if !reflect.DeepEqual(data.claims, token.Claims) { if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
@ -131,13 +132,12 @@ func TestJWT(t *testing.T) {
if !data.valid && err == nil { if !data.valid && err == nil {
t.Errorf("[%v] Invalid token passed validation", data.name) t.Errorf("[%v] Invalid token passed validation", data.name)
} }
if data.validationError != nil { if data.errors != 0 {
if err == nil { if err == nil {
t.Errorf("[%v] Expecting error. Didn't get one.", data.name) t.Errorf("[%v] Expecting error. Didn't get one.", data.name)
} else { } else {
// perform deep equal without the string bit // compare the bitfield part of the error
err.(*ValidationError).err = "" if err.(*jwt.ValidationError).Errors != data.errors {
if !reflect.DeepEqual(data.validationError, err) {
t.Errorf("[%v] Errors don't match expectation", data.name) t.Errorf("[%v] Errors don't match expectation", data.name)
} }
@ -155,7 +155,7 @@ func TestParseRequest(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil) r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString)) r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString))
token, err := ParseFromRequest(r, data.keyfunc) token, err := jwt.ParseFromRequest(r, data.keyfunc)
if !reflect.DeepEqual(data.claims, token.Claims) { if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)

View File

@ -1,6 +1,7 @@
package jwt package jwt_test
import ( import (
"github.com/dgrijalva/jwt-go"
"io/ioutil" "io/ioutil"
"strings" "strings"
"testing" "testing"
@ -49,7 +50,7 @@ func TestRSAVerify(t *testing.T) {
for _, data := range rsaTestData { for _, data := range rsaTestData {
parts := strings.Split(data.tokenString, ".") parts := strings.Split(data.tokenString, ".")
method := GetSigningMethod(data.alg) method := jwt.GetSigningMethod(data.alg)
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], key) err := method.Verify(strings.Join(parts[0:2], "."), parts[2], key)
if data.valid && err != nil { if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err) t.Errorf("[%v] Error while verifying key: %v", data.name, err)
@ -66,7 +67,7 @@ func TestRSASign(t *testing.T) {
for _, data := range rsaTestData { for _, data := range rsaTestData {
if data.valid { if data.valid {
parts := strings.Split(data.tokenString, ".") parts := strings.Split(data.tokenString, ".")
method := GetSigningMethod(data.alg) method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), key) sig, err := method.Sign(strings.Join(parts[0:2], "."), key)
if err != nil { if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err) t.Errorf("[%v] Error signing token: %v", data.name, err)
@ -80,13 +81,13 @@ func TestRSASign(t *testing.T) {
func TestRSAVerifyWithPreParsedPrivateKey(t *testing.T) { func TestRSAVerifyWithPreParsedPrivateKey(t *testing.T) {
key, _ := ioutil.ReadFile("test/sample_key.pub") key, _ := ioutil.ReadFile("test/sample_key.pub")
parsedKey, err := ParseRSAPublicKeyFromPEM(key) parsedKey, err := jwt.ParseRSAPublicKeyFromPEM(key)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
testData := rsaTestData[0] testData := rsaTestData[0]
parts := strings.Split(testData.tokenString, ".") parts := strings.Split(testData.tokenString, ".")
err = SigningMethodRS256.Verify(strings.Join(parts[0:2], "."), parts[2], parsedKey) err = jwt.SigningMethodRS256.Verify(strings.Join(parts[0:2], "."), parts[2], parsedKey)
if err != nil { if err != nil {
t.Errorf("[%v] Error while verifying key: %v", testData.name, err) t.Errorf("[%v] Error while verifying key: %v", testData.name, err)
} }
@ -94,13 +95,13 @@ func TestRSAVerifyWithPreParsedPrivateKey(t *testing.T) {
func TestRSAWithPreParsedPrivateKey(t *testing.T) { func TestRSAWithPreParsedPrivateKey(t *testing.T) {
key, _ := ioutil.ReadFile("test/sample_key") key, _ := ioutil.ReadFile("test/sample_key")
parsedKey, err := ParseRSAPrivateKeyFromPEM(key) parsedKey, err := jwt.ParseRSAPrivateKeyFromPEM(key)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
testData := rsaTestData[0] testData := rsaTestData[0]
parts := strings.Split(testData.tokenString, ".") parts := strings.Split(testData.tokenString, ".")
sig, err := SigningMethodRS256.Sign(strings.Join(parts[0:2], "."), parsedKey) sig, err := jwt.SigningMethodRS256.Sign(strings.Join(parts[0:2], "."), parsedKey)
if err != nil { if err != nil {
t.Errorf("[%v] Error signing token: %v", testData.name, err) t.Errorf("[%v] Error signing token: %v", testData.name, err)
} }
@ -115,28 +116,28 @@ func TestRSAKeyParsing(t *testing.T) {
badKey := []byte("All your base are belong to key") badKey := []byte("All your base are belong to key")
// Test parsePrivateKey // Test parsePrivateKey
if _, e := ParseRSAPrivateKeyFromPEM(key); e != nil { if _, e := jwt.ParseRSAPrivateKeyFromPEM(key); e != nil {
t.Errorf("Failed to parse valid private key: %v", e) t.Errorf("Failed to parse valid private key: %v", e)
} }
if k, e := ParseRSAPrivateKeyFromPEM(pubKey); e == nil { if k, e := jwt.ParseRSAPrivateKeyFromPEM(pubKey); e == nil {
t.Errorf("Parsed public key as valid private key: %v", k) t.Errorf("Parsed public key as valid private key: %v", k)
} }
if k, e := ParseRSAPrivateKeyFromPEM(badKey); e == nil { if k, e := jwt.ParseRSAPrivateKeyFromPEM(badKey); e == nil {
t.Errorf("Parsed invalid key as valid private key: %v", k) t.Errorf("Parsed invalid key as valid private key: %v", k)
} }
// Test parsePublicKey // Test parsePublicKey
if _, e := ParseRSAPublicKeyFromPEM(pubKey); e != nil { if _, e := jwt.ParseRSAPublicKeyFromPEM(pubKey); e != nil {
t.Errorf("Failed to parse valid public key: %v", e) t.Errorf("Failed to parse valid public key: %v", e)
} }
if k, e := ParseRSAPublicKeyFromPEM(key); e == nil { if k, e := jwt.ParseRSAPublicKeyFromPEM(key); e == nil {
t.Errorf("Parsed private key as valid public key: %v", k) t.Errorf("Parsed private key as valid public key: %v", k)
} }
if k, e := ParseRSAPublicKeyFromPEM(badKey); e == nil { if k, e := jwt.ParseRSAPublicKeyFromPEM(badKey); e == nil {
t.Errorf("Parsed invalid key as valid private key: %v", k) t.Errorf("Parsed invalid key as valid private key: %v", k)
} }