Updates tree.go + fixes + unit tests

This commit is contained in:
Manu Mtz-Almeida 2015-05-05 16:37:33 +02:00
parent 295201dad2
commit f212ae7728
5 changed files with 113 additions and 25 deletions

2
gin.go
View File

@ -233,6 +233,7 @@ func (engine *Engine) serveAutoRedirect(c *Context, root *node, tsr bool) bool {
} }
debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String()) debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String())
http.Redirect(c.Writer, req, req.URL.String(), code) http.Redirect(c.Writer, req, req.URL.String(), code)
c.writermem.WriteHeaderNow()
return true return true
} }
@ -246,6 +247,7 @@ func (engine *Engine) serveAutoRedirect(c *Context, root *node, tsr bool) bool {
req.URL.Path = string(fixedPath) req.URL.Path = string(fixedPath)
debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String()) debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String())
http.Redirect(c.Writer, req, req.URL.String(), code) http.Redirect(c.Writer, req, req.URL.String(), code)
c.writermem.WriteHeaderNow()
return true return true
} }
} }

View File

@ -25,6 +25,9 @@ func TestCreateEngine(t *testing.T) {
assert.Equal(t, "/", router.absolutePath) assert.Equal(t, "/", router.absolutePath)
assert.Equal(t, router.engine, router) assert.Equal(t, router.engine, router)
assert.Empty(t, router.Handlers) assert.Empty(t, router.Handlers)
assert.True(t, router.RedirectTrailingSlash)
assert.True(t, router.RedirectFixedPath)
assert.True(t, router.HandleMethodNotAllowed)
assert.Panics(t, func() { router.handle("", "/", []HandlerFunc{func(_ *Context) {}}) }) assert.Panics(t, func() { router.handle("", "/", []HandlerFunc{func(_ *Context) {}}) })
assert.Panics(t, func() { router.handle("GET", "", []HandlerFunc{func(_ *Context) {}}) }) assert.Panics(t, func() { router.handle("GET", "", []HandlerFunc{func(_ *Context) {}}) })

View File

@ -5,6 +5,7 @@
package gin package gin
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -110,18 +111,28 @@ func TestRouteNotOK2(t *testing.T) {
func TestRouteParamsByName(t *testing.T) { func TestRouteParamsByName(t *testing.T) {
name := "" name := ""
lastName := "" lastName := ""
wild := ""
router := New() router := New()
router.GET("/test/:name/:last_name", func(c *Context) { router.GET("/test/:name/:last_name/*wild", func(c *Context) {
name = c.Params.ByName("name") name = c.Params.ByName("name")
lastName = c.Params.ByName("last_name") lastName = c.Params.ByName("last_name")
wild = c.Params.ByName("wild")
assert.Equal(t, name, c.ParamValue("name"))
assert.Equal(t, lastName, c.ParamValue("last_name"))
assert.Equal(t, name, c.DefaultParamValue("name", "nothing"))
assert.Equal(t, lastName, c.DefaultParamValue("last_name", "nothing"))
assert.Equal(t, c.DefaultParamValue("noKey", "default"), "default")
}) })
// RUN // RUN
w := performRequest(router, "GET", "/test/john/smith") w := performRequest(router, "GET", "/test/john/smith/is/super/great")
// TEST // TEST
assert.Equal(t, w.Code, 200) assert.Equal(t, w.Code, 200)
assert.Equal(t, name, "john") assert.Equal(t, name, "john")
assert.Equal(t, lastName, "smith") assert.Equal(t, lastName, "smith")
assert.Equal(t, wild, "/is/super/great")
} }
// TestHandleStaticFile - ensure the static file handles properly // TestHandleStaticFile - ensure the static file handles properly
@ -183,3 +194,70 @@ func TestRouteHeadToDir(t *testing.T) {
assert.Contains(t, bodyAsString, "gin.go") assert.Contains(t, bodyAsString, "gin.go")
assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/html; charset=utf-8") assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/html; charset=utf-8")
} }
func TestRouteNotAllowed(t *testing.T) {
router := New()
router.POST("/path", func(c *Context) {})
w := performRequest(router, "GET", "/path")
assert.Equal(t, w.Code, http.StatusMethodNotAllowed)
router.NoMethod(func(c *Context) {
c.String(http.StatusTeapot, "responseText")
})
w = performRequest(router, "GET", "/path")
assert.Equal(t, w.Body.String(), "responseText")
assert.Equal(t, w.Code, http.StatusTeapot)
}
func TestRouterNotFound(t *testing.T) {
router := New()
router.GET("/path", func(c *Context) {})
router.GET("/dir/", func(c *Context) {})
router.GET("/", func(c *Context) {})
testRoutes := []struct {
route string
code int
header string
}{
{"/path/", 301, "map[Location:[/path]]"}, // TSR -/
{"/dir", 301, "map[Location:[/dir/]]"}, // TSR +/
{"", 301, "map[Location:[/]]"}, // TSR +/
{"/PATH", 301, "map[Location:[/path]]"}, // Fixed Case
{"/DIR/", 301, "map[Location:[/dir/]]"}, // Fixed Case
{"/PATH/", 301, "map[Location:[/path]]"}, // Fixed Case -/
{"/DIR", 301, "map[Location:[/dir/]]"}, // Fixed Case +/
{"/../path", 301, "map[Location:[/path]]"}, // CleanPath
{"/nope", 404, ""}, // NotFound
}
for _, tr := range testRoutes {
w := performRequest(router, "GET", tr.route)
assert.Equal(t, w.Code, tr.code)
if w.Code != 404 {
assert.Equal(t, fmt.Sprint(w.Header()), tr.header)
}
}
// Test custom not found handler
var notFound bool
router.NoRoute(func(c *Context) {
c.AbortWithStatus(404)
notFound = true
})
w := performRequest(router, "GET", "/nope")
assert.Equal(t, w.Code, 404)
assert.True(t, notFound)
// Test other method than GET (want 307 instead of 301)
router.PATCH("/path", func(c *Context) {})
w = performRequest(router, "PATCH", "/path/")
assert.Equal(t, w.Code, 307)
assert.Equal(t, fmt.Sprint(w.Header()), "map[Location:[/path]]")
// Test special case where no node for the prefix "/" exists
router = New()
router.GET("/a", func(c *Context) {})
w = performRequest(router, "GET", "/")
assert.Equal(t, w.Code, 404)
}

