Implementing `Is(err) bool` to support Go 1.13 style error checking (#136)

This commit is contained in:
Christian Banse 2022-01-19 22:55:19 +01:00 committed by GitHub
parent 0fb40d3824
commit 78a18c0808
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 23 deletions

View File

@ -56,17 +56,17 @@ func (c RegisteredClaims) Valid() error {
// default value in Go, let's not fail the verification for them. // default value in Go, let's not fail the verification for them.
if !c.VerifyExpiresAt(now, false) { if !c.VerifyExpiresAt(now, false) {
delta := now.Sub(c.ExpiresAt.Time) delta := now.Sub(c.ExpiresAt.Time)
vErr.Inner = fmt.Errorf("token is expired by %v", delta) vErr.Inner = fmt.Errorf("%s by %v", delta, ErrTokenExpired)
vErr.Errors |= ValidationErrorExpired vErr.Errors |= ValidationErrorExpired
} }
if !c.VerifyIssuedAt(now, false) { if !c.VerifyIssuedAt(now, false) {
vErr.Inner = fmt.Errorf("token used before issued") vErr.Inner = ErrTokenUsedBeforeIssued
vErr.Errors |= ValidationErrorIssuedAt vErr.Errors |= ValidationErrorIssuedAt
} }
if !c.VerifyNotBefore(now, false) { if !c.VerifyNotBefore(now, false) {
vErr.Inner = fmt.Errorf("token is not valid yet") vErr.Inner = ErrTokenNotValidYet
vErr.Errors |= ValidationErrorNotValidYet vErr.Errors |= ValidationErrorNotValidYet
} }
@ -149,17 +149,17 @@ func (c StandardClaims) Valid() error {
// default value in Go, let's not fail the verification for them. // default value in Go, let's not fail the verification for them.
if !c.VerifyExpiresAt(now, false) { if !c.VerifyExpiresAt(now, false) {
delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0)) delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0))
vErr.Inner = fmt.Errorf("token is expired by %v", delta) vErr.Inner = fmt.Errorf("%s by %v", delta, ErrTokenExpired)
vErr.Errors |= ValidationErrorExpired vErr.Errors |= ValidationErrorExpired
} }
if !c.VerifyIssuedAt(now, false) { if !c.VerifyIssuedAt(now, false) {
vErr.Inner = fmt.Errorf("token used before issued") vErr.Inner = ErrTokenUsedBeforeIssued
vErr.Errors |= ValidationErrorIssuedAt vErr.Errors |= ValidationErrorIssuedAt
} }
if !c.VerifyNotBefore(now, false) { if !c.VerifyNotBefore(now, false) {
vErr.Inner = fmt.Errorf("token is not valid yet") vErr.Inner = ErrTokenNotValidYet
vErr.Errors |= ValidationErrorNotValidYet vErr.Errors |= ValidationErrorNotValidYet
} }

View File

@ -9,6 +9,18 @@ var (
ErrInvalidKey = errors.New("key is invalid") ErrInvalidKey = errors.New("key is invalid")
ErrInvalidKeyType = errors.New("key is of invalid type") ErrInvalidKeyType = errors.New("key is of invalid type")
ErrHashUnavailable = errors.New("the requested hash function is unavailable") ErrHashUnavailable = errors.New("the requested hash function is unavailable")
ErrTokenMalformed = errors.New("token is malformed")
ErrTokenUnverifiable = errors.New("token is unverifiable")
ErrTokenSignatureInvalid = errors.New("token signature is invalid")
ErrTokenInvalidAudience = errors.New("token has invalid audience")
ErrTokenExpired = errors.New("token is expired")
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
ErrTokenInvalidIssuer = errors.New("token has invalid issuer")
ErrTokenNotValidYet = errors.New("token is not valid yet")
ErrTokenInvalidId = errors.New("token has invalid id")
ErrTokenInvalidClaims = errors.New("token has invalid claims")
) )
// The errors that might occur when parsing and validating a token // The errors that might occur when parsing and validating a token
@ -62,3 +74,39 @@ func (e *ValidationError) Unwrap() error {
func (e *ValidationError) valid() bool { func (e *ValidationError) valid() bool {
return e.Errors == 0 return e.Errors == 0
} }
// Is checks if this ValidationError is of the supplied error. We are first checking for the exact error message
// by comparing the inner error message. If that fails, we compare using the error flags. This way we can use
// custom error messages (mainly for backwards compatability) and still leverage errors.Is using the global error variables.
func (e *ValidationError) Is(err error) bool {
// Check, if our inner error is a direct match
if errors.Is(errors.Unwrap(e), err) {
return true
}
// Otherwise, we need to match using our error flags
switch err {
case ErrTokenMalformed:
return e.Errors&ValidationErrorMalformed != 0
case ErrTokenUnverifiable:
return e.Errors&ValidationErrorUnverifiable != 0
case ErrTokenSignatureInvalid:
return e.Errors&ValidationErrorSignatureInvalid != 0
case ErrTokenInvalidAudience:
return e.Errors&ValidationErrorAudience != 0
case ErrTokenExpired:
return e.Errors&ValidationErrorExpired != 0
case ErrTokenUsedBeforeIssued:
return e.Errors&ValidationErrorIssuedAt != 0
case ErrTokenInvalidIssuer:
return e.Errors&ValidationErrorIssuer != 0
case ErrTokenNotValidYet:
return e.Errors&ValidationErrorNotValidYet != 0
case ErrTokenInvalidId:
return e.Errors&ValidationErrorId != 0
case ErrTokenInvalidClaims:
return e.Errors&ValidationErrorClaimsInvalid != 0
}
return false
}

View File

@ -1,6 +1,7 @@
package jwt_test package jwt_test
import ( import (
"errors"
"fmt" "fmt"
"time" "time"
@ -103,18 +104,14 @@ func ExampleParse_errorChecking() {
if token.Valid { if token.Valid {
fmt.Println("You look nice today") fmt.Println("You look nice today")
} else if ve, ok := err.(*jwt.ValidationError); ok { } else if errors.Is(err, jwt.ErrTokenMalformed) {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
fmt.Println("That's not even a token") fmt.Println("That's not even a token")
} else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { } else if errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet) {
// Token is either expired or not active yet // Token is either expired or not active yet
fmt.Println("Timing is everything") fmt.Println("Timing is everything")
} else { } else {
fmt.Println("Couldn't handle this token:", err) fmt.Println("Couldn't handle this token:", err)
} }
} else {
fmt.Println("Couldn't handle this token:", err)
}
// Output: Timing is everything // Output: Timing is everything
} }

