add prefix from X-Forwarded-Prefix in redirectTrailingSlash (#1238)

* add prefix from X-Forwarded-Prefix in redirectTrailingSlash

* added test

* fix path import
This commit is contained in:
Tudor Roman 2019-02-27 13:56:29 +02:00 committed by 田欧
parent e207a3ce65
commit ccb105dbcb
2 changed files with 25 additions and 6 deletions

14
gin.go
View File

@ -10,6 +10,7 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"path"
"sync" "sync"
"github.com/gin-gonic/gin/render" "github.com/gin-gonic/gin/render"
@ -438,17 +439,20 @@ func serveError(c *Context, code int, defaultMessage []byte) {
func redirectTrailingSlash(c *Context) { func redirectTrailingSlash(c *Context) {
req := c.Request req := c.Request
path := req.URL.Path p := req.URL.Path
if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." {
p = prefix + "/" + req.URL.Path
}
code := http.StatusMovedPermanently // Permanent redirect, request with GET method code := http.StatusMovedPermanently // Permanent redirect, request with GET method
if req.Method != "GET" { if req.Method != "GET" {
code = http.StatusTemporaryRedirect code = http.StatusTemporaryRedirect
} }
req.URL.Path = path + "/" req.URL.Path = p + "/"
if length := len(path); length > 1 && path[length-1] == '/' { if length := len(p); length > 1 && p[length-1] == '/' {
req.URL.Path = path[:length-1] req.URL.Path = p[:length-1]
} }
debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String()) debugPrint("redirecting request %d: %s --> %s", code, p, req.URL.String())
http.Redirect(c.Writer, req, req.URL.String(), code) http.Redirect(c.Writer, req, req.URL.String(), code)
c.writermem.WriteHeaderNow() c.writermem.WriteHeaderNow()
} }

View File

@ -16,8 +16,16 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func performRequest(r http.Handler, method, path string) *httptest.ResponseRecorder { type header struct {
Key string
Value string
}
func performRequest(r http.Handler, method, path string, headers ...header) *httptest.ResponseRecorder {
req, _ := http.NewRequest(method, path, nil) req, _ := http.NewRequest(method, path, nil)
for _, h := range headers {
req.Header.Add(h.Key, h.Value)
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
return w return w
@ -170,6 +178,13 @@ func TestRouteRedirectTrailingSlash(t *testing.T) {
w = performRequest(router, "PUT", "/path4/") w = performRequest(router, "PUT", "/path4/")
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
w = performRequest(router, "GET", "/path2", header{Key: "X-Forwarded-Prefix", Value: "/api"})
assert.Equal(t, "/api/path2/", w.Header().Get("Location"))
assert.Equal(t, 301, w.Code)
w = performRequest(router, "GET", "/path2/", header{Key: "X-Forwarded-Prefix", Value: "/api/"})
assert.Equal(t, 200, w.Code)
router.RedirectTrailingSlash = false router.RedirectTrailingSlash = false
w = performRequest(router, "GET", "/path/") w = performRequest(router, "GET", "/path/")