Compare commits

...

1 Commits

Author SHA1 Message Date
Christian Banse 73c8300abd Playing around with Go 1.18 generics 2022-02-18 00:59:12 +01:00
20 changed files with 121 additions and 137 deletions

View File

@ -128,7 +128,7 @@ func verifyToken() error {
}
// Parse the token. Load the key from command line option
token, err := jwt.Parse(string(tokData), func(t *jwt.Token) (interface{}, error) {
token, err := jwt.Parse[any](string(tokData), func(t *jwt.Token[any]) (any, error) {
if isNone() {
return jwt.UnsafeAllowNoneSignatureType, nil
}
@ -206,7 +206,7 @@ func signToken() error {
}
// get the signing alg
alg := jwt.GetSigningMethod(*flagAlg)
alg := jwt.GetSigningMethod[any](*flagAlg)
if alg == nil {
return fmt.Errorf("couldn't find signing method: %v", *flagAlg)
}
@ -273,7 +273,7 @@ func showToken() error {
fmt.Fprintf(os.Stderr, "Token len: %v bytes\n", len(tokData))
}
token, err := jwt.Parse(string(tokData), nil)
token, err := jwt.Parse[any](string(tokData), nil)
if token == nil {
return fmt.Errorf("malformed token: %w", err)
}

View File

@ -32,19 +32,19 @@ var (
func init() {
// ES256
SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256}
RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod[*ecdsa.PublicKey] {
return SigningMethodES256
})
// ES384
SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384}
RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod[*ecdsa.PublicKey] {
return SigningMethodES384
})
// ES512
SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521}
RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod[*ecdsa.PublicKey] {
return SigningMethodES512
})
}
@ -55,7 +55,7 @@ func (m *SigningMethodECDSA) Alg() string {
// Verify implements token verification for the SigningMethod.
// For this verify method, key must be an ecdsa.PublicKey struct
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
func (m *SigningMethodECDSA) Verify(signingString, signature string, key *ecdsa.PublicKey) error {
var err error
// Decode the signature
@ -64,15 +64,6 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa
return err
}
// Get the key
var ecdsaKey *ecdsa.PublicKey
switch k := key.(type) {
case *ecdsa.PublicKey:
ecdsaKey = k
default:
return ErrInvalidKeyType
}
if len(sig) != 2*m.KeySize {
return ErrECDSAVerification
}
@ -88,7 +79,7 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa
hasher.Write([]byte(signingString))
// Verify the signature
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus {
if verifystatus := ecdsa.Verify(key, hasher.Sum(nil), r, s); verifystatus {
return nil
}
@ -115,6 +106,9 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
hasher := m.Hash.New()
hasher.Write([]byte(signingString))
var s crypto.Signer = ecdsaKey
s.Sign(rand.Reader, hasher.Sum(nil), nil)
// Sign the string and return r, s
if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
curveBits := ecdsaKey.Curve.Params().BitSize

View File

@ -64,7 +64,7 @@ func TestECDSAVerify(t *testing.T) {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
method := jwt.GetSigningMethod[*ecdsa.PublicKey](data.alg)
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ecdsaKey)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
@ -88,7 +88,7 @@ func TestECDSASign(t *testing.T) {
if data.valid {
parts := strings.Split(data.tokenString, ".")
toSign := strings.Join(parts[0:2], ".")
method := jwt.GetSigningMethod(data.alg)
method := jwt.GetSigningMethod[*ecdsa.PublicKey](data.alg)
sig, err := method.Sign(toSign, ecdsaKey)
if err != nil {
@ -98,7 +98,7 @@ func TestECDSASign(t *testing.T) {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)
}
err = method.Verify(toSign, sig, ecdsaKey.Public())
err = method.Verify(toSign, sig, ecdsaKey.Public().(*ecdsa.PublicKey))
if err != nil {
t.Errorf("[%v] Sign produced an invalid signature: %v", data.name, err)
}
@ -133,7 +133,7 @@ func BenchmarkECDSASigning(b *testing.B) {
b.Fatalf("Unable to parse ECDSA private key: %v", err)
}
method := jwt.GetSigningMethod(data.alg)
method := jwt.GetSigningMethod[*ecdsa.PublicKey](data.alg)
b.Run(data.name, func(b *testing.B) {
benchmarkSigning(b, method, ecdsaKey)

View File

@ -23,7 +23,7 @@ var (
func init() {
SigningMethodEdDSA = &SigningMethodEd25519{}
RegisterSigningMethod(SigningMethodEdDSA.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodEdDSA.Alg(), func() SigningMethod[ed25519.PublicKey] {
return SigningMethodEdDSA
})
}
@ -34,14 +34,13 @@ func (m *SigningMethodEd25519) Alg() string {
// Verify implements token verification for the SigningMethod.
// For this verify method, key must be an ed25519.PublicKey
func (m *SigningMethodEd25519) Verify(signingString, signature string, key interface{}) error {
func (m *SigningMethodEd25519) Verify(signingString, signature string, key ed25519.PublicKey) error {
var err error
var ed25519Key ed25519.PublicKey
var ok bool
if ed25519Key, ok = key.(ed25519.PublicKey); !ok {
/*if ed25519Key, ok = key.(ed25519.PublicKey); !ok {
return ErrInvalidKeyType
}
}*/
if len(ed25519Key) != ed25519.PublicKeySize {
return ErrInvalidKey

View File

@ -6,6 +6,7 @@ import (
"testing"
"github.com/golang-jwt/jwt/v4"
"golang.org/x/crypto/ed25519"
)
var ed25519TestData = []struct {
@ -47,7 +48,7 @@ func TestEd25519Verify(t *testing.T) {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
method := jwt.GetSigningMethod[*ed25519.PublicKey](data.alg)
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ed25519Key)
if data.valid && err != nil {
@ -71,7 +72,7 @@ func TestEd25519Sign(t *testing.T) {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
method := jwt.GetSigningMethod[*ed25519.PublicKey](data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), ed25519Key)
if err != nil {

View File

@ -39,7 +39,7 @@ func ParseEdPrivateKeyFromPEM(key []byte) (crypto.PrivateKey, error) {
}
// ParseEdPublicKeyFromPEM parses a PEM-encoded Edwards curve public key
func ParseEdPublicKeyFromPEM(key []byte) (crypto.PublicKey, error) {
func ParseEdPublicKeyFromPEM(key []byte) (*ed25519.PublicKey, error) {
var err error
// Parse PEM block
@ -60,5 +60,5 @@ func ParseEdPublicKeyFromPEM(key []byte) (crypto.PublicKey, error) {
return nil, ErrNotEdPublicKey
}
return pkey, nil
return &pkey, nil
}

2
go.mod
View File

@ -1,3 +1,3 @@
module github.com/golang-jwt/jwt/v4
go 1.15
go 1.18

16
hmac.go
View File

@ -24,19 +24,19 @@ var (
func init() {
// HS256
SigningMethodHS256 = &SigningMethodHMAC{"HS256", crypto.SHA256}
RegisterSigningMethod(SigningMethodHS256.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodHS256.Alg(), func() SigningMethod[[]byte] {
return SigningMethodHS256
})
// HS384
SigningMethodHS384 = &SigningMethodHMAC{"HS384", crypto.SHA384}
RegisterSigningMethod(SigningMethodHS384.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodHS384.Alg(), func() SigningMethod[[]byte] {
return SigningMethodHS384
})
// HS512
SigningMethodHS512 = &SigningMethodHMAC{"HS512", crypto.SHA512}
RegisterSigningMethod(SigningMethodHS512.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodHS512.Alg(), func() SigningMethod[[]byte] {
return SigningMethodHS512
})
}
@ -46,13 +46,7 @@ func (m *SigningMethodHMAC) Alg() string {
}
// Verify implements token verification for the SigningMethod. Returns nil if the signature is valid.
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
// Verify the key is the right type
keyBytes, ok := key.([]byte)
if !ok {
return ErrInvalidKeyType
}
func (m *SigningMethodHMAC) Verify(signingString, signature string, key []byte) error {
// Decode signature, for comparison
sig, err := DecodeSegment(signature)
if err != nil {
@ -67,7 +61,7 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac
// This signing method is symmetric, so we validate the signature
// by reproducing the signature from the signing string and key, then
// comparing that against the provided signature.
hasher := hmac.New(m.Hash.New, keyBytes)
hasher := hmac.New(m.Hash.New, key)
hasher.Write([]byte(signingString))
if !hmac.Equal(sig, hasher.Sum(nil)) {
return ErrSignatureInvalid

View File

@ -15,7 +15,7 @@ func init() {
SigningMethodNone = &signingMethodNone{}
NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid)
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod[unsafeNoneMagicConstant] {
return SigningMethodNone
})
}
@ -25,12 +25,7 @@ func (m *signingMethodNone) Alg() string {
}
// Only allow 'none' alg type if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Verify(signingString, signature string, key interface{}) (err error) {
// Key must be UnsafeAllowNoneSignatureType to prevent accidentally
// accepting 'none' signing method
if _, ok := key.(unsafeNoneMagicConstant); !ok {
return NoneSignatureTypeDisallowedError
}
func (m *signingMethodNone) Verify(signingString, signature string, key unsafeNoneMagicConstant) (err error) {
// If signing method is none, signature must be an empty string
if signature != "" {
return NewValidationError(

View File

@ -7,7 +7,7 @@ import (
"strings"
)
type Parser struct {
type Parser[T Key] struct {
// If populated, only these methods will be considered valid.
//
// Deprecated: In future releases, this field will not be exported anymore and should be set with an option to NewParser instead.
@ -25,8 +25,8 @@ type Parser struct {
}
// NewParser creates a new Parser with the specified options
func NewParser(options ...ParserOption) *Parser {
p := &Parser{}
func NewParser[T Key](options ...ParserOption[T]) *Parser[T] {
p := &Parser[T]{}
// loop through our parsing options and apply them
for _, option := range options {
@ -38,11 +38,11 @@ func NewParser(options ...ParserOption) *Parser {
// Parse parses, validates, verifies the signature and returns the parsed token.
// keyFunc will receive the parsed token and should return the key for validating.
func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
func (p *Parser[T]) Parse(tokenString string, keyFunc Keyfunc[T]) (*Token[T], error) {
return p.ParseWithClaims(tokenString, MapClaims{}, keyFunc)
}
func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) {
func (p *Parser[T]) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc[T]) (*Token[T], error) {
token, parts, err := p.ParseUnverified(tokenString, claims)
if err != nil {
return token, err
@ -65,7 +65,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
}
// Lookup key
var key interface{}
var key T
if keyFunc == nil {
// keyFunc was not provided. short circuiting validation
return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable)
@ -115,13 +115,13 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
//
// It's only ever useful in cases where you know the signature is valid (because it has
// been checked previously in the stack) and you want to extract values from it.
func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) {
func (p *Parser[T]) ParseUnverified(tokenString string, claims Claims) (token *Token[T], parts []string, err error) {
parts = strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
}
token = &Token{Raw: tokenString}
token = &Token[T]{Raw: tokenString}
// parse Header
var headerBytes []byte
@ -159,7 +159,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
// Lookup signature method
if method, ok := token.Header["alg"].(string); ok {
if token.Method = GetSigningMethod(method); token.Method == nil {
if token.Method = GetSigningMethod[T](method); token.Method == nil {
return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable)
}
} else {

View File

@ -3,27 +3,27 @@ package jwt
// ParserOption is used to implement functional-style options that modify the behavior of the parser. To add
// new options, just create a function (ideally beginning with With or Without) that returns an anonymous function that
// takes a *Parser type as input and manipulates its configuration accordingly.
type ParserOption func(*Parser)
type ParserOption[T Key] func(*Parser[T])
// WithValidMethods is an option to supply algorithm methods that the parser will check. Only those methods will be considered valid.
// It is heavily encouraged to use this option in order to prevent attacks such as https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/.
func WithValidMethods(methods []string) ParserOption {
return func(p *Parser) {
func WithValidMethods[T Key](methods []string) ParserOption[T] {
return func(p *Parser[T]) {
p.ValidMethods = methods
}
}
// WithJSONNumber is an option to configure the underlying JSON parser with UseNumber
func WithJSONNumber() ParserOption {
return func(p *Parser) {
func WithJSONNumber[T Key]() ParserOption[T] {
return func(p *Parser[T]) {
p.UseJSONNumber = true
}
}
// WithoutClaimsValidation is an option to disable claims validation. This option should only be used if you exactly know
// what you are doing.
func WithoutClaimsValidation() ParserOption {
return func(p *Parser) {
func WithoutClaimsValidation[T Key]() ParserOption[T] {
return func(p *Parser[T]) {
p.SkipClaimsValidation = true
}
}

View File

@ -12,9 +12,9 @@ import (
// the logic for extracting a token. Several useful implementations are provided.
//
// You can provide options to modify parsing behavior
func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc, options ...ParseFromRequestOption) (token *jwt.Token, err error) {
func ParseFromRequest[T jwt.Key](req *http.Request, extractor Extractor, keyFunc jwt.Keyfunc[T], options ...ParseFromRequestOption[T]) (token *jwt.Token[T], err error) {
// Create basic parser struct
p := &fromRequestParser{req, extractor, nil, nil}
p := &fromRequestParser[T]{req, extractor, nil, nil}
// Handle options
for _, option := range options {
@ -26,7 +26,7 @@ func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfun
p.claims = jwt.MapClaims{}
}
if p.parser == nil {
p.parser = &jwt.Parser{}
p.parser = &jwt.Parser[T]{}
}
// perform extract
@ -42,29 +42,29 @@ func ParseFromRequest(req *http.Request, extractor Extractor, keyFunc jwt.Keyfun
// ParseFromRequestWithClaims is an alias for ParseFromRequest but with custom Claims type.
//
// Deprecated: use ParseFromRequest and the WithClaims option
func ParseFromRequestWithClaims(req *http.Request, extractor Extractor, claims jwt.Claims, keyFunc jwt.Keyfunc) (token *jwt.Token, err error) {
return ParseFromRequest(req, extractor, keyFunc, WithClaims(claims))
func ParseFromRequestWithClaims[T jwt.Key](req *http.Request, extractor Extractor, claims jwt.Claims, keyFunc jwt.Keyfunc[T]) (token *jwt.Token[T], err error) {
return ParseFromRequest(req, extractor, keyFunc, WithClaims[T](claims))
}
type fromRequestParser struct {
type fromRequestParser[T jwt.Key] struct {
req *http.Request
extractor Extractor
claims jwt.Claims
parser *jwt.Parser
parser *jwt.Parser[T]
}
type ParseFromRequestOption func(*fromRequestParser)
type ParseFromRequestOption[T jwt.Key] func(*fromRequestParser[T])
// WithClaims parses with custom claims
func WithClaims(claims jwt.Claims) ParseFromRequestOption {
return func(p *fromRequestParser) {
func WithClaims[T jwt.Key](claims jwt.Claims) ParseFromRequestOption[T] {
return func(p *fromRequestParser[T]) {
p.claims = claims
}
}
// WithParser parses using a custom parser
func WithParser(parser *jwt.Parser) ParseFromRequestOption {
return func(p *fromRequestParser) {
func WithParser[T jwt.Key](parser *jwt.Parser[T]) ParseFromRequestOption[T] {
return func(p *fromRequestParser[T]) {
p.parser = parser
}
}

View File

@ -1,6 +1,7 @@
package request
import (
"crypto/rsa"
"fmt"
"net/http"
"net/url"
@ -58,14 +59,14 @@ func TestParseRequest(t *testing.T) {
// load keys from disk
privateKey := test.LoadRSAPrivateKeyFromDisk("../test/sample_key")
publicKey := test.LoadRSAPublicKeyFromDisk("../test/sample_key.pub")
keyfunc := func(*jwt.Token) (interface{}, error) {
keyfunc := func(*jwt.Token[*rsa.PublicKey]) (*rsa.PublicKey, error) {
return publicKey, nil
}
// Bearer token request
for _, data := range requestTestData {
// Make token from claims
tokenString := test.MakeSampleToken(data.claims, jwt.SigningMethodRS256, privateKey)
tokenString := test.MakeSampleToken[*rsa.PublicKey](data.claims, jwt.SigningMethodRS256, privateKey)
// Make query string
for k, vv := range data.query {

17
rsa.go
View File

@ -23,19 +23,19 @@ var (
func init() {
// RS256
SigningMethodRS256 = &SigningMethodRSA{"RS256", crypto.SHA256}
RegisterSigningMethod(SigningMethodRS256.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodRS256.Alg(), func() SigningMethod[*rsa.PublicKey] {
return SigningMethodRS256
})
// RS384
SigningMethodRS384 = &SigningMethodRSA{"RS384", crypto.SHA384}
RegisterSigningMethod(SigningMethodRS384.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodRS384.Alg(), func() SigningMethod[*rsa.PublicKey] {
return SigningMethodRS384
})
// RS512
SigningMethodRS512 = &SigningMethodRSA{"RS512", crypto.SHA512}
RegisterSigningMethod(SigningMethodRS512.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodRS512.Alg(), func() SigningMethod[*rsa.PublicKey] {
return SigningMethodRS512
})
}
@ -46,7 +46,7 @@ func (m *SigningMethodRSA) Alg() string {
// Verify implements token verification for the SigningMethod
// For this signing method, must be an *rsa.PublicKey structure.
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
func (m *SigningMethodRSA) Verify(signingString, signature string, key *rsa.PublicKey) error {
var err error
// Decode the signature
@ -55,13 +55,6 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface
return err
}
var rsaKey *rsa.PublicKey
var ok bool
if rsaKey, ok = key.(*rsa.PublicKey); !ok {
return ErrInvalidKeyType
}
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
@ -70,7 +63,7 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface
hasher.Write([]byte(signingString))
// Verify the signature
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
return rsa.VerifyPKCS1v15(key, m.Hash, hasher.Sum(nil), sig)
}
// Sign implements token signing for the SigningMethod

View File

@ -1,3 +1,4 @@
//go:build go1.4
// +build go1.4
package jwt
@ -40,7 +41,7 @@ func init() {
SaltLength: rsa.PSSSaltLengthAuto,
},
}
RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodPS256.Alg(), func() SigningMethod[*rsa.PublicKey] {
return SigningMethodPS256
})
@ -57,7 +58,7 @@ func init() {
SaltLength: rsa.PSSSaltLengthAuto,
},
}
RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodPS384.Alg(), func() SigningMethod[*rsa.PublicKey] {
return SigningMethodPS384
})
@ -74,14 +75,14 @@ func init() {
SaltLength: rsa.PSSSaltLengthAuto,
},
}
RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod {
RegisterSigningMethod(SigningMethodPS512.Alg(), func() SigningMethod[*rsa.PublicKey] {
return SigningMethodPS512
})
}
// Verify implements token verification for the SigningMethod.
// For this verify method, key must be an rsa.PublicKey struct
func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error {
func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key *rsa.PublicKey) error {
var err error
// Decode the signature
@ -90,14 +91,6 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf
return err
}
var rsaKey *rsa.PublicKey
switch k := key.(type) {
case *rsa.PublicKey:
rsaKey = k
default:
return ErrInvalidKey
}
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
@ -110,7 +103,7 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf
opts = m.VerifyOptions
}
return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts)
return rsa.VerifyPSS(key, m.Hash, hasher.Sum(nil), sig, opts)
}
// Sign implements token signing for the SigningMethod.

View File

@ -1,6 +1,7 @@
package jwt_test
import (
"crypto/rsa"
"io/ioutil"
"strings"
"testing"
@ -52,7 +53,7 @@ func TestRSAVerify(t *testing.T) {
for _, data := range rsaTestData {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
method := jwt.GetSigningMethod[*rsa.PublicKey](data.alg)
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], key)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
@ -70,7 +71,7 @@ func TestRSASign(t *testing.T) {
for _, data := range rsaTestData {
if data.valid {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
method := jwt.GetSigningMethod[*rsa.PublicKey](data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), key)
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)

View File

@ -4,33 +4,33 @@ import (
"sync"
)
var signingMethods = map[string]func() SigningMethod{}
//var signingMethods = map[string]SigningMethod[Key]{}
var signingMethodLock = new(sync.RWMutex)
// SigningMethod can be used add new methods for signing or verifying tokens.
type SigningMethod interface {
Verify(signingString, signature string, key interface{}) error // Returns nil if signature is valid
Sign(signingString string, key interface{}) (string, error) // Returns encoded signature or error
Alg() string // returns the alg identifier for this method (example: 'HS256')
type SigningMethod[T Key] interface {
Verify(signingString, signature string, key T) error // Returns nil if signature is valid
Sign(signingString string, key interface{}) (string, error) // Returns encoded signature or error
Alg() string // returns the alg identifier for this method (example: 'HS256')
}
// RegisterSigningMethod registers the "alg" name and a factory function for signing method.
// This is typically done during init() in the method's implementation
func RegisterSigningMethod(alg string, f func() SigningMethod) {
func RegisterSigningMethod[T Key](alg string, f func() SigningMethod[T]) {
signingMethodLock.Lock()
defer signingMethodLock.Unlock()
signingMethods[alg] = f
//signingMethods[alg] = f
}
// GetSigningMethod retrieves a signing method from an "alg" string
func GetSigningMethod(alg string) (method SigningMethod) {
func GetSigningMethod[T Key](alg string) (method SigningMethod[T]) {
signingMethodLock.RLock()
defer signingMethodLock.RUnlock()
if methodF, ok := signingMethods[alg]; ok {
/*if methodF, ok := signingMethods[alg]; ok {
method = methodF()
}
}*/
return
}
@ -39,8 +39,8 @@ func GetAlgorithms() (algs []string) {
signingMethodLock.RLock()
defer signingMethodLock.RUnlock()
for alg := range signingMethods {
/*for alg := range signingMethods {
algs = append(algs, alg)
}
}*/
return
}

View File

@ -33,7 +33,7 @@ func LoadRSAPublicKeyFromDisk(location string) *rsa.PublicKey {
}
// MakeSampleToken creates and returns a encoded JWT token that has been signed with the specified cryptographic key.
func MakeSampleToken(c jwt.Claims, method jwt.SigningMethod, key interface{}) string {
func MakeSampleToken[T jwt.Key](c jwt.Claims, method jwt.SigningMethod[T], key interface{}) string {
token := jwt.NewWithClaims(method, c)
s, e := token.SignedString(key)

View File

@ -1,13 +1,16 @@
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"strings"
"time"
)
// DecodePaddingAllowed will switch the codec used for decoding JWTs respectively. Note that the JWS RFC7515
// states that the tokens will utilize a Base64url encoding with no padding. Unfortunately, some implementations
// of JWT are producing non-standard tokens, and thus require support for decoding. Note that this is a global
@ -24,13 +27,23 @@ var TimeFunc = time.Now
// the key for verification. The function receives the parsed,
// but unverified Token. This allows you to use properties in the
// Header of the token (such as `kid`) to identify which key to use.
type Keyfunc func(*Token) (interface{}, error)
type Keyfunc[T Key] func(*Token[T]) (T, error)
type Key interface {
[]byte |
*rsa.PublicKey |
*ecdsa.PublicKey |
*ed25519.PublicKey |
ed25519.PublicKey |
unsafeNoneMagicConstant |
crypto.PublicKey // TODO: get rid of any in the future
}
// Token represents a JWT Token. Different fields will be used depending on whether you're
// creating or parsing/verifying a token.
type Token struct {
type Token[T Key] struct {
Raw string // The raw token. Populated when you Parse a token
Method SigningMethod // The signing method used or to be used
Method SigningMethod[T] // The signing method used or to be used
Header map[string]interface{} // The first segment of the token
Claims Claims // The second segment of the token
Signature string // The third segment of the token. Populated when you Parse a token
@ -38,13 +51,13 @@ type Token struct {
}
// New creates a new Token with the specified signing method and an empty map of claims.
func New(method SigningMethod) *Token {
func New[T Key](method SigningMethod[T]) *Token[T] {
return NewWithClaims(method, MapClaims{})
}
// NewWithClaims creates a new Token with the specified signing method and claims.
func NewWithClaims(method SigningMethod, claims Claims) *Token {
return &Token{
func NewWithClaims[T Key](method SigningMethod[T], claims Claims) *Token[T] {
return &Token[T]{
Header: map[string]interface{}{
"typ": "JWT",
"alg": method.Alg(),
@ -56,7 +69,7 @@ func NewWithClaims(method SigningMethod, claims Claims) *Token {
// SignedString creates and returns a complete, signed JWT.
// The token is signed using the SigningMethod specified in the token.
func (t *Token) SignedString(key interface{}) (string, error) {
func (t *Token[T]) SignedString(key interface{}) (string, error) {
var sig, sstr string
var err error
if sstr, err = t.SigningString(); err != nil {
@ -72,7 +85,7 @@ func (t *Token) SignedString(key interface{}) (string, error) {
// most expensive part of the whole deal. Unless you
// need this for something special, just go straight for
// the SignedString.
func (t *Token) SigningString() (string, error) {
func (t *Token[T]) SigningString() (string, error) {
var err error
var jsonValue []byte
@ -96,11 +109,11 @@ func (t *Token) SigningString() (string, error) {
// validate the 'alg' claim in the token matches the expected algorithm.
// For more details about the importance of validating the 'alg' claim,
// see https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
func Parse[T Key](tokenString string, keyFunc Keyfunc[T], options ...ParserOption[T]) (*Token[T], error) {
return NewParser(options...).Parse(tokenString, keyFunc)
}
func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
func ParseWithClaims[T Key](tokenString string, claims Claims, keyFunc Keyfunc[T], options ...ParserOption[T]) (*Token[T], error) {
return NewParser(options...).ParseWithClaims(tokenString, claims, keyFunc)
}

View File

@ -9,7 +9,7 @@ import (
func TestToken_SigningString(t1 *testing.T) {
type fields struct {
Raw string
Method jwt.SigningMethod
Method jwt.SigningMethod[[]byte]
Header map[string]interface{}
Claims jwt.Claims
Signature string
@ -34,13 +34,13 @@ func TestToken_SigningString(t1 *testing.T) {
Signature: "",
Valid: false,
},
want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30",
want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30",
wantErr: false,
},
}
for _, tt := range tests {
t1.Run(tt.name, func(t1 *testing.T) {
t := &jwt.Token{
t := &jwt.Token[[]byte]{
Raw: tt.fields.Raw,
Method: tt.fields.Method,
Header: tt.fields.Header,
@ -61,18 +61,18 @@ func TestToken_SigningString(t1 *testing.T) {
}
func BenchmarkToken_SigningString(b *testing.B) {
t := &jwt.Token{
Method: jwt.SigningMethodHS256,
t := &jwt.Token[[]byte]{
Method: jwt.SigningMethodHS256,
Header: map[string]interface{}{
"typ": "JWT",
"alg": jwt.SigningMethodHS256.Alg(),
},
Claims: jwt.StandardClaims{},
Claims: jwt.StandardClaims{},
}
b.Run("BenchmarkToken_SigningString", func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i<b.N; i++ {
for i := 0; i < b.N; i++ {
t.SigningString()
}
})