41
tree.go
View File

@ -78,6 +78,7 @@ func (n *node) incrementChildPrio(pos int) int {
// addRoute adds a node with the given handle to the path. // addRoute adds a node with the given handle to the path.
// Not concurrency-safe! // Not concurrency-safe!
func (n *node) addRoute(path string, handlers []HandlerFunc) { func (n *node) addRoute(path string, handlers []HandlerFunc) {
fullPath := path
n.priority++ n.priority++
numParams := countParams(path) numParams := countParams(path)
@ -147,7 +148,9 @@ func (n *node) addRoute(path string, handlers []HandlerFunc) {
} }
} }
panic("conflict with wildcard route") panic("path segment '" + path +
"' conflicts with existing wildcard '" + n.path +
"' in path '" + fullPath + "'")
} }
c := path[0] c := path[0]
@ -179,23 +182,23 @@ func (n *node) addRoute(path string, handlers []HandlerFunc) {
n.incrementChildPrio(len(n.indices) - 1) n.incrementChildPrio(len(n.indices) - 1)
n = child n = child
} }
n.insertChild(numParams, path, handlers) n.insertChild(numParams, path, fullPath, handlers)
return return
} else if i == len(path) { // Make node a (in-path) leaf } else if i == len(path) { // Make node a (in-path) leaf
if n.handlers != nil { if n.handlers != nil {
panic("a Handle is already registered for this path") panic("handlers are already registered for path ''" + fullPath + "'")
} }
n.handlers = handlers n.handlers = handlers
} }
return return
} }
} else { // Empty tree } else { // Empty tree
n.insertChild(numParams, path, handlers) n.insertChild(numParams, path, fullPath, handlers)
} }
} }
func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc) { func (n *node) insertChild(numParams uint8, path string, fullPath string, handlers []HandlerFunc) {
var offset int // already handled bytes of the path var offset int // already handled bytes of the path
// find prefix until first wildcard (beginning with ':'' or '*'') // find prefix until first wildcard (beginning with ':'' or '*'')
@ -205,27 +208,29 @@ func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc)
continue continue
} }
// check if this Node existing children which would be
// unreachable if we insert the wildcard here
if len(n.children) > 0 {
panic("wildcard route conflicts with existing children")
}
// find wildcard end (either '/' or path end) // find wildcard end (either '/' or path end)
end := i + 1 end := i + 1
for end < max && path[end] != '/' { for end < max && path[end] != '/' {
switch path[end] { switch path[end] {
// the wildcard name must not contain ':' and '*' // the wildcard name must not contain ':' and '*'
case ':', '*': case ':', '*':
panic("only one wildcard per path segment is allowed") panic("only one wildcard per path segment is allowed, has: '" +
path[i:] + "' in path '" + fullPath + "'")
default: default:
end++ end++
} }
} }
// check if this Node existing children which would be
// unreachable if we insert the wildcard here
if len(n.children) > 0 {
panic("wildcard route '" + path[i:end] +
"' conflicts with existing children in path '" + fullPath + "'")
}
// check if the wildcard has a name // check if the wildcard has a name
if end-i < 2 { if end-i < 2 {
panic("wildcards must be named with a non-empty name") panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
} }
if c == ':' { // param if c == ':' { // param
@ -261,17 +266,17 @@ func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc)
} else { // catchAll } else { // catchAll
if end != max || numParams > 1 { if end != max || numParams > 1 {
panic("catch-all routes are only allowed at the end of the path") panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
} }
if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
panic("catch-all conflicts with existing handle for the path segment root") panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
} }
// currently fixed width 1 for '/' // currently fixed width 1 for '/'
i-- i--
if path[i] != '/' { if path[i] != '/' {
panic("no / before catch-all") panic("no / before catch-all in path '" + fullPath + "'")
} }
n.path = path[offset:i] n.path = path[offset:i]
@ -394,7 +399,7 @@ walk: // Outer loop for walking the tree
return return
default: default:
panic("Invalid node type") panic("invalid node type")
} }
} }
} else if path == n.path { } else if path == n.path {
@ -505,7 +510,7 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPa
return append(ciPath, path...), true return append(ciPath, path...), true
default: default:
panic("Invalid node type") panic("invalid node type")
} }
} else { } else {
// We should have reached the node containing the handle. // We should have reached the node containing the handle.

View File

@ -357,7 +357,7 @@ func TestTreeDoubleWildcard(t *testing.T) {
tree.addRoute(route, nil) tree.addRoute(route, nil)
}) })
if rs, ok := recv.(string); !ok || rs != panicMsg { if rs, ok := recv.(string); !ok || !strings.HasPrefix(rs, panicMsg) {
t.Fatalf(`"Expected panic "%s" for route '%s', got "%v"`, panicMsg, route, recv) t.Fatalf(`"Expected panic "%s" for route '%s', got "%v"`, panicMsg, route, recv)
} }
} }
@ -594,15 +594,15 @@ func TestTreeInvalidNodeType(t *testing.T) {
recv := catchPanic(func() { recv := catchPanic(func() {
tree.getValue("/test", nil) tree.getValue("/test", nil)
}) })
if rs, ok := recv.(string); !ok || rs != "Invalid node type" { if rs, ok := recv.(string); !ok || rs != "invalid node type" {
t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv) t.Fatalf(`Expected panic "invalid node type", got "%v"`, recv)
} }
// case-insensitive lookup // case-insensitive lookup
recv = catchPanic(func() { recv = catchPanic(func() {
tree.findCaseInsensitivePath("/test", true) tree.findCaseInsensitivePath("/test", true)
}) })
if rs, ok := recv.(string); !ok || rs != "Invalid node type" { if rs, ok := recv.(string); !ok || rs != "invalid node type" {
t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv) t.Fatalf(`Expected panic "invalid node type", got "%v"`, recv)
} }
} }