diff --git a/gin.go b/gin.go index 1e126179..004ab1f7 100644 --- a/gin.go +++ b/gin.go @@ -109,8 +109,10 @@ type Engine struct { FuncMap template.FuncMap allNoRoute HandlersChain allNoMethod HandlersChain + allAutoRedirect HandlersChain noRoute HandlersChain noMethod HandlersChain + autoRedirect HandlersChain pool sync.Pool trees methodTrees maxParams uint16 @@ -234,6 +236,13 @@ func (engine *Engine) NoMethod(handlers ...HandlerFunc) { engine.rebuild405Handlers() } +// AutoRedirect sets the handlers called when auto redirected +// (RedirectTrailingSlash and RedirectFixedPath) +func (engine *Engine) AutoRedirect(handlers ...HandlerFunc) { + engine.autoRedirect = handlers + engine.rebuildAutoRedirectHandlers() +} + // Use attaches a global middleware to the router. ie. the middleware attached though Use() will be // included in the handlers chain for every single request. Even 404, 405, static files... // For example, this is the right place for a logger or error management middleware. @@ -241,6 +250,7 @@ func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes { engine.RouterGroup.Use(middleware...) engine.rebuild404Handlers() engine.rebuild405Handlers() + engine.rebuildAutoRedirectHandlers() return engine } @@ -252,6 +262,10 @@ func (engine *Engine) rebuild405Handlers() { engine.allNoMethod = engine.combineHandlers(engine.noMethod) } +func (engine *Engine) rebuildAutoRedirectHandlers() { + engine.allAutoRedirect = engine.combineHandlers(engine.autoRedirect) +} + func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { assert1(path[0] == '/', "path must begin with '/'") assert1(method != "", "HTTP method can not be empty") @@ -422,6 +436,7 @@ func (engine *Engine) handleHTTPRequest(c *Context) { return } if httpMethod != "CONNECT" && rPath != "/" { + c.handlers = engine.allAutoRedirect if value.tsr && engine.RedirectTrailingSlash { redirectTrailingSlash(c) return @@ -495,13 +510,14 @@ func redirectFixedPath(c *Context, root *node, trailingSlash bool) bool { func redirectRequest(c *Context) { req := c.Request - rPath := req.URL.Path - rURL := req.URL.String() - code := http.StatusMovedPermanently // Permanent redirect, request with GET method if req.Method != http.MethodGet { code = http.StatusTemporaryRedirect } + c.Next() + + rPath := req.URL.Path + rURL := req.URL.String() debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL) http.Redirect(c.Writer, req, rURL, code) c.writermem.WriteHeaderNow() diff --git a/gin_test.go b/gin_test.go index 11bdd79c..f93a4e6b 100644 --- a/gin_test.go +++ b/gin_test.go @@ -428,6 +428,59 @@ func TestNoMethodWithGlobalHandlers(t *testing.T) { compareFunc(t, router.allNoMethod[2], middleware0) } +func TestAutoRedirectWithoutGlobalHandlers(t *testing.T) { + var middleware0 HandlerFunc = func(c *Context) {} + var middleware1 HandlerFunc = func(c *Context) {} + + router := New() + + router.AutoRedirect(middleware0) + assert.Nil(t, router.Handlers) + assert.Len(t, router.autoRedirect, 1) + assert.Len(t, router.allAutoRedirect, 1) + compareFunc(t, router.autoRedirect[0], middleware0) + compareFunc(t, router.allAutoRedirect[0], middleware0) + + router.AutoRedirect(middleware1, middleware0) + assert.Len(t, router.autoRedirect, 2) + assert.Len(t, router.allAutoRedirect, 2) + compareFunc(t, router.autoRedirect[0], middleware1) + compareFunc(t, router.allAutoRedirect[0], middleware1) + compareFunc(t, router.autoRedirect[1], middleware0) + compareFunc(t, router.allAutoRedirect[1], middleware0) +} + +func TestAutoRedirectWithGlobalHandlers(t *testing.T) { + var middleware0 HandlerFunc = func(c *Context) {} + var middleware1 HandlerFunc = func(c *Context) {} + var middleware2 HandlerFunc = func(c *Context) {} + + router := New() + router.Use(middleware2) + + router.AutoRedirect(middleware0) + assert.Len(t, router.allAutoRedirect, 2) + assert.Len(t, router.Handlers, 1) + assert.Len(t, router.autoRedirect, 1) + + compareFunc(t, router.Handlers[0], middleware2) + compareFunc(t, router.autoRedirect[0], middleware0) + compareFunc(t, router.allAutoRedirect[0], middleware2) + compareFunc(t, router.allAutoRedirect[1], middleware0) + + router.Use(middleware1) + assert.Len(t, router.allAutoRedirect, 3) + assert.Len(t, router.Handlers, 2) + assert.Len(t, router.autoRedirect, 1) + + compareFunc(t, router.Handlers[0], middleware2) + compareFunc(t, router.Handlers[1], middleware1) + compareFunc(t, router.autoRedirect[0], middleware0) + compareFunc(t, router.allAutoRedirect[0], middleware2) + compareFunc(t, router.allAutoRedirect[1], middleware1) + compareFunc(t, router.allAutoRedirect[2], middleware0) +} + func compareFunc(t *testing.T, a, b interface{}) { sf1 := reflect.ValueOf(a) sf2 := reflect.ValueOf(b) diff --git a/routes_test.go b/routes_test.go index 11ff71a6..8c919b03 100644 --- a/routes_test.go +++ b/routes_test.go @@ -224,6 +224,27 @@ func TestRouteRedirectFixedPath(t *testing.T) { assert.Equal(t, http.StatusTemporaryRedirect, w.Code) } +func TestRouteRedirectWithHandler(t *testing.T) { + router := New() + router.RedirectTrailingSlash = true + router.GET("/path", func(c *Context) {}) + passed := []bool{false, false} + router.Use(func(c *Context) { + passed[0] = true + c.Next() + }) + router.AutoRedirect(func(c *Context) { + passed[1] = true + c.Next() + }) + + w := performRequest(router, http.MethodGet, "/path/") + assert.Equal(t, "/path", w.Header().Get("Location")) + assert.Equal(t, http.StatusMovedPermanently, w.Code) + assert.True(t, passed[0]) + assert.True(t, passed[1]) +} + // TestContextParamsGet tests that a parameter can be parsed from the URL. func TestRouteParamsByName(t *testing.T) { name := ""