View File

@ -126,16 +126,19 @@ func (m MapClaims) Valid() error {
now := TimeFunc().Unix() now := TimeFunc().Unix()
if !m.VerifyExpiresAt(now, false) { if !m.VerifyExpiresAt(now, false) {
// TODO(oxisto): this should be replaced with ErrTokenExpired
vErr.Inner = errors.New("Token is expired") vErr.Inner = errors.New("Token is expired")
vErr.Errors |= ValidationErrorExpired vErr.Errors |= ValidationErrorExpired
} }
if !m.VerifyIssuedAt(now, false) { if !m.VerifyIssuedAt(now, false) {
// TODO(oxisto): this should be replaced with ErrTokenUsedBeforeIssued
vErr.Inner = errors.New("Token used before issued") vErr.Inner = errors.New("Token used before issued")
vErr.Errors |= ValidationErrorIssuedAt vErr.Errors |= ValidationErrorIssuedAt
} }
if !m.VerifyNotBefore(now, false) { if !m.VerifyNotBefore(now, false) {
// TODO(oxisto): this should be replaced with ErrTokenNotValidYet
vErr.Inner = errors.New("Token is not valid yet") vErr.Inner = errors.New("Token is not valid yet")
vErr.Errors |= ValidationErrorNotValidYet vErr.Errors |= ValidationErrorNotValidYet
} }

View File

@ -4,6 +4,7 @@ import (
"crypto" "crypto"
"crypto/rsa" "crypto/rsa"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
@ -51,6 +52,7 @@ var jwtTestData = []struct {
claims jwt.Claims claims jwt.Claims
valid bool valid bool
errors uint32 errors uint32
err []error
parser *jwt.Parser parser *jwt.Parser
signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose
}{ }{
@ -62,6 +64,7 @@ var jwtTestData = []struct {
true, true,
0, 0,
nil, nil,
nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
{ {
@ -71,6 +74,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
false, false,
jwt.ValidationErrorExpired, jwt.ValidationErrorExpired,
[]error{jwt.ErrTokenExpired},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -81,6 +85,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
false, false,
jwt.ValidationErrorNotValidYet, jwt.ValidationErrorNotValidYet,
[]error{jwt.ErrTokenNotValidYet},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -91,6 +96,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
false, false,
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
[]error{jwt.ErrTokenNotValidYet},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -101,6 +107,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorSignatureInvalid, jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -111,6 +118,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorUnverifiable, jwt.ValidationErrorUnverifiable,
[]error{jwt.ErrTokenUnverifiable},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -121,6 +129,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorSignatureInvalid, jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -131,6 +140,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorUnverifiable, jwt.ValidationErrorUnverifiable,
[]error{jwt.ErrTokenUnverifiable, errKeyFuncError},
nil, nil,
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -141,6 +151,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorSignatureInvalid, jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid},
&jwt.Parser{ValidMethods: []string{"HS256"}}, &jwt.Parser{ValidMethods: []string{"HS256"}},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -151,6 +162,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
true, true,
0, 0,
nil,
&jwt.Parser{ValidMethods: []string{"RS256", "HS256"}}, &jwt.Parser{ValidMethods: []string{"RS256", "HS256"}},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -161,6 +173,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
false, false,
jwt.ValidationErrorSignatureInvalid, jwt.ValidationErrorSignatureInvalid,
[]error{jwt.ErrTokenSignatureInvalid},
&jwt.Parser{ValidMethods: []string{"RS256", "HS256"}}, &jwt.Parser{ValidMethods: []string{"RS256", "HS256"}},
jwt.SigningMethodES256, jwt.SigningMethodES256,
}, },
@ -171,6 +184,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"}, jwt.MapClaims{"foo": "bar"},
true, true,
0, 0,
nil,
&jwt.Parser{ValidMethods: []string{"HS256", "ES256"}}, &jwt.Parser{ValidMethods: []string{"HS256", "ES256"}},
jwt.SigningMethodES256, jwt.SigningMethodES256,
}, },
@ -181,6 +195,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": json.Number("123.4")}, jwt.MapClaims{"foo": json.Number("123.4")},
true, true,
0, 0,
nil,
&jwt.Parser{UseJSONNumber: true}, &jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -193,6 +208,7 @@ var jwtTestData = []struct {
}, },
true, true,
0, 0,
nil,
&jwt.Parser{UseJSONNumber: true}, &jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -203,6 +219,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
false, false,
jwt.ValidationErrorExpired, jwt.ValidationErrorExpired,
[]error{jwt.ErrTokenExpired},
&jwt.Parser{UseJSONNumber: true}, &jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -213,6 +230,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
false, false,
jwt.ValidationErrorNotValidYet, jwt.ValidationErrorNotValidYet,
[]error{jwt.ErrTokenNotValidYet},
&jwt.Parser{UseJSONNumber: true}, &jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -223,6 +241,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
false, false,
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
[]error{jwt.ErrTokenNotValidYet},
&jwt.Parser{UseJSONNumber: true}, &jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -233,6 +252,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
true, true,
0, 0,
nil,
&jwt.Parser{UseJSONNumber: true, SkipClaimsValidation: true}, &jwt.Parser{UseJSONNumber: true, SkipClaimsValidation: true},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -245,6 +265,7 @@ var jwtTestData = []struct {
}, },
true, true,
0, 0,
nil,
&jwt.Parser{UseJSONNumber: true}, &jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -257,6 +278,7 @@ var jwtTestData = []struct {
}, },
true, true,
0, 0,
nil,
&jwt.Parser{UseJSONNumber: true}, &jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -269,6 +291,7 @@ var jwtTestData = []struct {
}, },
true, true,
0, 0,
nil,
&jwt.Parser{UseJSONNumber: true}, &jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -281,6 +304,7 @@ var jwtTestData = []struct {
}, },
false, false,
jwt.ValidationErrorMalformed, jwt.ValidationErrorMalformed,
[]error{jwt.ErrTokenMalformed},
&jwt.Parser{UseJSONNumber: true}, &jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -293,6 +317,7 @@ var jwtTestData = []struct {
}, },
false, false,
jwt.ValidationErrorMalformed, jwt.ValidationErrorMalformed,
[]error{jwt.ErrTokenMalformed},
&jwt.Parser{UseJSONNumber: true}, &jwt.Parser{UseJSONNumber: true},
jwt.SigningMethodRS256, jwt.SigningMethodRS256,
}, },
@ -325,6 +350,7 @@ func TestParser_Parse(t *testing.T) {
// Parse the token // Parse the token
var token *jwt.Token var token *jwt.Token
var ve *jwt.ValidationError
var err error var err error
var parser = data.parser var parser = data.parser
if parser == nil { if parser == nil {
@ -361,8 +387,7 @@ func TestParser_Parse(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("[%v] Expecting error. Didn't get one.", data.name) t.Errorf("[%v] Expecting error. Didn't get one.", data.name)
} else { } else {
if errors.As(err, &ve) {
ve := err.(*jwt.ValidationError)
// compare the bitfield part of the error // compare the bitfield part of the error
if e := ve.Errors; e != data.errors { if e := ve.Errors; e != data.errors {
t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors) t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors)
@ -373,6 +398,23 @@ func TestParser_Parse(t *testing.T) {
} }
} }
} }
}
if data.err != nil {
if err == nil {
t.Errorf("[%v] Expecting error(s). Didn't get one.", data.name)
} else {
var all = false
for _, e := range data.err {
all = errors.Is(err, e)
}
if !all {
t.Errorf("[%v] Errors don't match expectation. %v should contain all of %v", data.name, err, data.err)
}
}
}
if data.valid { if data.valid {
if token.Signature == "" { if token.Signature == "" {
t.Errorf("[%v] Signature is left unpopulated after parsing", data.name) t.Errorf("[%v] Signature is left unpopulated after parsing", data.name)