base64 decoding fixins

This commit is contained in:
Dave Grijalva 2012-04-18 12:18:31 -07:00
parent 50d3e917c8
commit 18108cfd38
5 changed files with 71 additions and 18 deletions

24
jwt.go
View File

@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/json"
"time"
"net/http"
)
// A JWT Token
@ -26,7 +27,7 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke
token = new(Token)
// parse Header
var headerBytes []byte
if headerBytes, err = base64.URLEncoding.DecodeString(parts[0]); err != nil {
if headerBytes, err = DecodeSegment(parts[0]); err != nil {
return
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {
@ -35,7 +36,7 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke
// parse Claims
var claimBytes []byte
if claimBytes, err = base64.URLEncoding.DecodeString(parts[1]); err != nil {
if claimBytes, err = DecodeSegment(parts[1]); err != nil {
return
}
if err = json.Unmarshal(claimBytes, &token.Claims); err != nil {
@ -75,3 +76,22 @@ func Parse(tokenString string, keyFunc func(*Token)([]byte, error)) (token *Toke
}
return
}
func ParseFromRequest(req *http.Request, keyFunc func(*Token)([]byte, error))(token *Token, err error) {
return nil, nil
}
func DecodeSegment(seg string)([]byte, error) {
// len % 4
switch len(seg) % 4 {
case 2:
seg = seg + "=="
case 3:
seg = seg + "==="
}
return base64.URLEncoding.DecodeString(seg)
}

View File

@ -1,9 +1,51 @@
package jwt
import (
"os"
"io"
"bytes"
"testing"
"reflect"
)
func TestJWT(t *testing.T) {
var jwtTestData = []struct{
name string
tokenString string
claims map[string]interface{}
valid bool
}{
{
"basic: foo => bar",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
map[string]interface{}{"foo": "bar"},
true,
},
{
"basic invalid: foo => bar",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.EhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
map[string]interface{}{"foo": "bar"},
false,
},
}
func TestJWT(t *testing.T) {
file, _ := os.Open("test/sample_key.pub")
buf := new(bytes.Buffer)
io.Copy(buf, file)
key := buf.Bytes()
file.Close()
for _, data := range jwtTestData {
token, err := Parse(data.tokenString, func(t *Token)([]byte, error){ return key, nil })
if !reflect.DeepEqual(data.claims, token.Claims) {
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
}
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying token: %v", data.name, err)
}
if !data.valid && err == nil {
t.Errorf("[%v] Invalid token passed validation", data.name)
}
}
}

View File

@ -2,7 +2,6 @@ package jwt
import (
"errors"
"encoding/base64"
"encoding/pem"
"crypto"
"crypto/x509"
@ -19,17 +18,9 @@ func init() {
}
func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)(err error) {
// len % 4
switch len(signature) % 4 {
case 2:
signature = signature + "=="
case 3:
signature = signature + "==="
}
// Key
var sig []byte
if sig, err = base64.URLEncoding.DecodeString(signature); err == nil {
if sig, err = DecodeSegment(signature); err == nil {
var block *pem.Block
if block, _ = pem.Decode(key); block != nil {
var parsedKey interface{}
@ -50,6 +41,6 @@ func (m *SigningMethodRS256) Verify(signingString, signature string, key []byte)
return
}
func (m *SigningMethodRS256) Sign(token, key []byte)error {
func (m *SigningMethodRS256) Sign(token *Token, key []byte)error {
return nil
}

View File

@ -8,7 +8,7 @@ import (
"strings"
)
var testData = []struct{
var rsaTestData = []struct{
name string
tokenString string
claims map[string]interface{}
@ -35,7 +35,7 @@ func TestRS256Verify(t *testing.T) {
key := buf.Bytes()
file.Close()
for _, data := range testData {
for _, data := range rsaTestData {
parts := strings.Split(data.tokenString, ".")
method, _ := GetSigningMethod("RS256")

View File

@ -10,7 +10,7 @@ var signingMethods = map[string]func() SigningMethod{}
// Signing method
type SigningMethod interface {
Verify(signingString, signature string, key []byte)error
Sign(token, key []byte)error
Sign(token *Token, key []byte)error
}
func RegisterSigningMethod(alg string, f func() SigningMethod) {