Refactor code by using switch instead of if-else (#318)

This commit is contained in:
Oleksandr Redko 2023-07-18 09:44:48 +03:00 committed by GitHub
parent f53600aa9f
commit 873d96d0a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 85 additions and 76 deletions

View File

@ -60,13 +60,14 @@ func main() {
// Figure out which thing to do and then do that // Figure out which thing to do and then do that
func start() error { func start() error {
if *flagSign != "" { switch {
case *flagSign != "":
return signToken() return signToken()
} else if *flagVerify != "" { case *flagVerify != "":
return verifyToken() return verifyToken()
} else if *flagShow != "" { case *flagShow != "":
return showToken() return showToken()
} else { default:
flag.Usage() flag.Usage()
return fmt.Errorf("none of the required flags are present. What do you want me to do?") return fmt.Errorf("none of the required flags are present. What do you want me to do?")
} }
@ -79,17 +80,18 @@ func loadData(p string) ([]byte, error) {
} }
var rdr io.Reader var rdr io.Reader
if p == "-" { switch p {
case "-":
rdr = os.Stdin rdr = os.Stdin
} else if p == "+" { case "+":
return []byte("{}"), nil return []byte("{}"), nil
} else { default:
if f, err := os.Open(p); err == nil { f, err := os.Open(p)
rdr = f if err != nil {
defer f.Close()
} else {
return nil, err return nil, err
} }
rdr = f
defer f.Close()
} }
return io.ReadAll(rdr) return io.ReadAll(rdr)
} }
@ -136,14 +138,16 @@ func verifyToken() error {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if isEs() { switch {
case isEs():
return jwt.ParseECPublicKeyFromPEM(data) return jwt.ParseECPublicKeyFromPEM(data)
} else if isRs() { case isRs():
return jwt.ParseRSAPublicKeyFromPEM(data) return jwt.ParseRSAPublicKeyFromPEM(data)
} else if isEd() { case isEd():
return jwt.ParseEdPublicKeyFromPEM(data) return jwt.ParseEdPublicKeyFromPEM(data)
default:
return data, nil
} }
return data, nil
}) })
// Print some debug data // Print some debug data
@ -221,40 +225,41 @@ func signToken() error {
} }
} }
if isEs() { switch {
if k, ok := key.([]byte); !ok { case isEs():
k, ok := key.([]byte)
if !ok {
return fmt.Errorf("couldn't convert key data to key") return fmt.Errorf("couldn't convert key data to key")
} else {
key, err = jwt.ParseECPrivateKeyFromPEM(k)
if err != nil {
return err
}
} }
} else if isRs() { key, err = jwt.ParseECPrivateKeyFromPEM(k)
if k, ok := key.([]byte); !ok { if err != nil {
return fmt.Errorf("couldn't convert key data to key") return err
} else {
key, err = jwt.ParseRSAPrivateKeyFromPEM(k)
if err != nil {
return err
}
} }
} else if isEd() { case isRs():
if k, ok := key.([]byte); !ok { k, ok := key.([]byte)
if !ok {
return fmt.Errorf("couldn't convert key data to key") return fmt.Errorf("couldn't convert key data to key")
} else { }
key, err = jwt.ParseEdPrivateKeyFromPEM(k) key, err = jwt.ParseRSAPrivateKeyFromPEM(k)
if err != nil { if err != nil {
return err return err
} }
case isEd():
k, ok := key.([]byte)
if !ok {
return fmt.Errorf("couldn't convert key data to key")
}
key, err = jwt.ParseEdPrivateKeyFromPEM(k)
if err != nil {
return err
} }
} }
if out, err := token.SignedString(key); err == nil { out, err := token.SignedString(key)
fmt.Println(out) if err != nil {
} else {
return fmt.Errorf("error signing token: %w", err) return fmt.Errorf("error signing token: %w", err)
} }
fmt.Println(out)
return nil return nil
} }

View File

@ -163,17 +163,18 @@ func ExampleParse_errorChecking() {
return []byte("AllYourBase"), nil return []byte("AllYourBase"), nil
}) })
if token.Valid { switch {
case token.Valid:
fmt.Println("You look nice today") fmt.Println("You look nice today")
} else if errors.Is(err, jwt.ErrTokenMalformed) { case errors.Is(err, jwt.ErrTokenMalformed):
fmt.Println("That's not even a token") fmt.Println("That's not even a token")
} else if errors.Is(err, jwt.ErrTokenSignatureInvalid) { case errors.Is(err, jwt.ErrTokenSignatureInvalid):
// Invalid signature // Invalid signature
fmt.Println("Invalid signature") fmt.Println("Invalid signature")
} else if errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet) { case errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet):
// Token is either expired or not active yet // Token is either expired or not active yet
fmt.Println("Timing is everything") fmt.Println("Timing is everything")
} else { default:
fmt.Println("Couldn't handle this token:", err) fmt.Println("Couldn't handle this token:", err)
} }

View File

@ -66,16 +66,17 @@ func TestHMACVerify(t *testing.T) {
func TestHMACSign(t *testing.T) { func TestHMACSign(t *testing.T) {
for _, data := range hmacTestData { for _, data := range hmacTestData {
if data.valid { if !data.valid {
parts := strings.Split(data.tokenString, ".") continue
method := jwt.GetSigningMethod(data.alg) }
sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey) parts := strings.Split(data.tokenString, ".")
if err != nil { method := jwt.GetSigningMethod(data.alg)
t.Errorf("[%v] Error signing token: %v", data.name, err) sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey)
} if err != nil {
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { t.Errorf("[%v] Error signing token: %v", data.name, err)
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) }
} if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
} }
} }
} }

View File

@ -59,16 +59,17 @@ func TestNoneVerify(t *testing.T) {
func TestNoneSign(t *testing.T) { func TestNoneSign(t *testing.T) {
for _, data := range noneTestData { for _, data := range noneTestData {
if data.valid { if !data.valid {
parts := strings.Split(data.tokenString, ".") continue
method := jwt.GetSigningMethod(data.alg) }
sig, err := method.Sign(strings.Join(parts[0:2], "."), data.key) parts := strings.Split(data.tokenString, ".")
if err != nil { method := jwt.GetSigningMethod(data.alg)
t.Errorf("[%v] Error signing token: %v", data.name, err) sig, err := method.Sign(strings.Join(parts[0:2], "."), data.key)
} if err != nil {
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) { t.Errorf("[%v] Error signing token: %v", data.name, err)
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2]) }
} if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
} }
} }
} }

View File

@ -84,18 +84,19 @@ func TestRSAPSSSign(t *testing.T) {
} }
for _, data := range rsaPSSTestData { for _, data := range rsaPSSTestData {
if data.valid { if !data.valid {
parts := strings.Split(data.tokenString, ".") continue
method := jwt.GetSigningMethod(data.alg) }
sig, err := method.Sign(strings.Join(parts[0:2], "."), rsaPSSKey) parts := strings.Split(data.tokenString, ".")
if err != nil { method := jwt.GetSigningMethod(data.alg)
t.Errorf("[%v] Error signing token: %v", data.name, err) sig, err := method.Sign(strings.Join(parts[0:2], "."), rsaPSSKey)
} if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
ssig := encodeSegment(sig) ssig := encodeSegment(sig)
if ssig == parts[2] { if ssig == parts[2] {
t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, ssig, parts[2]) t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, ssig, parts[2])
}
} }
} }
} }