AutoRedirect API to handle before auto redirection

This commit is contained in:
Helios 2020-08-23 00:41:04 +08:00
parent b94d23d1b4
commit 22b88e0ed1
3 changed files with 93 additions and 3 deletions

22
gin.go
View File

@ -109,8 +109,10 @@ type Engine struct {
FuncMap template.FuncMap FuncMap template.FuncMap
allNoRoute HandlersChain allNoRoute HandlersChain
allNoMethod HandlersChain allNoMethod HandlersChain
allAutoRedirect HandlersChain
noRoute HandlersChain noRoute HandlersChain
noMethod HandlersChain noMethod HandlersChain
autoRedirect HandlersChain
pool sync.Pool pool sync.Pool
trees methodTrees trees methodTrees
maxParams uint16 maxParams uint16
@ -234,6 +236,13 @@ func (engine *Engine) NoMethod(handlers ...HandlerFunc) {
engine.rebuild405Handlers() engine.rebuild405Handlers()
} }
// AutoRedirect sets the handlers called when auto redirected
// (RedirectTrailingSlash and RedirectFixedPath)
func (engine *Engine) AutoRedirect(handlers ...HandlerFunc) {
engine.autoRedirect = handlers
engine.rebuildAutoRedirectHandlers()
}
// Use attaches a global middleware to the router. ie. the middleware attached though Use() will be // Use attaches a global middleware to the router. ie. the middleware attached though Use() will be
// included in the handlers chain for every single request. Even 404, 405, static files... // included in the handlers chain for every single request. Even 404, 405, static files...
// For example, this is the right place for a logger or error management middleware. // For example, this is the right place for a logger or error management middleware.
@ -241,6 +250,7 @@ func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes {
engine.RouterGroup.Use(middleware...) engine.RouterGroup.Use(middleware...)
engine.rebuild404Handlers() engine.rebuild404Handlers()
engine.rebuild405Handlers() engine.rebuild405Handlers()
engine.rebuildAutoRedirectHandlers()
return engine return engine
} }
@ -252,6 +262,10 @@ func (engine *Engine) rebuild405Handlers() {
engine.allNoMethod = engine.combineHandlers(engine.noMethod) engine.allNoMethod = engine.combineHandlers(engine.noMethod)
} }
func (engine *Engine) rebuildAutoRedirectHandlers() {
engine.allAutoRedirect = engine.combineHandlers(engine.autoRedirect)
}
func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
assert1(path[0] == '/', "path must begin with '/'") assert1(path[0] == '/', "path must begin with '/'")
assert1(method != "", "HTTP method can not be empty") assert1(method != "", "HTTP method can not be empty")
@ -422,6 +436,7 @@ func (engine *Engine) handleHTTPRequest(c *Context) {
return return
} }
if httpMethod != "CONNECT" && rPath != "/" { if httpMethod != "CONNECT" && rPath != "/" {
c.handlers = engine.allAutoRedirect
if value.tsr && engine.RedirectTrailingSlash { if value.tsr && engine.RedirectTrailingSlash {
redirectTrailingSlash(c) redirectTrailingSlash(c)
return return
@ -495,13 +510,14 @@ func redirectFixedPath(c *Context, root *node, trailingSlash bool) bool {
func redirectRequest(c *Context) { func redirectRequest(c *Context) {
req := c.Request req := c.Request
rPath := req.URL.Path
rURL := req.URL.String()
code := http.StatusMovedPermanently // Permanent redirect, request with GET method code := http.StatusMovedPermanently // Permanent redirect, request with GET method
if req.Method != http.MethodGet { if req.Method != http.MethodGet {
code = http.StatusTemporaryRedirect code = http.StatusTemporaryRedirect
} }
c.Next()
rPath := req.URL.Path
rURL := req.URL.String()
debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL) debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL)
http.Redirect(c.Writer, req, rURL, code) http.Redirect(c.Writer, req, rURL, code)
c.writermem.WriteHeaderNow() c.writermem.WriteHeaderNow()

View File

@ -428,6 +428,59 @@ func TestNoMethodWithGlobalHandlers(t *testing.T) {
compareFunc(t, router.allNoMethod[2], middleware0) compareFunc(t, router.allNoMethod[2], middleware0)
} }
func TestAutoRedirectWithoutGlobalHandlers(t *testing.T) {
var middleware0 HandlerFunc = func(c *Context) {}
var middleware1 HandlerFunc = func(c *Context) {}
router := New()
router.AutoRedirect(middleware0)
assert.Nil(t, router.Handlers)
assert.Len(t, router.autoRedirect, 1)
assert.Len(t, router.allAutoRedirect, 1)
compareFunc(t, router.autoRedirect[0], middleware0)
compareFunc(t, router.allAutoRedirect[0], middleware0)
router.AutoRedirect(middleware1, middleware0)
assert.Len(t, router.autoRedirect, 2)
assert.Len(t, router.allAutoRedirect, 2)
compareFunc(t, router.autoRedirect[0], middleware1)
compareFunc(t, router.allAutoRedirect[0], middleware1)
compareFunc(t, router.autoRedirect[1], middleware0)
compareFunc(t, router.allAutoRedirect[1], middleware0)
}
func TestAutoRedirectWithGlobalHandlers(t *testing.T) {
var middleware0 HandlerFunc = func(c *Context) {}
var middleware1 HandlerFunc = func(c *Context) {}
var middleware2 HandlerFunc = func(c *Context) {}
router := New()
router.Use(middleware2)
router.AutoRedirect(middleware0)
assert.Len(t, router.allAutoRedirect, 2)
assert.Len(t, router.Handlers, 1)
assert.Len(t, router.autoRedirect, 1)
compareFunc(t, router.Handlers[0], middleware2)
compareFunc(t, router.autoRedirect[0], middleware0)
compareFunc(t, router.allAutoRedirect[0], middleware2)
compareFunc(t, router.allAutoRedirect[1], middleware0)
router.Use(middleware1)
assert.Len(t, router.allAutoRedirect, 3)
assert.Len(t, router.Handlers, 2)
assert.Len(t, router.autoRedirect, 1)
compareFunc(t, router.Handlers[0], middleware2)
compareFunc(t, router.Handlers[1], middleware1)
compareFunc(t, router.autoRedirect[0], middleware0)
compareFunc(t, router.allAutoRedirect[0], middleware2)
compareFunc(t, router.allAutoRedirect[1], middleware1)
compareFunc(t, router.allAutoRedirect[2], middleware0)
}
func compareFunc(t *testing.T, a, b interface{}) { func compareFunc(t *testing.T, a, b interface{}) {
sf1 := reflect.ValueOf(a) sf1 := reflect.ValueOf(a)
sf2 := reflect.ValueOf(b) sf2 := reflect.ValueOf(b)

View File

@ -224,6 +224,27 @@ func TestRouteRedirectFixedPath(t *testing.T) {
assert.Equal(t, http.StatusTemporaryRedirect, w.Code) assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
} }
func TestRouteRedirectWithHandler(t *testing.T) {
router := New()
router.RedirectTrailingSlash = true
router.GET("/path", func(c *Context) {})
passed := []bool{false, false}
router.Use(func(c *Context) {
passed[0] = true
c.Next()
})
router.AutoRedirect(func(c *Context) {
passed[1] = true
c.Next()
})
w := performRequest(router, http.MethodGet, "/path/")
assert.Equal(t, "/path", w.Header().Get("Location"))
assert.Equal(t, http.StatusMovedPermanently, w.Code)
assert.True(t, passed[0])
assert.True(t, passed[1])
}
// TestContextParamsGet tests that a parameter can be parsed from the URL. // TestContextParamsGet tests that a parameter can be parsed from the URL.
func TestRouteParamsByName(t *testing.T) { func TestRouteParamsByName(t *testing.T) {
name := "" name := ""