diff --git a/gin.go b/gin.go index ad64c35f..e28e9579 100644 --- a/gin.go +++ b/gin.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "os" + "path" "sync" "github.com/gin-gonic/gin/render" @@ -438,17 +439,20 @@ func serveError(c *Context, code int, defaultMessage []byte) { func redirectTrailingSlash(c *Context) { req := c.Request - path := req.URL.Path + p := req.URL.Path + if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." { + p = prefix + "/" + req.URL.Path + } code := http.StatusMovedPermanently // Permanent redirect, request with GET method if req.Method != "GET" { code = http.StatusTemporaryRedirect } - req.URL.Path = path + "/" - if length := len(path); length > 1 && path[length-1] == '/' { - req.URL.Path = path[:length-1] + req.URL.Path = p + "/" + if length := len(p); length > 1 && p[length-1] == '/' { + req.URL.Path = p[:length-1] } - debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String()) + debugPrint("redirecting request %d: %s --> %s", code, p, req.URL.String()) http.Redirect(c.Writer, req, req.URL.String(), code) c.writermem.WriteHeaderNow() } diff --git a/routes_test.go b/routes_test.go index 8d50292d..a842704f 100644 --- a/routes_test.go +++ b/routes_test.go @@ -16,8 +16,16 @@ import ( "github.com/stretchr/testify/assert" ) -func performRequest(r http.Handler, method, path string) *httptest.ResponseRecorder { +type header struct { + Key string + Value string +} + +func performRequest(r http.Handler, method, path string, headers ...header) *httptest.ResponseRecorder { req, _ := http.NewRequest(method, path, nil) + for _, h := range headers { + req.Header.Add(h.Key, h.Value) + } w := httptest.NewRecorder() r.ServeHTTP(w, req) return w @@ -170,6 +178,13 @@ func TestRouteRedirectTrailingSlash(t *testing.T) { w = performRequest(router, "PUT", "/path4/") assert.Equal(t, http.StatusOK, w.Code) + w = performRequest(router, "GET", "/path2", header{Key: "X-Forwarded-Prefix", Value: "/api"}) + assert.Equal(t, "/api/path2/", w.Header().Get("Location")) + assert.Equal(t, 301, w.Code) + + w = performRequest(router, "GET", "/path2/", header{Key: "X-Forwarded-Prefix", Value: "/api/"}) + assert.Equal(t, 200, w.Code) + router.RedirectTrailingSlash = false w = performRequest(router, "GET", "/path/")