mirror of https://github.com/gin-gonic/gin.git
Updates tree.go + fixes + unit tests
This commit is contained in:
parent
295201dad2
commit
f212ae7728
2
gin.go
2
gin.go
|
@ -233,6 +233,7 @@ func (engine *Engine) serveAutoRedirect(c *Context, root *node, tsr bool) bool {
|
|||
}
|
||||
debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String())
|
||||
http.Redirect(c.Writer, req, req.URL.String(), code)
|
||||
c.writermem.WriteHeaderNow()
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -246,6 +247,7 @@ func (engine *Engine) serveAutoRedirect(c *Context, root *node, tsr bool) bool {
|
|||
req.URL.Path = string(fixedPath)
|
||||
debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String())
|
||||
http.Redirect(c.Writer, req, req.URL.String(), code)
|
||||
c.writermem.WriteHeaderNow()
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,9 @@ func TestCreateEngine(t *testing.T) {
|
|||
assert.Equal(t, "/", router.absolutePath)
|
||||
assert.Equal(t, router.engine, router)
|
||||
assert.Empty(t, router.Handlers)
|
||||
assert.True(t, router.RedirectTrailingSlash)
|
||||
assert.True(t, router.RedirectFixedPath)
|
||||
assert.True(t, router.HandleMethodNotAllowed)
|
||||
|
||||
assert.Panics(t, func() { router.handle("", "/", []HandlerFunc{func(_ *Context) {}}) })
|
||||
assert.Panics(t, func() { router.handle("GET", "", []HandlerFunc{func(_ *Context) {}}) })
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package gin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -110,18 +111,28 @@ func TestRouteNotOK2(t *testing.T) {
|
|||
func TestRouteParamsByName(t *testing.T) {
|
||||
name := ""
|
||||
lastName := ""
|
||||
wild := ""
|
||||
router := New()
|
||||
router.GET("/test/:name/:last_name", func(c *Context) {
|
||||
router.GET("/test/:name/:last_name/*wild", func(c *Context) {
|
||||
name = c.Params.ByName("name")
|
||||
lastName = c.Params.ByName("last_name")
|
||||
wild = c.Params.ByName("wild")
|
||||
|
||||
assert.Equal(t, name, c.ParamValue("name"))
|
||||
assert.Equal(t, lastName, c.ParamValue("last_name"))
|
||||
|
||||
assert.Equal(t, name, c.DefaultParamValue("name", "nothing"))
|
||||
assert.Equal(t, lastName, c.DefaultParamValue("last_name", "nothing"))
|
||||
assert.Equal(t, c.DefaultParamValue("noKey", "default"), "default")
|
||||
})
|
||||
// RUN
|
||||
w := performRequest(router, "GET", "/test/john/smith")
|
||||
w := performRequest(router, "GET", "/test/john/smith/is/super/great")
|
||||
|
||||
// TEST
|
||||
assert.Equal(t, w.Code, 200)
|
||||
assert.Equal(t, name, "john")
|
||||
assert.Equal(t, lastName, "smith")
|
||||
assert.Equal(t, wild, "/is/super/great")
|
||||
}
|
||||
|
||||
// TestHandleStaticFile - ensure the static file handles properly
|
||||
|
@ -183,3 +194,70 @@ func TestRouteHeadToDir(t *testing.T) {
|
|||
assert.Contains(t, bodyAsString, "gin.go")
|
||||
assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/html; charset=utf-8")
|
||||
}
|
||||
|
||||
func TestRouteNotAllowed(t *testing.T) {
|
||||
router := New()
|
||||
|
||||
router.POST("/path", func(c *Context) {})
|
||||
w := performRequest(router, "GET", "/path")
|
||||
assert.Equal(t, w.Code, http.StatusMethodNotAllowed)
|
||||
|
||||
router.NoMethod(func(c *Context) {
|
||||
c.String(http.StatusTeapot, "responseText")
|
||||
})
|
||||
w = performRequest(router, "GET", "/path")
|
||||
assert.Equal(t, w.Body.String(), "responseText")
|
||||
assert.Equal(t, w.Code, http.StatusTeapot)
|
||||
}
|
||||
|
||||
func TestRouterNotFound(t *testing.T) {
|
||||
router := New()
|
||||
router.GET("/path", func(c *Context) {})
|
||||
router.GET("/dir/", func(c *Context) {})
|
||||
router.GET("/", func(c *Context) {})
|
||||
|
||||
testRoutes := []struct {
|
||||
route string
|
||||
code int
|
||||
header string
|
||||
}{
|
||||
{"/path/", 301, "map[Location:[/path]]"}, // TSR -/
|
||||
{"/dir", 301, "map[Location:[/dir/]]"}, // TSR +/
|
||||
{"", 301, "map[Location:[/]]"}, // TSR +/
|
||||
{"/PATH", 301, "map[Location:[/path]]"}, // Fixed Case
|
||||
{"/DIR/", 301, "map[Location:[/dir/]]"}, // Fixed Case
|
||||
{"/PATH/", 301, "map[Location:[/path]]"}, // Fixed Case -/
|
||||
{"/DIR", 301, "map[Location:[/dir/]]"}, // Fixed Case +/
|
||||
{"/../path", 301, "map[Location:[/path]]"}, // CleanPath
|
||||
{"/nope", 404, ""}, // NotFound
|
||||
}
|
||||
for _, tr := range testRoutes {
|
||||
w := performRequest(router, "GET", tr.route)
|
||||
assert.Equal(t, w.Code, tr.code)
|
||||
if w.Code != 404 {
|
||||
assert.Equal(t, fmt.Sprint(w.Header()), tr.header)
|
||||
}
|
||||
}
|
||||
|
||||
// Test custom not found handler
|
||||
var notFound bool
|
||||
router.NoRoute(func(c *Context) {
|
||||
c.AbortWithStatus(404)
|
||||
notFound = true
|
||||
})
|
||||
w := performRequest(router, "GET", "/nope")
|
||||
assert.Equal(t, w.Code, 404)
|
||||
assert.True(t, notFound)
|
||||
|
||||
// Test other method than GET (want 307 instead of 301)
|
||||
router.PATCH("/path", func(c *Context) {})
|
||||
w = performRequest(router, "PATCH", "/path/")
|
||||
assert.Equal(t, w.Code, 307)
|
||||
assert.Equal(t, fmt.Sprint(w.Header()), "map[Location:[/path]]")
|
||||
|
||||
// Test special case where no node for the prefix "/" exists
|
||||
router = New()
|
||||
router.GET("/a", func(c *Context) {})
|
||||
w = performRequest(router, "GET", "/")
|
||||
assert.Equal(t, w.Code, 404)
|
||||
}
|
||||
|
|
41
tree.go
41
tree.go
|
@ -78,6 +78,7 @@ func (n *node) incrementChildPrio(pos int) int {
|
|||
// addRoute adds a node with the given handle to the path.
|
||||
// Not concurrency-safe!
|
||||
func (n *node) addRoute(path string, handlers []HandlerFunc) {
|
||||
fullPath := path
|
||||
n.priority++
|
||||
numParams := countParams(path)
|
||||
|
||||
|
@ -147,7 +148,9 @@ func (n *node) addRoute(path string, handlers []HandlerFunc) {
|
|||
}
|
||||
}
|
||||
|
||||
panic("conflict with wildcard route")
|
||||
panic("path segment '" + path +
|
||||
"' conflicts with existing wildcard '" + n.path +
|
||||
"' in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
c := path[0]
|
||||
|
@ -179,23 +182,23 @@ func (n *node) addRoute(path string, handlers []HandlerFunc) {
|
|||
n.incrementChildPrio(len(n.indices) - 1)
|
||||
n = child
|
||||
}
|
||||
n.insertChild(numParams, path, handlers)
|
||||
n.insertChild(numParams, path, fullPath, handlers)
|
||||
return
|
||||
|
||||
} else if i == len(path) { // Make node a (in-path) leaf
|
||||
if n.handlers != nil {
|
||||
panic("a Handle is already registered for this path")
|
||||
panic("handlers are already registered for path ''" + fullPath + "'")
|
||||
}
|
||||
n.handlers = handlers
|
||||
}
|
||||
return
|
||||
}
|
||||
} else { // Empty tree
|
||||
n.insertChild(numParams, path, handlers)
|
||||
n.insertChild(numParams, path, fullPath, handlers)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc) {
|
||||
func (n *node) insertChild(numParams uint8, path string, fullPath string, handlers []HandlerFunc) {
|
||||
var offset int // already handled bytes of the path
|
||||
|
||||
// find prefix until first wildcard (beginning with ':'' or '*'')
|
||||
|
@ -205,27 +208,29 @@ func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc)
|
|||
continue
|
||||
}
|
||||
|
||||
// check if this Node existing children which would be
|
||||
// unreachable if we insert the wildcard here
|
||||
if len(n.children) > 0 {
|
||||
panic("wildcard route conflicts with existing children")
|
||||
}
|
||||
|
||||
// find wildcard end (either '/' or path end)
|
||||
end := i + 1
|
||||
for end < max && path[end] != '/' {
|
||||
switch path[end] {
|
||||
// the wildcard name must not contain ':' and '*'
|
||||
case ':', '*':
|
||||
panic("only one wildcard per path segment is allowed")
|
||||
panic("only one wildcard per path segment is allowed, has: '" +
|
||||
path[i:] + "' in path '" + fullPath + "'")
|
||||
default:
|
||||
end++
|
||||
}
|
||||
}
|
||||
|
||||
// check if this Node existing children which would be
|
||||
// unreachable if we insert the wildcard here
|
||||
if len(n.children) > 0 {
|
||||
panic("wildcard route '" + path[i:end] +
|
||||
"' conflicts with existing children in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
// check if the wildcard has a name
|
||||
if end-i < 2 {
|
||||
panic("wildcards must be named with a non-empty name")
|
||||
panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
if c == ':' { // param
|
||||
|
@ -261,17 +266,17 @@ func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc)
|
|||
|
||||
} else { // catchAll
|
||||
if end != max || numParams > 1 {
|
||||
panic("catch-all routes are only allowed at the end of the path")
|
||||
panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
|
||||
panic("catch-all conflicts with existing handle for the path segment root")
|
||||
panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
// currently fixed width 1 for '/'
|
||||
i--
|
||||
if path[i] != '/' {
|
||||
panic("no / before catch-all")
|
||||
panic("no / before catch-all in path '" + fullPath + "'")
|
||||
}
|
||||
|
||||
n.path = path[offset:i]
|
||||
|
@ -394,7 +399,7 @@ walk: // Outer loop for walking the tree
|
|||
return
|
||||
|
||||
default:
|
||||
panic("Invalid node type")
|
||||
panic("invalid node type")
|
||||
}
|
||||
}
|
||||
} else if path == n.path {
|
||||
|
@ -505,7 +510,7 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPa
|
|||
return append(ciPath, path...), true
|
||||
|
||||
default:
|
||||
panic("Invalid node type")
|
||||
panic("invalid node type")
|
||||
}
|
||||
} else {
|
||||
// We should have reached the node containing the handle.
|
||||
|
|
10
tree_test.go
10
tree_test.go
|
@ -357,7 +357,7 @@ func TestTreeDoubleWildcard(t *testing.T) {
|
|||
tree.addRoute(route, nil)
|
||||
})
|
||||
|
||||
if rs, ok := recv.(string); !ok || rs != panicMsg {
|
||||
if rs, ok := recv.(string); !ok || !strings.HasPrefix(rs, panicMsg) {
|
||||
t.Fatalf(`"Expected panic "%s" for route '%s', got "%v"`, panicMsg, route, recv)
|
||||
}
|
||||
}
|
||||
|
@ -594,15 +594,15 @@ func TestTreeInvalidNodeType(t *testing.T) {
|
|||
recv := catchPanic(func() {
|
||||
tree.getValue("/test", nil)
|
||||
})
|
||||
if rs, ok := recv.(string); !ok || rs != "Invalid node type" {
|
||||
t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv)
|
||||
if rs, ok := recv.(string); !ok || rs != "invalid node type" {
|
||||
t.Fatalf(`Expected panic "invalid node type", got "%v"`, recv)
|
||||
}
|
||||
|
||||
// case-insensitive lookup
|
||||
recv = catchPanic(func() {
|
||||
tree.findCaseInsensitivePath("/test", true)
|
||||
})
|
||||
if rs, ok := recv.(string); !ok || rs != "Invalid node type" {
|
||||
t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv)
|
||||
if rs, ok := recv.(string); !ok || rs != "invalid node type" {
|
||||
t.Fatalf(`Expected panic "invalid node type", got "%v"`, recv)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue