jwt/ecdsa.go

153 lines
3.7 KiB
Go
Raw Normal View History

2015-07-16 21:26:45 +03:00
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"errors"
"math/big"
)
var (
// Sadly this is missing from crypto/ecdsa compared to crypto/rsa
ErrECDSAVerification = errors.New("crypto/ecdsa: verification error")
)
// SigningMethodECDSA Implements the ECDSA family of signing methods signing methods
// Expects *ecdsa.PrivateKey for signing and *ecdsa.PublicKey for verification
2015-07-16 21:26:45 +03:00
type SigningMethodECDSA struct {
2015-09-17 05:53:08 +03:00
Name string
Hash crypto.Hash
KeySize int
CurveBits int
2015-07-16 21:26:45 +03:00
}
// Specific instances for EC256 and company
var (
SigningMethodES256 *SigningMethodECDSA
SigningMethodES384 *SigningMethodECDSA
SigningMethodES512 *SigningMethodECDSA
)
func init() {
// ES256
2015-09-17 05:53:08 +03:00
SigningMethodES256 = &SigningMethodECDSA{"ES256", crypto.SHA256, 32, 256}
2015-07-16 21:26:45 +03:00
RegisterSigningMethod(SigningMethodES256.Alg(), func() SigningMethod {
return SigningMethodES256
})
// ES384
2015-09-17 05:53:08 +03:00
SigningMethodES384 = &SigningMethodECDSA{"ES384", crypto.SHA384, 48, 384}
2015-07-16 21:26:45 +03:00
RegisterSigningMethod(SigningMethodES384.Alg(), func() SigningMethod {
return SigningMethodES384
})
// ES512
2015-09-17 05:53:08 +03:00
SigningMethodES512 = &SigningMethodECDSA{"ES512", crypto.SHA512, 66, 521}
2015-07-16 21:26:45 +03:00
RegisterSigningMethod(SigningMethodES512.Alg(), func() SigningMethod {
return SigningMethodES512
})
}
func (m *SigningMethodECDSA) Alg() string {
return m.Name
}
// Verify Implements the Verify method from SigningMethod
2015-07-16 21:26:45 +03:00
// For this verify method, key must be an ecdsa.PublicKey struct
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
var err error
// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}
// Get the key
var ecdsaKey *ecdsa.PublicKey
switch k := key.(type) {
case *ecdsa.PublicKey:
ecdsaKey = k
default:
return ErrInvalidKeyType
2015-07-16 21:26:45 +03:00
}
2015-09-17 05:53:08 +03:00
if len(sig) != 2*m.KeySize {
return ErrECDSAVerification
2015-07-16 21:26:45 +03:00
}
2015-09-17 05:53:08 +03:00
r := big.NewInt(0).SetBytes(sig[:m.KeySize])
s := big.NewInt(0).SetBytes(sig[m.KeySize:])
2015-07-16 21:26:45 +03:00
// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
if _, err = hasher.Write([]byte(signingString)); err != nil {
return err
}
2015-07-16 21:26:45 +03:00
// Verify the signature
2021-05-28 02:26:21 +03:00
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus {
2015-07-16 21:26:45 +03:00
return nil
}
2021-05-28 02:26:21 +03:00
return ErrECDSAVerification
2015-07-16 21:26:45 +03:00
}
// Sign Implements the Sign method from SigningMethod
2015-07-16 21:26:45 +03:00
// For this signing method, key must be an ecdsa.PrivateKey struct
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
// Get the key
var ecdsaKey *ecdsa.PrivateKey
switch k := key.(type) {
case *ecdsa.PrivateKey:
ecdsaKey = k
default:
return "", ErrInvalidKeyType
2015-07-16 21:26:45 +03:00
}
// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
}
hasher := m.Hash.New()
if _, err := hasher.Write([]byte(signingString)); err != nil {
return "", err
}
2015-07-16 21:26:45 +03:00
// Sign the string and return r, s
if r, s, err := ecdsa.Sign(rand.Reader, ecdsaKey, hasher.Sum(nil)); err == nil {
2015-09-17 05:53:08 +03:00
curveBits := ecdsaKey.Curve.Params().BitSize
if m.CurveBits != curveBits {
return "", ErrInvalidKey
2015-07-16 21:26:45 +03:00
}
2015-09-17 05:53:08 +03:00
keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes += 1
}
// We serialize the outpus (r and s) into big-endian byte arrays and pad
// them with zeros on the left to make sure the sizes work out. Both arrays
// must be keyBytes long, and the output must be 2*keyBytes long.
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
out := append(rBytesPadded, sBytesPadded...)
return EncodeSegment(out), nil
2015-07-16 21:26:45 +03:00
} else {
return "", err
}
}