Add CustomRecovery builtin middleware (#2322)

* Add CustomRecovery and CustomRecoveryWithWriter methods

* add CustomRecovery example to README

* add test for CustomRecovery

* support RecoveryWithWriter(io.Writer, ...RecoveryFunc)
This commit is contained in:
Johnny Dallas 2020-07-08 18:40:00 -07:00 committed by GitHub
parent 5e40c1d49c
commit 4cabdd303f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 161 additions and 5 deletions

View File

@ -496,6 +496,39 @@ func main() {
} }
``` ```
### Custom Recovery behavior
```go
func main() {
// Creates a router without any middleware by default
r := gin.New()
// Global middleware
// Logger middleware will write the logs to gin.DefaultWriter even if you set with GIN_MODE=release.
// By default gin.DefaultWriter = os.Stdout
r.Use(gin.Logger())
// Recovery middleware recovers from any panics and writes a 500 if there was one.
r.Use(gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
if err, ok := recovered.(string); ok {
c.String(http.StatusInternalServerError, fmt.Sprintf("error: %s", err))
}
c.AbortWithStatus(http.StatusInternalServerError)
}))
r.GET("/panic", func(c *gin.Context) {
// panic with a string -- the custom middleware could save this to a database or report it to the user
panic("foo")
})
r.GET("/", func(c *gin.Context) {
c.String(http.StatusOK, "ohai")
})
// Listen and serve on 0.0.0.0:8080
r.Run(":8080")
}
```
### How to write log file ### How to write log file
```go ```go
func main() { func main() {

View File

@ -26,13 +26,29 @@ var (
slash = []byte("/") slash = []byte("/")
) )
// RecoveryFunc defines the function passable to CustomRecovery.
type RecoveryFunc func(c *Context, err interface{})
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one. // Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
func Recovery() HandlerFunc { func Recovery() HandlerFunc {
return RecoveryWithWriter(DefaultErrorWriter) return RecoveryWithWriter(DefaultErrorWriter)
} }
//CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it.
func CustomRecovery(handle RecoveryFunc) HandlerFunc {
return RecoveryWithWriter(DefaultErrorWriter, handle)
}
// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one. // RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
func RecoveryWithWriter(out io.Writer) HandlerFunc { func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc {
if len(recovery) > 0 {
return CustomRecoveryWithWriter(out, recovery[0])
}
return CustomRecoveryWithWriter(out, defaultHandleRecovery)
}
// CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it.
func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
var logger *log.Logger var logger *log.Logger
if out != nil { if out != nil {
logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags) logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
@ -70,13 +86,12 @@ func RecoveryWithWriter(out io.Writer) HandlerFunc {
timeFormat(time.Now()), err, stack, reset) timeFormat(time.Now()), err, stack, reset)
} }
} }
// If the connection is dead, we can't write a status to it.
if brokenPipe { if brokenPipe {
// If the connection is dead, we can't write a status to it.
c.Error(err.(error)) // nolint: errcheck c.Error(err.(error)) // nolint: errcheck
c.Abort() c.Abort()
} else { } else {
c.AbortWithStatus(http.StatusInternalServerError) handle(c, err)
} }
} }
}() }()
@ -84,6 +99,10 @@ func RecoveryWithWriter(out io.Writer) HandlerFunc {
} }
} }
func defaultHandleRecovery(c *Context, err interface{}) {
c.AbortWithStatus(http.StatusInternalServerError)
}
// stack returns a nicely formatted stack frame, skipping skip frames. // stack returns a nicely formatted stack frame, skipping skip frames.
func stack(skip int) []byte { func stack(skip int) []byte {
buf := new(bytes.Buffer) // the returned data buf := new(bytes.Buffer) // the returned data

View File

@ -62,7 +62,7 @@ func TestPanicInHandler(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Contains(t, buffer.String(), "panic recovered") assert.Contains(t, buffer.String(), "panic recovered")
assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem") assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
assert.Contains(t, buffer.String(), "TestPanicInHandler") assert.Contains(t, buffer.String(), t.Name())
assert.NotContains(t, buffer.String(), "GET /recovery") assert.NotContains(t, buffer.String(), "GET /recovery")
// Debug mode prints the request // Debug mode prints the request
@ -144,3 +144,107 @@ func TestPanicWithBrokenPipe(t *testing.T) {
}) })
} }
} }
func TestCustomRecoveryWithWriter(t *testing.T) {
errBuffer := new(bytes.Buffer)
buffer := new(bytes.Buffer)
router := New()
handleRecovery := func(c *Context, err interface{}) {
errBuffer.WriteString(err.(string))
c.AbortWithStatus(http.StatusBadRequest)
}
router.Use(CustomRecoveryWithWriter(buffer, handleRecovery))
router.GET("/recovery", func(_ *Context) {
panic("Oupps, Houston, we have a problem")
})
// RUN
w := performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "panic recovered")
assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
assert.Contains(t, buffer.String(), t.Name())
assert.NotContains(t, buffer.String(), "GET /recovery")
// Debug mode prints the request
SetMode(DebugMode)
// RUN
w = performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "GET /recovery")
assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String())
SetMode(TestMode)
}
func TestCustomRecovery(t *testing.T) {
errBuffer := new(bytes.Buffer)
buffer := new(bytes.Buffer)
router := New()
DefaultErrorWriter = buffer
handleRecovery := func(c *Context, err interface{}) {
errBuffer.WriteString(err.(string))
c.AbortWithStatus(http.StatusBadRequest)
}
router.Use(CustomRecovery(handleRecovery))
router.GET("/recovery", func(_ *Context) {
panic("Oupps, Houston, we have a problem")
})
// RUN
w := performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "panic recovered")
assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
assert.Contains(t, buffer.String(), t.Name())
assert.NotContains(t, buffer.String(), "GET /recovery")
// Debug mode prints the request
SetMode(DebugMode)
// RUN
w = performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "GET /recovery")
assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String())
SetMode(TestMode)
}
func TestRecoveryWithWriterWithCustomRecovery(t *testing.T) {
errBuffer := new(bytes.Buffer)
buffer := new(bytes.Buffer)
router := New()
DefaultErrorWriter = buffer
handleRecovery := func(c *Context, err interface{}) {
errBuffer.WriteString(err.(string))
c.AbortWithStatus(http.StatusBadRequest)
}
router.Use(RecoveryWithWriter(DefaultErrorWriter, handleRecovery))
router.GET("/recovery", func(_ *Context) {
panic("Oupps, Houston, we have a problem")
})
// RUN
w := performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "panic recovered")
assert.Contains(t, buffer.String(), "Oupps, Houston, we have a problem")
assert.Contains(t, buffer.String(), t.Name())
assert.NotContains(t, buffer.String(), "GET /recovery")
// Debug mode prints the request
SetMode(DebugMode)
// RUN
w = performRequest(router, "GET", "/recovery")
// TEST
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, buffer.String(), "GET /recovery")
assert.Equal(t, strings.Repeat("Oupps, Houston, we have a problem", 2), errBuffer.String())
SetMode(TestMode)
}