forked from mirror/jwt
Compare commits
1 Commits
Author | SHA1 | Date |
---|---|---|
Christian Banse | 73c8300abd |
|
@ -128,7 +128,7 @@ func verifyToken() error {
|
|||
}
|
||||
|
||||
// Parse the token. Load the key from command line option
|
||||
token, err := jwt.Parse(string(tokData), func(t *jwt.Token) (interface{}, error) {
|
||||
token, err := jwt.Parse[any](string(tokData), func(t *jwt.Token[any]) (any, error) {
|
||||
if isNone() {
|
||||
return jwt.UnsafeAllowNoneSignatureType, nil
|
||||
}
|
||||
|
@ -206,7 +206,7 @@ func signToken() error {
|
|||
}
|
||||
|
||||
// get the signing alg
|
||||
alg := jwt.GetSigningMethod(*flagAlg)
|
||||
alg := jwt.GetSigningMethod[any](*flagAlg)
|
||||
if alg == nil {
|
||||
return fmt.Errorf("couldn't find signing method: %v", *flagAlg)
|
||||
}
|
||||
|
@ -273,7 +273,7 @@ func showToken() error {
|
|||
fmt.Fprintf(os.Stderr, "Token len: %v bytes\n", len(tokData))
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(string(tokData), nil)
|
||||
token, err := jwt.Parse[any](string(tokData), nil)
|
||||
if token == nil {
|
||||
return fmt.Errorf("malformed token: %w", err)
|
||||
}
|
||||
|
|
22
ecdsa.go
22
ecdsa.go
|
@ -32,19 +32,19 @@ var (
|
|||
func init() {
|
||||
// ES256
|
||||
SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256}
|
||||
RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod[*ecdsa.PublicKey] {
|
||||
return SigningMethodES256
|
||||
})
|
||||
|
||||
// ES384
|
||||
SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384}
|
||||
RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod[*ecdsa.PublicKey] {
|
||||
return SigningMethodES384
|
||||
})
|
||||
|
||||
// ES512
|
||||
SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521}
|
||||
RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod[*ecdsa.PublicKey] {
|
||||
return SigningMethodES512
|
||||
})
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ func (m *SigningMethodECDSA) Alg() string {
|
|||
|
||||
// Verify implements token verification for the SigningMethod.
|
||||
// For this verify method, key must be an ecdsa.PublicKey struct
|
||||
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
|
||||
func (m *SigningMethodECDSA) Verify(signingString, signature string, key *ecdsa.PublicKey) error {
|
||||
var err error
|
||||
|
||||
// Decode the signature
|
||||
|
@ -64,15 +64,6 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa
|
|||
return err
|
||||
}
|
||||
|
||||
// Get the key
|
||||
var ecdsaKey *ecdsa.PublicKey
|
||||
switch k := key.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
ecdsaKey = k
|
||||
default:
|
||||
return ErrInvalidKeyType
|
||||
}
|
||||
|
||||
if len(sig) != 2*m.KeySize {
|
||||
return ErrECDSAVerification
|
||||
}
|
||||
|
@ -88,7 +79,7 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa
|
|||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Verify the signature
|
||||
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus {
|
||||
if verifystatus := ecdsa.Verify(key, hasher.Sum(nil), r, s); verifystatus {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -115,6 +106,9 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
|
|||
hasher := m.Hash.New()
|
||||
hasher.Write([]byte(signingString))
|
||||
|
||||
var s crypto.Signer = ecdsaKey
|
||||
s.Sign(rand.Reader, hasher.Sum(nil), nil)
|
||||
|
||||
// Sign the string and return r, s
|
||||
if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
|
||||
curveBits := ecdsaKey.Curve.Params().BitSize
|
||||
|
|
|
@ -64,7 +64,7 @@ func TestECDSAVerify(t *testing.T) {
|
|||
|
||||
parts := strings.Split(data.tokenString, ".")
|
||||
|
||||
method := jwt.GetSigningMethod(data.alg)
|
||||
method := jwt.GetSigningMethod[*ecdsa.PublicKey](data.alg)
|
||||
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ecdsaKey)
|
||||
if data.valid && err != nil {
|
||||
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
|
||||
|
@ -88,7 +88,7 @@ func TestECDSASign(t *testing.T) {
|
|||
if data.valid {
|
||||
parts := strings.Split(data.tokenString, ".")
|
||||
toSign := strings.Join(parts[0:2], ".")
|
||||
method := jwt.GetSigningMethod(data.alg)
|
||||
method := jwt.GetSigningMethod[*ecdsa.PublicKey](data.alg)
|
||||
sig, err := method.Sign(toSign, ecdsaKey)
|
||||
|
||||
if err != nil {
|
||||
|
@ -98,7 +98,7 @@ func TestECDSASign(t *testing.T) {
|
|||
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)
|
||||
}
|
||||
|
||||
err = method.Verify(toSign, sig, ecdsaKey.Public())
|
||||
err = method.Verify(toSign, sig, ecdsaKey.Public().(*ecdsa.PublicKey))
|
||||
if err != nil {
|
||||
t.Errorf("[%v] Sign produced an invalid signature: %v", data.name, err)
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ func BenchmarkECDSASigning(b *testing.B) {
|
|||
b.Fatalf("Unable to parse ECDSA private key: %v", err)
|
||||
}
|
||||
|
||||
method := jwt.GetSigningMethod(data.alg)
|
||||
method := jwt.GetSigningMethod[*ecdsa.PublicKey](data.alg)
|
||||
|
||||
b.Run(data.name, func(b *testing.B) {
|
||||
benchmarkSigning(b, method, ecdsaKey)
|
||||
|
|
|
@ -23,7 +23,7 @@ var (
|
|||
|
||||
func init() {
|
||||
SigningMethodEdDSA = &SigningMethodEd25519{}
|
||||
RegisterSigningMethod(SigningMethodEdDSA.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodEdDSA.Alg(), func() SigningMethod[ed25519.PublicKey] {
|
||||
return SigningMethodEdDSA
|
||||
})
|
||||
}
|
||||
|
@ -34,14 +34,13 @@ func (m *SigningMethodEd25519) Alg() string {
|
|||
|
||||
// Verify implements token verification for the SigningMethod.
|
||||
// For this verify method, key must be an ed25519.PublicKey
|
||||
func (m *SigningMethodEd25519) Verify(signingString, signature string, key interface{}) error {
|
||||
func (m *SigningMethodEd25519) Verify(signingString, signature string, key ed25519.PublicKey) error {
|
||||
var err error
|
||||
var ed25519Key ed25519.PublicKey
|
||||
var ok bool
|
||||
|
||||
if ed25519Key, ok = key.(ed25519.PublicKey); !ok {
|
||||
/*if ed25519Key, ok = key.(ed25519.PublicKey); !ok {
|
||||
return ErrInvalidKeyType
|
||||
}
|
||||
}*/
|
||||
|
||||
if len(ed25519Key) != ed25519.PublicKeySize {
|
||||
return ErrInvalidKey
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
var ed25519TestData = []struct {
|
||||
|
@ -47,7 +48,7 @@ func TestEd25519Verify(t *testing.T) {
|
|||
|
||||
parts := strings.Split(data.tokenString, ".")
|
||||
|
||||
method := jwt.GetSigningMethod(data.alg)
|
||||
method := jwt.GetSigningMethod[*ed25519.PublicKey](data.alg)
|
||||
|
||||
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ed25519Key)
|
||||
if data.valid && err != nil {
|
||||
|
@ -71,7 +72,7 @@ func TestEd25519Sign(t *testing.T) {
|
|||
|
||||
parts := strings.Split(data.tokenString, ".")
|
||||
|
||||
method := jwt.GetSigningMethod(data.alg)
|
||||
method := jwt.GetSigningMethod[*ed25519.PublicKey](data.alg)
|
||||
|
||||
sig, err := method.Sign(strings.Join(parts[0:2], "."), ed25519Key)
|
||||
if err != nil {
|
||||
|
|
|
@ -39,7 +39,7 @@ func ParseEdPrivateKeyFromPEM(key []byte) (crypto.PrivateKey, error) {
|
|||
}
|
||||
|
||||
// ParseEdPublicKeyFromPEM parses a PEM-encoded Edwards curve public key
|
||||
func ParseEdPublicKeyFromPEM(key []byte) (crypto.PublicKey, error) {
|
||||
func ParseEdPublicKeyFromPEM(key []byte) (*ed25519.PublicKey, error) {
|
||||
var err error
|
||||
|
||||
// Parse PEM block
|
||||
|
@ -60,5 +60,5 @@ func ParseEdPublicKeyFromPEM(key []byte) (crypto.PublicKey, error) {
|
|||
return nil, ErrNotEdPublicKey
|
||||
}
|
||||
|
||||
return pkey, nil
|
||||
return &pkey, nil
|
||||
}
|
||||
|
|
16
hmac.go
16
hmac.go
|
@ -24,19 +24,19 @@ var (
|
|||
func init() {
|
||||
// HS256
|
||||
SigningMethodHS256 = &SigningMethodHMAC{"HS256", crypto.SHA256}
|
||||
RegisterSigningMethod(SigningMethodHS256.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodHS256.Alg(), func() SigningMethod[[]byte] {
|
||||
return SigningMethodHS256
|
||||
})
|
||||
|
||||
// HS384
|
||||
SigningMethodHS384 = &SigningMethodHMAC{"HS384", crypto.SHA384}
|
||||
RegisterSigningMethod(SigningMethodHS384.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodHS384.Alg(), func() SigningMethod[[]byte] {
|
||||
return SigningMethodHS384
|
||||
})
|
||||
|
||||
// HS512
|
||||
SigningMethodHS512 = &SigningMethodHMAC{"HS512", crypto.SHA512}
|
||||
RegisterSigningMethod(SigningMethodHS512.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodHS512.Alg(), func() SigningMethod[[]byte] {
|
||||
return SigningMethodHS512
|
||||
})
|
||||
}
|
||||
|
@ -46,13 +46,7 @@ func (m *SigningMethodHMAC) Alg() string {
|
|||
}
|
||||
|
||||
// Verify implements token verification for the SigningMethod. Returns nil if the signature is valid.
|
||||
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
|
||||
// Verify the key is the right type
|
||||
keyBytes, ok := key.([]byte)
|
||||
if !ok {
|
||||
return ErrInvalidKeyType
|
||||
}
|
||||
|
||||
func (m *SigningMethodHMAC) Verify(signingString, signature string, key []byte) error {
|
||||
// Decode signature, for comparison
|
||||
sig, err := DecodeSegment(signature)
|
||||
if err != nil {
|
||||
|
@ -67,7 +61,7 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac
|
|||
// This signing method is symmetric, so we validate the signature
|
||||
// by reproducing the signature from the signing string and key, then
|
||||
// comparing that against the provided signature.
|
||||
hasher := hmac.New(m.Hash.New, keyBytes)
|
||||
hasher := hmac.New(m.Hash.New, key)
|
||||
hasher.Write([]byte(signingString))
|
||||
if !hmac.Equal(sig, hasher.Sum(nil)) {
|
||||
return ErrSignatureInvalid
|
||||
|
|
9
none.go
9
none.go
|
@ -15,7 +15,7 @@ func init() {
|
|||
SigningMethodNone = &signingMethodNone{}
|
||||
NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid)
|
||||
|
||||
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod[unsafeNoneMagicConstant] {
|
||||
return SigningMethodNone
|
||||
})
|
||||
}
|
||||
|
@ -25,12 +25,7 @@ func (m *signingMethodNone) Alg() string {
|
|||
}
|
||||
|
||||
// Only allow 'none' alg type if UnsafeAllowNoneSignatureType is specified as the key
|
||||
func (m *signingMethodNone) Verify(signingString, signature string, key interface{}) (err error) {
|
||||
// Key must be UnsafeAllowNoneSignatureType to prevent accidentally
|
||||
// accepting 'none' signing method
|
||||
if _, ok := key.(unsafeNoneMagicConstant); !ok {
|
||||
return NoneSignatureTypeDisallowedError
|
||||
}
|
||||
func (m *signingMethodNone) Verify(signingString, signature string, key unsafeNoneMagicConstant) (err error) {
|
||||
// If signing method is none, signature must be an empty string
|
||||
if signature != "" {
|
||||
return NewValidationError(
|
||||
|
|
18
parser.go
18
parser.go
|
@ -7,7 +7,7 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
type Parser[T Key] struct {
|
||||
// If populated, only these methods will be considered valid.
|
||||
//
|
||||
// Deprecated: In future releases, this field will not be exported anymore and should be set with an option to NewParser instead.
|
||||
|
@ -25,8 +25,8 @@ type Parser struct {
|
|||
}
|
||||
|
||||
// NewParser creates a new Parser with the specified options
|
||||
func NewParser(options ...ParserOption) *Parser {
|
||||
p := &Parser{}
|
||||
func NewParser[T Key](options ...ParserOption[T]) *Parser[T] {
|
||||
p := &Parser[T]{}
|
||||
|
||||
// loop through our parsing options and apply them
|
||||
for _, option := range options {
|
||||
|
@ -38,11 +38,11 @@ func NewParser(options ...ParserOption) *Parser {
|
|||
|
||||
// Parse parses, validates, verifies the signature and returns the parsed token.
|
||||
// keyFunc will receive the parsed token and should return the key for validating.
|
||||
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
|
||||
func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], error) {
|
||||
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
|
||||
}
|
||||
|
||||
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
|
||||
func (p *Parser[T]) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc[T]) (*Token[T], error) {
|
||||
token, parts, err := p.ParseUnverified(tokenString, claims)
|
||||
if err != nil {
|
||||
return token, err
|
||||
|
@ -65,7 +65,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
|
|||
}
|
||||
|
||||
// Lookup key
|
||||
var key interface{}
|
||||
var key T
|
||||
if keyFunc == nil {
|
||||
// keyFunc was not provided. short circuiting validation
|
||||
return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable)
|
||||
|
@ -115,13 +115,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
|
|||
//
|
||||
// It's only ever useful in cases where you know the signature is valid (because it has
|
||||
// been checked previously in the stack) and you want to extract values from it.
|
||||
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
|
||||
func (p *Parser[T]) ParseUnverified(tokenString string, claims Claims) (token *Token[T], parts []string, err error) {
|
||||
parts = strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
|
||||
}
|
||||
|
||||
token = &Token{Raw: tokenString}
|
||||
token = &Token[T]{Raw: tokenString}
|
||||
|
||||
// parse Header
|
||||
var headerBytes []byte
|
||||
|
@ -159,7 +159,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
|
|||
|
||||
// Lookup signature method
|
||||
if method, ok := token.Header["alg"].(string); ok {
|
||||
if token.Method = GetSigningMethod(method); token.Method == nil {
|
||||
if token.Method = GetSigningMethod[T](method); token.Method == nil {
|
||||
return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -3,27 +3,27 @@ package jwt
|
|||
// ParserOption is used to implement functional-style options that modify the behavior of the parser. To add
|
||||
// new options, just create a function (ideally beginning with With or Without) that returns an anonymous function that
|
||||
// takes a *Parser type as input and manipulates its configuration accordingly.
|
||||
type ParserOption func(*Parser)
|
||||
type ParserOption[T Key] func(*Parser[T])
|
||||
|
||||
// WithValidMethods is an option to supply algorithm methods that the parser will check. Only those methods will be considered valid.
|
||||
// It is heavily encouraged to use this option in order to prevent attacks such as https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/.
|
||||
func WithValidMethods(methods []string) ParserOption {
|
||||
return func(p *Parser) {
|
||||
func WithValidMethods[T Key](methods []string) ParserOption[T] {
|
||||
return func(p *Parser[T]) {
|
||||
p.ValidMethods = methods
|
||||
}
|
||||
}
|
||||
|
||||
// WithJSONNumber is an option to configure the underlying JSON parser with UseNumber
|
||||
func WithJSONNumber() ParserOption {
|
||||
return func(p *Parser) {
|
||||
func WithJSONNumber[T Key]() ParserOption[T] {
|
||||
return func(p *Parser[T]) {
|
||||
p.UseJSONNumber = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithoutClaimsValidation is an option to disable claims validation. This option should only be used if you exactly know
|
||||
// what you are doing.
|
||||
func WithoutClaimsValidation() ParserOption {
|
||||
return func(p *Parser) {
|
||||
func WithoutClaimsValidation[T Key]() ParserOption[T] {
|
||||
return func(p *Parser[T]) {
|
||||
p.SkipClaimsValidation = true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,9 +12,9 @@ import (
|
|||
// the logic for extracting a token. Several useful implementations are provided.
|
||||
//
|
||||
// You can provide options to modify parsing behavior
|
||||
func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc, options ...ParseFromRequestOption) (token *jwt.Token, err error) {
|
||||
func ParseFromRequest[T jwt.Key](req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc[T], options ...ParseFromRequestOption[T]) (token *jwt.Token[T], err error) {
|
||||
// Create basic parser struct
|
||||
p := &fromRequestParser{req, extractor, nil, nil}
|
||||
p := &fromRequestParser[T]{req, extractor, nil, nil}
|
||||
|
||||
// Handle options
|
||||
for _, option := range options {
|
||||
|
@ -26,7 +26,7 @@ func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfun
|
|||
p.claims = jwt.MapClaims{}
|
||||
}
|
||||
if p.parser == nil {
|
||||
p.parser = &jwt.Parser{}
|
||||
p.parser = &jwt.Parser[T]{}
|
||||
}
|
||||
|
||||
// perform extract
|
||||
|
@ -42,29 +42,29 @@ func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfun
|
|||
// ParseFromRequestWithClaims is an alias for ParseFromRequest but with custom Claims type.
|
||||
//
|
||||
// Deprecated: use ParseFromRequest and the WithClaims option
|
||||
func ParseFromRequestWithClaims(req *http.Request, extractor Extractor, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) {
|
||||
return ParseFromRequest(req, extractor, keyFunc, WithClaims(claims))
|
||||
func ParseFromRequestWithClaims[T jwt.Key](req *http.Request, extractor Extractor, claims jwt.Claims, keyFunc jwt.Keyfunc[T]) (token *jwt.Token[T], err error) {
|
||||
return ParseFromRequest(req, extractor, keyFunc, WithClaims[T](claims))
|
||||
}
|
||||
|
||||
type fromRequestParser struct {
|
||||
type fromRequestParser[T jwt.Key] struct {
|
||||
req *http.Request
|
||||
extractor Extractor
|
||||
claims jwt.Claims
|
||||
parser *jwt.Parser
|
||||
parser *jwt.Parser[T]
|
||||
}
|
||||
|
||||
type ParseFromRequestOption func(*fromRequestParser)
|
||||
type ParseFromRequestOption[T jwt.Key] func(*fromRequestParser[T])
|
||||
|
||||
// WithClaims parses with custom claims
|
||||
func WithClaims(claims jwt.Claims) ParseFromRequestOption {
|
||||
return func(p *fromRequestParser) {
|
||||
func WithClaims[T jwt.Key](claims jwt.Claims) ParseFromRequestOption[T] {
|
||||
return func(p *fromRequestParser[T]) {
|
||||
p.claims = claims
|
||||
}
|
||||
}
|
||||
|
||||
// WithParser parses using a custom parser
|
||||
func WithParser(parser *jwt.Parser) ParseFromRequestOption {
|
||||
return func(p *fromRequestParser) {
|
||||
func WithParser[T jwt.Key](parser *jwt.Parser[T]) ParseFromRequestOption[T] {
|
||||
return func(p *fromRequestParser[T]) {
|
||||
p.parser = parser
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package request
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -58,14 +59,14 @@ func TestParseRequest(t *testing.T) {
|
|||
// load keys from disk
|
||||
privateKey := test.LoadRSAPrivateKeyFromDisk("../test/sample_key")
|
||||
publicKey := test.LoadRSAPublicKeyFromDisk("../test/sample_key.pub")
|
||||
keyfunc := func(*jwt.Token) (interface{}, error) {
|
||||
keyfunc := func(*jwt.Token[*rsa.PublicKey]) (*rsa.PublicKey, error) {
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
// Bearer token request
|
||||
for _, data := range requestTestData {
|
||||
// Make token from claims
|
||||
tokenString := test.MakeSampleToken(data.claims, jwt.SigningMethodRS256, privateKey)
|
||||
tokenString := test.MakeSampleToken[*rsa.PublicKey](data.claims, jwt.SigningMethodRS256, privateKey)
|
||||
|
||||
// Make query string
|
||||
for k, vv := range data.query {
|
||||
|
|
17
rsa.go
17
rsa.go
|
@ -23,19 +23,19 @@ var (
|
|||
func init() {
|
||||
// RS256
|
||||
SigningMethodRS256 = &SigningMethodRSA{"RS256", crypto.SHA256}
|
||||
RegisterSigningMethod(SigningMethodRS256.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodRS256.Alg(), func() SigningMethod[*rsa.PublicKey] {
|
||||
return SigningMethodRS256
|
||||
})
|
||||
|
||||
// RS384
|
||||
SigningMethodRS384 = &SigningMethodRSA{"RS384", crypto.SHA384}
|
||||
RegisterSigningMethod(SigningMethodRS384.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodRS384.Alg(), func() SigningMethod[*rsa.PublicKey] {
|
||||
return SigningMethodRS384
|
||||
})
|
||||
|
||||
// RS512
|
||||
SigningMethodRS512 = &SigningMethodRSA{"RS512", crypto.SHA512}
|
||||
RegisterSigningMethod(SigningMethodRS512.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodRS512.Alg(), func() SigningMethod[*rsa.PublicKey] {
|
||||
return SigningMethodRS512
|
||||
})
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ func (m *SigningMethodRSA) Alg() string {
|
|||
|
||||
// Verify implements token verification for the SigningMethod
|
||||
// For this signing method, must be an *rsa.PublicKey structure.
|
||||
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
|
||||
func (m *SigningMethodRSA) Verify(signingString, signature string, key *rsa.PublicKey) error {
|
||||
var err error
|
||||
|
||||
// Decode the signature
|
||||
|
@ -55,13 +55,6 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface
|
|||
return err
|
||||
}
|
||||
|
||||
var rsaKey *rsa.PublicKey
|
||||
var ok bool
|
||||
|
||||
if rsaKey, ok = key.(*rsa.PublicKey); !ok {
|
||||
return ErrInvalidKeyType
|
||||
}
|
||||
|
||||
// Create hasher
|
||||
if !m.Hash.Available() {
|
||||
return ErrHashUnavailable
|
||||
|
@ -70,7 +63,7 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface
|
|||
hasher.Write([]byte(signingString))
|
||||
|
||||
// Verify the signature
|
||||
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
|
||||
return rsa.VerifyPKCS1v15(key, m.Hash, hasher.Sum(nil), sig)
|
||||
}
|
||||
|
||||
// Sign implements token signing for the SigningMethod
|
||||
|
|
19
rsa_pss.go
19
rsa_pss.go
|
@ -1,3 +1,4 @@
|
|||
//go:build go1.4
|
||||
// +build go1.4
|
||||
|
||||
package jwt
|
||||
|
@ -40,7 +41,7 @@ func init() {
|
|||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
},
|
||||
}
|
||||
RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod[*rsa.PublicKey] {
|
||||
return SigningMethodPS256
|
||||
})
|
||||
|
||||
|
@ -57,7 +58,7 @@ func init() {
|
|||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
},
|
||||
}
|
||||
RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod[*rsa.PublicKey] {
|
||||
return SigningMethodPS384
|
||||
})
|
||||
|
||||
|
@ -74,14 +75,14 @@ func init() {
|
|||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
},
|
||||
}
|
||||
RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod {
|
||||
RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod[*rsa.PublicKey] {
|
||||
return SigningMethodPS512
|
||||
})
|
||||
}
|
||||
|
||||
// Verify implements token verification for the SigningMethod.
|
||||
// For this verify method, key must be an rsa.PublicKey struct
|
||||
func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error {
|
||||
func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key *rsa.PublicKey) error {
|
||||
var err error
|
||||
|
||||
// Decode the signature
|
||||
|
@ -90,14 +91,6 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf
|
|||
return err
|
||||
}
|
||||
|
||||
var rsaKey *rsa.PublicKey
|
||||
switch k := key.(type) {
|
||||
case *rsa.PublicKey:
|
||||
rsaKey = k
|
||||
default:
|
||||
return ErrInvalidKey
|
||||
}
|
||||
|
||||
// Create hasher
|
||||
if !m.Hash.Available() {
|
||||
return ErrHashUnavailable
|
||||
|
@ -110,7 +103,7 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf
|
|||
opts = m.VerifyOptions
|
||||
}
|
||||
|
||||
return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts)
|
||||
return rsa.VerifyPSS(key, m.Hash, hasher.Sum(nil), sig, opts)
|
||||
}
|
||||
|
||||
// Sign implements token signing for the SigningMethod.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package jwt_test
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -52,7 +53,7 @@ func TestRSAVerify(t *testing.T) {
|
|||
for _, data := range rsaTestData {
|
||||
parts := strings.Split(data.tokenString, ".")
|
||||
|
||||
method := jwt.GetSigningMethod(data.alg)
|
||||
method := jwt.GetSigningMethod[*rsa.PublicKey](data.alg)
|
||||
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], key)
|
||||
if data.valid && err != nil {
|
||||
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
|
||||
|
@ -70,7 +71,7 @@ func TestRSASign(t *testing.T) {
|
|||
for _, data := range rsaTestData {
|
||||
if data.valid {
|
||||
parts := strings.Split(data.tokenString, ".")
|
||||
method := jwt.GetSigningMethod(data.alg)
|
||||
method := jwt.GetSigningMethod[*rsa.PublicKey](data.alg)
|
||||
sig, err := method.Sign(strings.Join(parts[0:2], "."), key)
|
||||
if err != nil {
|
||||
t.Errorf("[%v] Error signing token: %v", data.name, err)
|
||||
|
|
|
@ -4,33 +4,33 @@ import (
|
|||
"sync"
|
||||
)
|
||||
|
||||
var signingMethods = map[string]func() SigningMethod{}
|
||||
//var signingMethods = map[string]SigningMethod[Key]{}
|
||||
var signingMethodLock = new(sync.RWMutex)
|
||||
|
||||
// SigningMethod can be used add new methods for signing or verifying tokens.
|
||||
type SigningMethod interface {
|
||||
Verify(signingString, signature string, key interface{}) error // Returns nil if signature is valid
|
||||
Sign(signingString string, key interface{}) (string, error) // Returns encoded signature or error
|
||||
Alg() string // returns the alg identifier for this method (example: 'HS256')
|
||||
type SigningMethod[T Key] interface {
|
||||
Verify(signingString, signature string, key T) error // Returns nil if signature is valid
|
||||
Sign(signingString string, key interface{}) (string, error) // Returns encoded signature or error
|
||||
Alg() string // returns the alg identifier for this method (example: 'HS256')
|
||||
}
|
||||
|
||||
// RegisterSigningMethod registers the "alg" name and a factory function for signing method.
|
||||
// This is typically done during init() in the method's implementation
|
||||
func RegisterSigningMethod(alg string, f func() SigningMethod) {
|
||||
func RegisterSigningMethod[T Key](alg string, f func() SigningMethod[T]) {
|
||||
signingMethodLock.Lock()
|
||||
defer signingMethodLock.Unlock()
|
||||
|
||||
signingMethods[alg] = f
|
||||
//signingMethods[alg] = f
|
||||
}
|
||||
|
||||
// GetSigningMethod retrieves a signing method from an "alg" string
|
||||
func GetSigningMethod(alg string) (method SigningMethod) {
|
||||
func GetSigningMethod[T Key](alg string) (method SigningMethod[T]) {
|
||||
signingMethodLock.RLock()
|
||||
defer signingMethodLock.RUnlock()
|
||||
|
||||
if methodF, ok := signingMethods[alg]; ok {
|
||||
/*if methodF, ok := signingMethods[alg]; ok {
|
||||
method = methodF()
|
||||
}
|
||||
}*/
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -39,8 +39,8 @@ func GetAlgorithms() (algs []string) {
|
|||
signingMethodLock.RLock()
|
||||
defer signingMethodLock.RUnlock()
|
||||
|
||||
for alg := range signingMethods {
|
||||
/*for alg := range signingMethods {
|
||||
algs = append(algs, alg)
|
||||
}
|
||||
}*/
|
||||
return
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ func LoadRSAPublicKeyFromDisk(location string) *rsa.PublicKey {
|
|||
}
|
||||
|
||||
// MakeSampleToken creates and returns a encoded JWT token that has been signed with the specified cryptographic key.
|
||||
func MakeSampleToken(c jwt.Claims, method jwt.SigningMethod, key interface{}) string {
|
||||
func MakeSampleToken[T jwt.Key](c jwt.Claims, method jwt.SigningMethod[T], key interface{}) string {
|
||||
token := jwt.NewWithClaims(method, c)
|
||||
s, e := token.SignedString(key)
|
||||
|
||||
|
|
35
token.go
35
token.go
|
@ -1,13 +1,16 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
||||
// DecodePaddingAllowed will switch the codec used for decoding JWTs respectively. Note that the JWS RFC7515
|
||||
// states that the tokens will utilize a Base64url encoding with no padding. Unfortunately, some implementations
|
||||
// of JWT are producing non-standard tokens, and thus require support for decoding. Note that this is a global
|
||||
|
@ -24,13 +27,23 @@ var TimeFunc = time.Now
|
|||
// the key for verification. The function receives the parsed,
|
||||
// but unverified Token. This allows you to use properties in the
|
||||
// Header of the token (such as `kid`) to identify which key to use.
|
||||
type Keyfunc func(*Token) (interface{}, error)
|
||||
type Keyfunc[T Key] func(*Token[T]) (T, error)
|
||||
|
||||
type Key interface {
|
||||
[]byte |
|
||||
*rsa.PublicKey |
|
||||
*ecdsa.PublicKey |
|
||||
*ed25519.PublicKey |
|
||||
ed25519.PublicKey |
|
||||
unsafeNoneMagicConstant |
|
||||
crypto.PublicKey // TODO: get rid of any in the future
|
||||
}
|
||||
|
||||
// Token represents a JWT Token. Different fields will be used depending on whether you're
|
||||
// creating or parsing/verifying a token.
|
||||
type Token struct {
|
||||
type Token[T Key] struct {
|
||||
Raw string // The raw token. Populated when you Parse a token
|
||||
Method SigningMethod // The signing method used or to be used
|
||||
Method SigningMethod[T] // The signing method used or to be used
|
||||
Header map[string]interface{} // The first segment of the token
|
||||
Claims Claims // The second segment of the token
|
||||
Signature string // The third segment of the token. Populated when you Parse a token
|
||||
|
@ -38,13 +51,13 @@ type Token struct {
|
|||
}
|
||||
|
||||
// New creates a new Token with the specified signing method and an empty map of claims.
|
||||
func New(method SigningMethod) *Token {
|
||||
func New[T Key](method SigningMethod[T]) *Token[T] {
|
||||
return NewWithClaims(method, MapClaims{})
|
||||
}
|
||||
|
||||
// NewWithClaims creates a new Token with the specified signing method and claims.
|
||||
func NewWithClaims(method SigningMethod, claims Claims) *Token {
|
||||
return &Token{
|
||||
func NewWithClaims[T Key](method SigningMethod[T], claims Claims) *Token[T] {
|
||||
return &Token[T]{
|
||||
Header: map[string]interface{}{
|
||||
"typ": "JWT",
|
||||
"alg": method.Alg(),
|
||||
|
@ -56,7 +69,7 @@ func NewWithClaims(method SigningMethod, claims Claims) *Token {
|
|||
|
||||
// SignedString creates and returns a complete, signed JWT.
|
||||
// The token is signed using the SigningMethod specified in the token.
|
||||
func (t *Token) SignedString(key interface{}) (string, error) {
|
||||
func (t *Token[T]) SignedString(key interface{}) (string, error) {
|
||||
var sig, sstr string
|
||||
var err error
|
||||
if sstr, err = t.SigningString(); err != nil {
|
||||
|
@ -72,7 +85,7 @@ func (t *Token) SignedString(key interface{}) (string, error) {
|
|||
// most expensive part of the whole deal. Unless you
|
||||
// need this for something special, just go straight for
|
||||
// the SignedString.
|
||||
func (t *Token) SigningString() (string, error) {
|
||||
func (t *Token[T]) SigningString() (string, error) {
|
||||
var err error
|
||||
var jsonValue []byte
|
||||
|
||||
|
@ -96,11 +109,11 @@ func (t *Token) SigningString() (string, error) {
|
|||
// validate the 'alg' claim in the token matches the expected algorithm.
|
||||
// For more details about the importance of validating the 'alg' claim,
|
||||
// see https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
|
||||
func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
|
||||
func Parse[T Key](tokenString string, keyFunc Keyfunc[T], options ...ParserOption[T]) (*Token[T], error) {
|
||||
return NewParser(options...).Parse(tokenString, keyFunc)
|
||||
}
|
||||
|
||||
func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
|
||||
func ParseWithClaims[T Key](tokenString string, claims Claims, keyFunc Keyfunc[T], options ...ParserOption[T]) (*Token[T], error) {
|
||||
return NewParser(options...).ParseWithClaims(tokenString, claims, keyFunc)
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
func TestToken_SigningString(t1 *testing.T) {
|
||||
type fields struct {
|
||||
Raw string
|
||||
Method jwt.SigningMethod
|
||||
Method jwt.SigningMethod[[]byte]
|
||||
Header map[string]interface{}
|
||||
Claims jwt.Claims
|
||||
Signature string
|
||||
|
@ -34,13 +34,13 @@ func TestToken_SigningString(t1 *testing.T) {
|
|||
Signature: "",
|
||||
Valid: false,
|
||||
},
|
||||
want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30",
|
||||
want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t1.Run(tt.name, func(t1 *testing.T) {
|
||||
t := &jwt.Token{
|
||||
t := &jwt.Token[[]byte]{
|
||||
Raw: tt.fields.Raw,
|
||||
Method: tt.fields.Method,
|
||||
Header: tt.fields.Header,
|
||||
|
@ -61,18 +61,18 @@ func TestToken_SigningString(t1 *testing.T) {
|
|||
}
|
||||
|
||||
func BenchmarkToken_SigningString(b *testing.B) {
|
||||
t := &jwt.Token{
|
||||
Method: jwt.SigningMethodHS256,
|
||||
t := &jwt.Token[[]byte]{
|
||||
Method: jwt.SigningMethodHS256,
|
||||
Header: map[string]interface{}{
|
||||
"typ": "JWT",
|
||||
"alg": jwt.SigningMethodHS256.Alg(),
|
||||
},
|
||||
Claims: jwt.StandardClaims{},
|
||||
Claims: jwt.StandardClaims{},
|
||||
}
|
||||
b.Run("BenchmarkToken_SigningString", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i<b.N; i++ {
|
||||
for i := 0; i < b.N; i++ {
|
||||
t.SigningString()
|
||||
}
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue