mirror of https://github.com/golang-jwt/jwt.git
Refactor code by using switch instead of if-else (#318)
This commit is contained in:
parent
f53600aa9f
commit
873d96d0a0
|
@ -60,13 +60,14 @@ func main() {
|
||||||
|
|
||||||
// Figure out which thing to do and then do that
|
// Figure out which thing to do and then do that
|
||||||
func start() error {
|
func start() error {
|
||||||
if *flagSign != "" {
|
switch {
|
||||||
|
case *flagSign != "":
|
||||||
return signToken()
|
return signToken()
|
||||||
} else if *flagVerify != "" {
|
case *flagVerify != "":
|
||||||
return verifyToken()
|
return verifyToken()
|
||||||
} else if *flagShow != "" {
|
case *flagShow != "":
|
||||||
return showToken()
|
return showToken()
|
||||||
} else {
|
default:
|
||||||
flag.Usage()
|
flag.Usage()
|
||||||
return fmt.Errorf("none of the required flags are present. What do you want me to do?")
|
return fmt.Errorf("none of the required flags are present. What do you want me to do?")
|
||||||
}
|
}
|
||||||
|
@ -79,17 +80,18 @@ func loadData(p string) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var rdr io.Reader
|
var rdr io.Reader
|
||||||
if p == "-" {
|
switch p {
|
||||||
|
case "-":
|
||||||
rdr = os.Stdin
|
rdr = os.Stdin
|
||||||
} else if p == "+" {
|
case "+":
|
||||||
return []byte("{}"), nil
|
return []byte("{}"), nil
|
||||||
} else {
|
default:
|
||||||
if f, err := os.Open(p); err == nil {
|
f, err := os.Open(p)
|
||||||
rdr = f
|
if err != nil {
|
||||||
defer f.Close()
|
|
||||||
} else {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
rdr = f
|
||||||
|
defer f.Close()
|
||||||
}
|
}
|
||||||
return io.ReadAll(rdr)
|
return io.ReadAll(rdr)
|
||||||
}
|
}
|
||||||
|
@ -136,14 +138,16 @@ func verifyToken() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if isEs() {
|
switch {
|
||||||
|
case isEs():
|
||||||
return jwt.ParseECPublicKeyFromPEM(data)
|
return jwt.ParseECPublicKeyFromPEM(data)
|
||||||
} else if isRs() {
|
case isRs():
|
||||||
return jwt.ParseRSAPublicKeyFromPEM(data)
|
return jwt.ParseRSAPublicKeyFromPEM(data)
|
||||||
} else if isEd() {
|
case isEd():
|
||||||
return jwt.ParseEdPublicKeyFromPEM(data)
|
return jwt.ParseEdPublicKeyFromPEM(data)
|
||||||
|
default:
|
||||||
|
return data, nil
|
||||||
}
|
}
|
||||||
return data, nil
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// Print some debug data
|
// Print some debug data
|
||||||
|
@ -221,40 +225,41 @@ func signToken() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isEs() {
|
switch {
|
||||||
if k, ok := key.([]byte); !ok {
|
case isEs():
|
||||||
|
k, ok := key.([]byte)
|
||||||
|
if !ok {
|
||||||
return fmt.Errorf("couldn't convert key data to key")
|
return fmt.Errorf("couldn't convert key data to key")
|
||||||
} else {
|
|
||||||
key, err = jwt.ParseECPrivateKeyFromPEM(k)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else if isRs() {
|
key, err = jwt.ParseECPrivateKeyFromPEM(k)
|
||||||
if k, ok := key.([]byte); !ok {
|
if err != nil {
|
||||||
return fmt.Errorf("couldn't convert key data to key")
|
return err
|
||||||
} else {
|
|
||||||
key, err = jwt.ParseRSAPrivateKeyFromPEM(k)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else if isEd() {
|
case isRs():
|
||||||
if k, ok := key.([]byte); !ok {
|
k, ok := key.([]byte)
|
||||||
|
if !ok {
|
||||||
return fmt.Errorf("couldn't convert key data to key")
|
return fmt.Errorf("couldn't convert key data to key")
|
||||||
} else {
|
}
|
||||||
key, err = jwt.ParseEdPrivateKeyFromPEM(k)
|
key, err = jwt.ParseRSAPrivateKeyFromPEM(k)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
case isEd():
|
||||||
|
k, ok := key.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("couldn't convert key data to key")
|
||||||
|
}
|
||||||
|
key, err = jwt.ParseEdPrivateKeyFromPEM(k)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if out, err := token.SignedString(key); err == nil {
|
out, err := token.SignedString(key)
|
||||||
fmt.Println(out)
|
if err != nil {
|
||||||
} else {
|
|
||||||
return fmt.Errorf("error signing token: %w", err)
|
return fmt.Errorf("error signing token: %w", err)
|
||||||
}
|
}
|
||||||
|
fmt.Println(out)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -163,17 +163,18 @@ func ExampleParse_errorChecking() {
|
||||||
return []byte("AllYourBase"), nil
|
return []byte("AllYourBase"), nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if token.Valid {
|
switch {
|
||||||
|
case token.Valid:
|
||||||
fmt.Println("You look nice today")
|
fmt.Println("You look nice today")
|
||||||
} else if errors.Is(err, jwt.ErrTokenMalformed) {
|
case errors.Is(err, jwt.ErrTokenMalformed):
|
||||||
fmt.Println("That's not even a token")
|
fmt.Println("That's not even a token")
|
||||||
} else if errors.Is(err, jwt.ErrTokenSignatureInvalid) {
|
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
|
||||||
// Invalid signature
|
// Invalid signature
|
||||||
fmt.Println("Invalid signature")
|
fmt.Println("Invalid signature")
|
||||||
} else if errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet) {
|
case errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet):
|
||||||
// Token is either expired or not active yet
|
// Token is either expired or not active yet
|
||||||
fmt.Println("Timing is everything")
|
fmt.Println("Timing is everything")
|
||||||
} else {
|
default:
|
||||||
fmt.Println("Couldn't handle this token:", err)
|
fmt.Println("Couldn't handle this token:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
21
hmac_test.go
21
hmac_test.go
|
@ -66,16 +66,17 @@ func TestHMACVerify(t *testing.T) {
|
||||||
|
|
||||||
func TestHMACSign(t *testing.T) {
|
func TestHMACSign(t *testing.T) {
|
||||||
for _, data := range hmacTestData {
|
for _, data := range hmacTestData {
|
||||||
if data.valid {
|
if !data.valid {
|
||||||
parts := strings.Split(data.tokenString, ".")
|
continue
|
||||||
method := jwt.GetSigningMethod(data.alg)
|
}
|
||||||
sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey)
|
parts := strings.Split(data.tokenString, ".")
|
||||||
if err != nil {
|
method := jwt.GetSigningMethod(data.alg)
|
||||||
t.Errorf("[%v] Error signing token: %v", data.name, err)
|
sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey)
|
||||||
}
|
if err != nil {
|
||||||
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
|
t.Errorf("[%v] Error signing token: %v", data.name, err)
|
||||||
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
|
}
|
||||||
}
|
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
|
||||||
|
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
21
none_test.go
21
none_test.go
|
@ -59,16 +59,17 @@ func TestNoneVerify(t *testing.T) {
|
||||||
|
|
||||||
func TestNoneSign(t *testing.T) {
|
func TestNoneSign(t *testing.T) {
|
||||||
for _, data := range noneTestData {
|
for _, data := range noneTestData {
|
||||||
if data.valid {
|
if !data.valid {
|
||||||
parts := strings.Split(data.tokenString, ".")
|
continue
|
||||||
method := jwt.GetSigningMethod(data.alg)
|
}
|
||||||
sig, err := method.Sign(strings.Join(parts[0:2], "."), data.key)
|
parts := strings.Split(data.tokenString, ".")
|
||||||
if err != nil {
|
method := jwt.GetSigningMethod(data.alg)
|
||||||
t.Errorf("[%v] Error signing token: %v", data.name, err)
|
sig, err := method.Sign(strings.Join(parts[0:2], "."), data.key)
|
||||||
}
|
if err != nil {
|
||||||
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
|
t.Errorf("[%v] Error signing token: %v", data.name, err)
|
||||||
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
|
}
|
||||||
}
|
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
|
||||||
|
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,18 +84,19 @@ func TestRSAPSSSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, data := range rsaPSSTestData {
|
for _, data := range rsaPSSTestData {
|
||||||
if data.valid {
|
if !data.valid {
|
||||||
parts := strings.Split(data.tokenString, ".")
|
continue
|
||||||
method := jwt.GetSigningMethod(data.alg)
|
}
|
||||||
sig, err := method.Sign(strings.Join(parts[0:2], "."), rsaPSSKey)
|
parts := strings.Split(data.tokenString, ".")
|
||||||
if err != nil {
|
method := jwt.GetSigningMethod(data.alg)
|
||||||
t.Errorf("[%v] Error signing token: %v", data.name, err)
|
sig, err := method.Sign(strings.Join(parts[0:2], "."), rsaPSSKey)
|
||||||
}
|
if err != nil {
|
||||||
|
t.Errorf("[%v] Error signing token: %v", data.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
ssig := encodeSegment(sig)
|
ssig := encodeSegment(sig)
|
||||||
if ssig == parts[2] {
|
if ssig == parts[2] {
|
||||||
t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, ssig, parts[2])
|
t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, ssig, parts[2])
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue