From 873d96d0a0787c65eda77ac708f557fbc3f9f590 Mon Sep 17 00:00:00 2001 From: Oleksandr Redko Date: Tue, 18 Jul 2023 09:44:48 +0300 Subject: [PATCH] Refactor code by using switch instead of if-else (#318) --- cmd/jwt/main.go | 85 ++++++++++++++++++++++++++----------------------- example_test.go | 11 ++++--- hmac_test.go | 21 ++++++------ none_test.go | 21 ++++++------ rsa_pss_test.go | 23 ++++++------- 5 files changed, 85 insertions(+), 76 deletions(-) diff --git a/cmd/jwt/main.go b/cmd/jwt/main.go index f1e49a9..0f49c70 100644 --- a/cmd/jwt/main.go +++ b/cmd/jwt/main.go @@ -60,13 +60,14 @@ func main() { // Figure out which thing to do and then do that func start() error { - if *flagSign != "" { + switch { + case *flagSign != "": return signToken() - } else if *flagVerify != "" { + case *flagVerify != "": return verifyToken() - } else if *flagShow != "" { + case *flagShow != "": return showToken() - } else { + default: flag.Usage() return fmt.Errorf("none of the required flags are present. What do you want me to do?") } @@ -79,17 +80,18 @@ func loadData(p string) ([]byte, error) { } var rdr io.Reader - if p == "-" { + switch p { + case "-": rdr = os.Stdin - } else if p == "+" { + case "+": return []byte("{}"), nil - } else { - if f, err := os.Open(p); err == nil { - rdr = f - defer f.Close() - } else { + default: + f, err := os.Open(p) + if err != nil { return nil, err } + rdr = f + defer f.Close() } return io.ReadAll(rdr) } @@ -136,14 +138,16 @@ func verifyToken() error { if err != nil { return nil, err } - if isEs() { + switch { + case isEs(): return jwt.ParseECPublicKeyFromPEM(data) - } else if isRs() { + case isRs(): return jwt.ParseRSAPublicKeyFromPEM(data) - } else if isEd() { + case isEd(): return jwt.ParseEdPublicKeyFromPEM(data) + default: + return data, nil } - return data, nil }) // Print some debug data @@ -221,40 +225,41 @@ func signToken() error { } } - if isEs() { - if k, ok := key.([]byte); !ok { + switch { + case isEs(): + k, ok := key.([]byte) + if !ok { return fmt.Errorf("couldn't convert key data to key") - } else { - key, err = jwt.ParseECPrivateKeyFromPEM(k) - if err != nil { - return err - } } - } else if isRs() { - if k, ok := key.([]byte); !ok { - return fmt.Errorf("couldn't convert key data to key") - } else { - key, err = jwt.ParseRSAPrivateKeyFromPEM(k) - if err != nil { - return err - } + key, err = jwt.ParseECPrivateKeyFromPEM(k) + if err != nil { + return err } - } else if isEd() { - if k, ok := key.([]byte); !ok { + case isRs(): + k, ok := key.([]byte) + if !ok { return fmt.Errorf("couldn't convert key data to key") - } else { - key, err = jwt.ParseEdPrivateKeyFromPEM(k) - if err != nil { - return err - } + } + key, err = jwt.ParseRSAPrivateKeyFromPEM(k) + if err != nil { + return err + } + case isEd(): + k, ok := key.([]byte) + if !ok { + return fmt.Errorf("couldn't convert key data to key") + } + key, err = jwt.ParseEdPrivateKeyFromPEM(k) + if err != nil { + return err } } - if out, err := token.SignedString(key); err == nil { - fmt.Println(out) - } else { + out, err := token.SignedString(key) + if err != nil { return fmt.Errorf("error signing token: %w", err) } + fmt.Println(out) return nil } diff --git a/example_test.go b/example_test.go index c2a55bd..0df3124 100644 --- a/example_test.go +++ b/example_test.go @@ -163,17 +163,18 @@ func ExampleParse_errorChecking() { return []byte("AllYourBase"), nil }) - if token.Valid { + switch { + case token.Valid: fmt.Println("You look nice today") - } else if errors.Is(err, jwt.ErrTokenMalformed) { + case errors.Is(err, jwt.ErrTokenMalformed): fmt.Println("That's not even a token") - } else if errors.Is(err, jwt.ErrTokenSignatureInvalid) { + case errors.Is(err, jwt.ErrTokenSignatureInvalid): // Invalid signature fmt.Println("Invalid signature") - } else if errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet) { + case errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet): // Token is either expired or not active yet fmt.Println("Timing is everything") - } else { + default: fmt.Println("Couldn't handle this token:", err) } diff --git a/hmac_test.go b/hmac_test.go index 264a2a4..3eb0380 100644 --- a/hmac_test.go +++ b/hmac_test.go @@ -66,16 +66,17 @@ func TestHMACVerify(t *testing.T) { func TestHMACSign(t *testing.T) { for _, data := range hmacTestData { - if data.valid { - parts := strings.Split(data.tokenString, ".") - method := jwt.GetSigningMethod(data.alg) - sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey) - if err != nil { - t.Errorf("[%v] Error signing token: %v", data.name, err) - } - if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { - t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) - } + if !data.valid { + continue + } + parts := strings.Split(data.tokenString, ".") + method := jwt.GetSigningMethod(data.alg) + sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey) + if err != nil { + t.Errorf("[%v] Error signing token: %v", data.name, err) + } + if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { + t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) } } } diff --git a/none_test.go b/none_test.go index d370cf8..f126b14 100644 --- a/none_test.go +++ b/none_test.go @@ -59,16 +59,17 @@ func TestNoneVerify(t *testing.T) { func TestNoneSign(t *testing.T) { for _, data := range noneTestData { - if data.valid { - parts := strings.Split(data.tokenString, ".") - method := jwt.GetSigningMethod(data.alg) - sig, err := method.Sign(strings.Join(parts[0:2], "."), data.key) - if err != nil { - t.Errorf("[%v] Error signing token: %v", data.name, err) - } - if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { - t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) - } + if !data.valid { + continue + } + parts := strings.Split(data.tokenString, ".") + method := jwt.GetSigningMethod(data.alg) + sig, err := method.Sign(strings.Join(parts[0:2], "."), data.key) + if err != nil { + t.Errorf("[%v] Error signing token: %v", data.name, err) + } + if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { + t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) } } } diff --git a/rsa_pss_test.go b/rsa_pss_test.go index 9707a75..536cde6 100644 --- a/rsa_pss_test.go +++ b/rsa_pss_test.go @@ -84,18 +84,19 @@ func TestRSAPSSSign(t *testing.T) { } for _, data := range rsaPSSTestData { - if data.valid { - parts := strings.Split(data.tokenString, ".") - method := jwt.GetSigningMethod(data.alg) - sig, err := method.Sign(strings.Join(parts[0:2], "."), rsaPSSKey) - if err != nil { - t.Errorf("[%v] Error signing token: %v", data.name, err) - } + if !data.valid { + continue + } + parts := strings.Split(data.tokenString, ".") + method := jwt.GetSigningMethod(data.alg) + sig, err := method.Sign(strings.Join(parts[0:2], "."), rsaPSSKey) + if err != nil { + t.Errorf("[%v] Error signing token: %v", data.name, err) + } - ssig := encodeSegment(sig) - if ssig == parts[2] { - t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, ssig, parts[2]) - } + ssig := encodeSegment(sig) + if ssig == parts[2] { + t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, ssig, parts[2]) } } }