diff --git a/context.go b/context.go index 4259a6bc..c78e8c30 100644 --- a/context.go +++ b/context.go @@ -126,6 +126,11 @@ func (c *Context) AbortWithStatus(code int) { c.Abort() } +func (c *Context) AbortWithError(code int, err error) *errorMsg { + c.AbortWithStatus(code) + return c.Error(err) +} + func (c *Context) IsAborted() bool { return c.index == AbortIndex } @@ -134,30 +139,16 @@ func (c *Context) IsAborted() bool { /********* ERROR MANAGEMENT *********/ /************************************/ -// Fail is the same as Abort plus an error message. -// Calling `context.Fail(500, err)` is equivalent to: -// ``` -// context.Error("Operation aborted", err) -// context.AbortWithStatus(500) -// ``` -func (c *Context) Fail(code int, err error) { - c.Error(err, "Operation aborted") - c.AbortWithStatus(code) -} - -func (c *Context) ErrorTyped(err error, typ int, meta interface{}) { - c.Errors = append(c.Errors, errorMsg{ - Error: err, - Flags: typ, - Meta: meta, - }) -} - // Attaches an error to the current context. The error is pushed to a list of errors. // It's a good idea to call Error for each error that occurred during the resolution of a request. // A middleware can be used to collect all the errors and push them to a database together, print a log, or append it in the HTTP response. -func (c *Context) Error(err error, meta interface{}) { - c.ErrorTyped(err, ErrorTypeExternal, meta) +func (c *Context) Error(err error) *errorMsg { + newError := &errorMsg{ + Error: err, + Flags: ErrorTypePrivate, + } + c.Errors = append(c.Errors, newError) + return newError } func (c *Context) LastError() error { @@ -168,6 +159,35 @@ func (c *Context) LastError() error { return nil } +/************************************/ +/******** METADATA MANAGEMENT********/ +/************************************/ + +// Sets a new pair key/value just for the specified context. +// It also lazy initializes the hashmap. +func (c *Context) Set(key string, value interface{}) { + if c.Keys == nil { + c.Keys = make(map[string]interface{}) + } + c.Keys[key] = value +} + +// Get returns the value for the given key or an error if the key does not exist. +func (c *Context) Get(key string) (value interface{}, exists bool) { + if c.Keys != nil { + value, exists = c.Keys[key] + } + return +} + +// MustGet returns the value for the given key or panics if the value doesn't exist. +func (c *Context) MustGet(key string) interface{} { + if value, exists := c.Get(key); exists { + return value + } + panic("Key \"" + key + "\" does not exist") +} + /************************************/ /************ INPUT DATA ************/ /************************************/ @@ -233,40 +253,29 @@ func (c *Context) postFormValue(key string) (string, bool) { return "", false } -/************************************/ -/******** METADATA MANAGEMENT********/ -/************************************/ - -// Sets a new pair key/value just for the specified context. -// It also lazy initializes the hashmap. -func (c *Context) Set(key string, value interface{}) { - if c.Keys == nil { - c.Keys = make(map[string]interface{}) - } - c.Keys[key] = value +// This function checks the Content-Type to select a binding engine automatically, +// Depending the "Content-Type" header different bindings are used: +// "application/json" --> JSON binding +// "application/xml" --> XML binding +// else --> returns an error +// if Parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input. It decodes the json payload into the struct specified as a pointer.Like ParseBody() but this method also writes a 400 error if the json is not valid. +func (c *Context) Bind(obj interface{}) error { + b := binding.Default(c.Request.Method, c.ContentType()) + return c.BindWith(obj, b) } -// Get returns the value for the given key or an error if the key does not exist. -func (c *Context) Get(key string) (value interface{}, exists bool) { - if c.Keys != nil { - value, exists = c.Keys[key] - } - return +func (c *Context) BindJSON(obj interface{}) error { + return c.BindWith(obj, binding.JSON) } -// MustGet returns the value for the given key or panics if the value doesn't exist. -func (c *Context) MustGet(key string) interface{} { - if value, exists := c.Get(key); exists { - return value - } else { - panic("Key \"" + key + "\" does not exist") +func (c *Context) BindWith(obj interface{}, b binding.Binding) error { + if err := b.Bind(c.Request, obj); err != nil { + c.AbortWithError(400, err).Type(ErrorTypeBind) + return err } + return nil } -/************************************/ -/********* PARSING REQUEST **********/ -/************************************/ - func (c *Context) ClientIP() string { clientIP := c.Request.Header.Get("X-Real-IP") if len(clientIP) > 0 { @@ -284,25 +293,6 @@ func (c *Context) ContentType() string { return filterFlags(c.Request.Header.Get("Content-Type")) } -// This function checks the Content-Type to select a binding engine automatically, -// Depending the "Content-Type" header different bindings are used: -// "application/json" --> JSON binding -// "application/xml" --> XML binding -// else --> returns an error -// if Parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input. It decodes the json payload into the struct specified as a pointer.Like ParseBody() but this method also writes a 400 error if the json is not valid. -func (c *Context) Bind(obj interface{}) bool { - b := binding.Default(c.Request.Method, c.ContentType()) - return c.BindWith(obj, b) -} - -func (c *Context) BindWith(obj interface{}, b binding.Binding) bool { - if err := b.Bind(c.Request, obj); err != nil { - c.Fail(400, err) - return false - } - return true -} - /************************************/ /******** RESPONSE RENDERING ********/ /************************************/ @@ -319,8 +309,7 @@ func (c *Context) Render(code int, r render.Render) { c.Writer.WriteHeader(code) if err := r.Write(c.Writer); err != nil { debugPrintError(err) - c.ErrorTyped(err, ErrorTypeInternal, nil) - c.AbortWithStatus(500) + c.AbortWithError(500, err).Type(ErrorTypeRender) } } @@ -430,7 +419,7 @@ func (c *Context) Negotiate(code int, config Negotiate) { c.XML(code, data) default: - c.Fail(http.StatusNotAcceptable, errors.New("the accepted formats are not offered by the server")) + c.AbortWithError(http.StatusNotAcceptable, errors.New("the accepted formats are not offered by the server")) } } @@ -458,6 +447,10 @@ func (c *Context) SetAccepted(formats ...string) { c.Accepted = formats } +/************************************/ +/******** CONTENT NEGOTIATION *******/ +/************************************/ + func (c *Context) Deadline() (deadline time.Time, ok bool) { return } diff --git a/context_test.go b/context_test.go index eacca3d8..cdf549fe 100644 --- a/context_test.go +++ b/context_test.go @@ -41,7 +41,7 @@ func TestContextReset(t *testing.T) { c.index = 2 c.Writer = &responseWriter{ResponseWriter: httptest.NewRecorder()} c.Params = Params{Param{}} - c.Error(errors.New("test"), nil) + c.Error(errors.New("test")) c.Set("foo", "bar") c.reset() @@ -352,41 +352,41 @@ func TestContextError(t *testing.T) { assert.Nil(t, c.LastError()) assert.Empty(t, c.Errors.String()) - c.Error(errors.New("first error"), "some data") + c.Error(errors.New("first error")).Meta("some data") assert.Equal(t, c.LastError().Error(), "first error") assert.Len(t, c.Errors, 1) assert.Equal(t, c.Errors.String(), "Error #01: first error\n Meta: some data\n") - c.Error(errors.New("second error"), "some data 2") + c.Error(errors.New("second error")).Meta("some data 2") assert.Equal(t, c.LastError().Error(), "second error") assert.Len(t, c.Errors, 2) assert.Equal(t, c.Errors.String(), "Error #01: first error\n Meta: some data\n"+ "Error #02: second error\n Meta: some data 2\n") assert.Equal(t, c.Errors[0].Error, errors.New("first error")) - assert.Equal(t, c.Errors[0].Meta, "some data") - assert.Equal(t, c.Errors[0].Flags, ErrorTypeExternal) + assert.Equal(t, c.Errors[0].Metadata, "some data") + assert.Equal(t, c.Errors[0].Flags, ErrorTypePrivate) assert.Equal(t, c.Errors[1].Error, errors.New("second error")) - assert.Equal(t, c.Errors[1].Meta, "some data 2") - assert.Equal(t, c.Errors[1].Flags, ErrorTypeExternal) + assert.Equal(t, c.Errors[1].Metadata, "some data 2") + assert.Equal(t, c.Errors[1].Flags, ErrorTypePrivate) } func TestContextTypedError(t *testing.T) { c, _, _ := createTestContext() - c.ErrorTyped(errors.New("externo 0"), ErrorTypeExternal, nil) - c.ErrorTyped(errors.New("externo 1"), ErrorTypeExternal, nil) - c.ErrorTyped(errors.New("interno 0"), ErrorTypeInternal, nil) - c.ErrorTyped(errors.New("externo 2"), ErrorTypeExternal, nil) - c.ErrorTyped(errors.New("interno 1"), ErrorTypeInternal, nil) - c.ErrorTyped(errors.New("interno 2"), ErrorTypeInternal, nil) + c.Error(errors.New("externo 0")).Type(ErrorTypePublic) + c.Error(errors.New("externo 1")).Type(ErrorTypePublic) + c.Error(errors.New("interno 0")).Type(ErrorTypePrivate) + c.Error(errors.New("externo 2")).Type(ErrorTypePublic) + c.Error(errors.New("interno 1")).Type(ErrorTypePrivate) + c.Error(errors.New("interno 2")).Type(ErrorTypePrivate) - for _, err := range c.Errors.ByType(ErrorTypeExternal) { - assert.Equal(t, err.Flags, ErrorTypeExternal) + for _, err := range c.Errors.ByType(ErrorTypePublic) { + assert.Equal(t, err.Flags, ErrorTypePublic) } - for _, err := range c.Errors.ByType(ErrorTypeInternal) { - assert.Equal(t, err.Flags, ErrorTypeInternal) + for _, err := range c.Errors.ByType(ErrorTypePrivate) { + assert.Equal(t, err.Flags, ErrorTypePrivate) } assert.Equal(t, c.Errors.Errors(), []string{"externo 0", "externo 1", "interno 0", "externo 2", "interno 1", "interno 2"}) @@ -394,7 +394,7 @@ func TestContextTypedError(t *testing.T) { func TestContextFail(t *testing.T) { c, w, _ := createTestContext() - c.Fail(401, errors.New("bad input")) + c.AbortWithError(401, errors.New("bad input")) c.Writer.WriteHeaderNow() assert.Equal(t, w.Code, 401) @@ -434,7 +434,7 @@ func TestContextAutoBind(t *testing.T) { Foo string `json:"foo"` Bar string `json:"bar"` } - assert.True(t, c.Bind(&obj)) + assert.NoError(t, c.Bind(&obj)) assert.Equal(t, obj.Bar, "foo") assert.Equal(t, obj.Foo, "bar") assert.Equal(t, w.Body.Len(), 0) @@ -450,7 +450,7 @@ func TestContextBadAutoBind(t *testing.T) { } assert.False(t, c.IsAborted()) - assert.False(t, c.Bind(&obj)) + assert.Error(t, c.Bind(&obj)) c.Writer.WriteHeaderNow() assert.Empty(t, obj.Bar) @@ -467,7 +467,7 @@ func TestContextBindWith(t *testing.T) { Foo string `json:"foo"` Bar string `json:"bar"` } - assert.True(t, c.BindWith(&obj, binding.JSON)) + assert.NoError(t, c.BindWith(&obj, binding.JSON)) assert.Equal(t, obj.Bar, "foo") assert.Equal(t, obj.Foo, "bar") assert.Equal(t, w.Body.Len(), 0) diff --git a/errors.go b/errors.go index 73179aa1..e26d50f7 100644 --- a/errors.go +++ b/errors.go @@ -10,19 +10,33 @@ import ( ) const ( - ErrorTypeInternal = 1 << iota - ErrorTypeExternal = 1 << iota - ErrorTypeAny = 0xffffffff + ErrorTypeBind = 1 << 31 + ErrorTypeRender = 1 << 30 + ErrorTypePrivate = 1 << 0 + ErrorTypePublic = 1 << 1 + + ErrorTypeAny = 0xffffffff + ErrorTypeNu = 2 ) // Used internally to collect errors that occurred during an http request. type errorMsg struct { - Error error `json:"error"` - Flags int `json:"-"` - Meta interface{} `json:"meta"` + Error error `json:"error"` + Flags int `json:"-"` + Metadata interface{} `json:"meta"` } -type errorMsgs []errorMsg +func (msg *errorMsg) Type(flags int) *errorMsg { + msg.Flags = flags + return msg +} + +func (msg *errorMsg) Meta(data interface{}) *errorMsg { + msg.Metadata = data + return msg +} + +type errorMsgs []*errorMsg func (a errorMsgs) ByType(typ int) errorMsgs { if len(a) == 0 { @@ -54,7 +68,7 @@ func (a errorMsgs) String() string { } var buffer bytes.Buffer for i, msg := range a { - fmt.Fprintf(&buffer, "Error #%02d: %s\n Meta: %v\n", (i + 1), msg.Error, msg.Meta) + fmt.Fprintf(&buffer, "Error #%02d: %s\n Meta: %v\n", (i + 1), msg.Error, msg.Metadata) } return buffer.String() } diff --git a/middleware_test.go b/middleware_test.go index a0fac56d..10031cf9 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -169,7 +169,7 @@ func TestMiddlewareFailHandlersChain(t *testing.T) { router := New() router.Use(func(context *Context) { signature += "A" - context.Fail(500, errors.New("foo")) + context.AbortWithError(500, errors.New("foo")) }) router.Use(func(context *Context) { signature += "B"