mirror of https://github.com/golang-jwt/jwt.git
pass keys as interface{} rather than []byte
This will allow clients to pass, for example, their own instances of rsa.PublicKey if the key is not specified as some flavour of X509 cert. For example, Salesforce just specify the modulus and exponent (https://login.salesforce.com/id/keys)
This commit is contained in:
parent
0ed08007c3
commit
23cb3af02c
|
@ -115,7 +115,7 @@ func verifyToken() error {
|
|||
}
|
||||
|
||||
// Parse the token. Load the key from command line option
|
||||
token, err := jwt.Parse(string(tokData), func(t *jwt.Token) ([]byte, error) {
|
||||
token, err := jwt.Parse(string(tokData), func(t *jwt.Token) (interface{}, error) {
|
||||
return loadData(*flagKey)
|
||||
})
|
||||
|
||||
|
|
39
hmac.go
39
hmac.go
|
@ -42,24 +42,31 @@ func (m *SigningMethodHMAC) Alg() string {
|
|||
return m.Name
|
||||
}
|
||||
|
||||
func (m *SigningMethodHMAC) Verify(signingString, signature string, key []byte) error {
|
||||
// Key
|
||||
var sig []byte
|
||||
var err error
|
||||
if sig, err = DecodeSegment(signature); err == nil {
|
||||
hasher := hmac.New(m.Hash.New, key)
|
||||
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
|
||||
if keyBytes, ok := key.([]byte); ok {
|
||||
var sig []byte
|
||||
var err error
|
||||
if sig, err = DecodeSegment(signature); err == nil {
|
||||
hasher := hmac.New(m.Hash.New, keyBytes)
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
if !bytes.Equal(sig, hasher.Sum(nil)) {
|
||||
err = errors.New("Signature is invalid")
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return ErrInvalidKey
|
||||
}
|
||||
|
||||
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) {
|
||||
if keyBytes, ok := key.([]byte); ok {
|
||||
hasher := hmac.New(m.Hash.New, keyBytes)
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
if !bytes.Equal(sig, hasher.Sum(nil)) {
|
||||
err = errors.New("Signature is invalid")
|
||||
}
|
||||
return EncodeSegment(hasher.Sum(nil)), nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *SigningMethodHMAC) Sign(signingString string, key []byte) (string, error) {
|
||||
hasher := hmac.New(m.Hash.New, key)
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
return EncodeSegment(hasher.Sum(nil)), nil
|
||||
return "", ErrInvalidKey
|
||||
}
|
||||
|
|
4
jwt.go
4
jwt.go
|
@ -18,7 +18,7 @@ var TimeFunc = time.Now
|
|||
// the key for verification. The function receives the parsed,
|
||||
// but unverified Token. This allows you to use propries in the
|
||||
// Header of the token (such as `kid`) to identify which key to use.
|
||||
type Keyfunc func(*Token) ([]byte, error)
|
||||
type Keyfunc func(*Token) (interface{}, error)
|
||||
|
||||
// A JWT Token. Different fields will be used depending on whether you're
|
||||
// creating or parsing/verifying a token.
|
||||
|
@ -120,7 +120,7 @@ func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
|
|||
}
|
||||
|
||||
// Lookup key
|
||||
var key []byte
|
||||
var key interface{}
|
||||
if key, err = keyFunc(token); err != nil {
|
||||
return token, &ValidationError{err: err.Error(), Errors: ValidationErrorUnverifiable}
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ func TestJWT(t *testing.T) {
|
|||
if data.tokenString == "" {
|
||||
data.tokenString = makeSample(data.claims)
|
||||
}
|
||||
token, err := Parse(data.tokenString, func(t *Token) ([]byte, error) { return key, nil })
|
||||
token, err := Parse(data.tokenString, func(t *Token) (interface{}, error) { return key, nil })
|
||||
|
||||
if !reflect.DeepEqual(data.claims, token.Claims) {
|
||||
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
|
||||
|
@ -120,7 +120,7 @@ func TestParseRequest(t *testing.T) {
|
|||
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", data.tokenString))
|
||||
token, err := ParseFromRequest(r, func(t *Token) ([]byte, error) { return key, nil })
|
||||
token, err := ParseFromRequest(r, func(t *Token) (interface{}, error) { return key, nil })
|
||||
|
||||
if !reflect.DeepEqual(data.claims, token.Claims) {
|
||||
t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
|
||||
|
|
63
rsa.go
63
rsa.go
|
@ -18,6 +18,7 @@ var (
|
|||
SigningMethodRS256 *SigningMethodRSA
|
||||
SigningMethodRS384 *SigningMethodRSA
|
||||
SigningMethodRS512 *SigningMethodRSA
|
||||
ErrInvalidKey = errors.New("An invalid key was passed. Expected a []byte")
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -44,7 +45,7 @@ func (m *SigningMethodRSA) Alg() string {
|
|||
return m.Name
|
||||
}
|
||||
|
||||
func (m *SigningMethodRSA) Verify(signingString, signature string, key []byte) error {
|
||||
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
|
||||
var err error
|
||||
|
||||
// Decode the signature
|
||||
|
@ -53,42 +54,48 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key []byte) e
|
|||
return err
|
||||
}
|
||||
|
||||
// Parse public key
|
||||
var rsaKey *rsa.PublicKey
|
||||
if rsaKey, err = m.parsePublicKey(key); err != nil {
|
||||
return err
|
||||
if keyBytes, ok := key.([]byte); ok {
|
||||
var rsaKey *rsa.PublicKey
|
||||
if rsaKey, err = m.parsePublicKey(keyBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create hasher
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Verify the signature
|
||||
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
|
||||
} else {
|
||||
return ErrInvalidKey
|
||||
}
|
||||
|
||||
// Create hasher
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Verify the signature
|
||||
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
|
||||
}
|
||||
|
||||
// Implements the Sign method from SigningMethod
|
||||
// For this signing method, must be PEM encoded PKCS1 or PKCS8 RSA private key
|
||||
func (m *SigningMethodRSA) Sign(signingString string, key []byte) (string, error) {
|
||||
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) {
|
||||
var err error
|
||||
|
||||
// Parse private key
|
||||
var rsaKey *rsa.PrivateKey
|
||||
if rsaKey, err = m.parsePrivateKey(key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create the hasher
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Sign the string and return the encoded bytes
|
||||
if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil {
|
||||
return EncodeSegment(sigBytes), nil
|
||||
} else {
|
||||
return "", err
|
||||
if keyBytes, ok := key.([]byte); ok {
|
||||
// Key
|
||||
var rsaKey *rsa.PrivateKey
|
||||
if rsaKey, err = m.parsePrivateKey(keyBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create the hasher
|
||||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Sign the string and return the encoded bytes
|
||||
if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil {
|
||||
return EncodeSegment(sigBytes), nil
|
||||
} else {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return "", ErrInvalidKey
|
||||
}
|
||||
|
||||
// Parse PEM encoded PKCS1 or PKCS8 public key
|
||||
|
|
|
@ -4,8 +4,8 @@ var signingMethods = map[string]func() SigningMethod{}
|
|||
|
||||
// Signing method
|
||||
type SigningMethod interface {
|
||||
Verify(signingString, signature string, key []byte) error
|
||||
Sign(signingString string, key []byte) (string, error)
|
||||
Verify(signingString, signature string, key interface{}) error
|
||||
Sign(signingString string, key interface{}) (string, error)
|
||||
Alg() string
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue