From f212ae77289674a64f725bf650841e14b8f98613 Mon Sep 17 00:00:00 2001 From: Manu Mtz-Almeida Date: Tue, 5 May 2015 16:37:33 +0200 Subject: [PATCH] Updates tree.go + fixes + unit tests --- gin.go | 2 ++ gin_test.go | 3 ++ routes_test.go | 82 ++++++++++++++++++++++++++++++++++++++++++++++++-- tree.go | 41 ++++++++++++++----------- tree_test.go | 10 +++--- 5 files changed, 113 insertions(+), 25 deletions(-) diff --git a/gin.go b/gin.go index aa6c7700..4151cae4 100644 --- a/gin.go +++ b/gin.go @@ -233,6 +233,7 @@ func (engine *Engine) serveAutoRedirect(c *Context, root *node, tsr bool) bool { } debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String()) http.Redirect(c.Writer, req, req.URL.String(), code) + c.writermem.WriteHeaderNow() return true } @@ -246,6 +247,7 @@ func (engine *Engine) serveAutoRedirect(c *Context, root *node, tsr bool) bool { req.URL.Path = string(fixedPath) debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String()) http.Redirect(c.Writer, req, req.URL.String(), code) + c.writermem.WriteHeaderNow() return true } } diff --git a/gin_test.go b/gin_test.go index 36877be9..ec0ad6b3 100644 --- a/gin_test.go +++ b/gin_test.go @@ -25,6 +25,9 @@ func TestCreateEngine(t *testing.T) { assert.Equal(t, "/", router.absolutePath) assert.Equal(t, router.engine, router) assert.Empty(t, router.Handlers) + assert.True(t, router.RedirectTrailingSlash) + assert.True(t, router.RedirectFixedPath) + assert.True(t, router.HandleMethodNotAllowed) assert.Panics(t, func() { router.handle("", "/", []HandlerFunc{func(_ *Context) {}}) }) assert.Panics(t, func() { router.handle("GET", "", []HandlerFunc{func(_ *Context) {}}) }) diff --git a/routes_test.go b/routes_test.go index 5c8821ed..2c34c92f 100644 --- a/routes_test.go +++ b/routes_test.go @@ -5,6 +5,7 @@ package gin import ( + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -110,18 +111,28 @@ func TestRouteNotOK2(t *testing.T) { func TestRouteParamsByName(t *testing.T) { name := "" lastName := "" + wild := "" router := New() - router.GET("/test/:name/:last_name", func(c *Context) { + router.GET("/test/:name/:last_name/*wild", func(c *Context) { name = c.Params.ByName("name") lastName = c.Params.ByName("last_name") + wild = c.Params.ByName("wild") + + assert.Equal(t, name, c.ParamValue("name")) + assert.Equal(t, lastName, c.ParamValue("last_name")) + + assert.Equal(t, name, c.DefaultParamValue("name", "nothing")) + assert.Equal(t, lastName, c.DefaultParamValue("last_name", "nothing")) + assert.Equal(t, c.DefaultParamValue("noKey", "default"), "default") }) // RUN - w := performRequest(router, "GET", "/test/john/smith") + w := performRequest(router, "GET", "/test/john/smith/is/super/great") // TEST assert.Equal(t, w.Code, 200) assert.Equal(t, name, "john") assert.Equal(t, lastName, "smith") + assert.Equal(t, wild, "/is/super/great") } // TestHandleStaticFile - ensure the static file handles properly @@ -183,3 +194,70 @@ func TestRouteHeadToDir(t *testing.T) { assert.Contains(t, bodyAsString, "gin.go") assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/html; charset=utf-8") } + +func TestRouteNotAllowed(t *testing.T) { + router := New() + + router.POST("/path", func(c *Context) {}) + w := performRequest(router, "GET", "/path") + assert.Equal(t, w.Code, http.StatusMethodNotAllowed) + + router.NoMethod(func(c *Context) { + c.String(http.StatusTeapot, "responseText") + }) + w = performRequest(router, "GET", "/path") + assert.Equal(t, w.Body.String(), "responseText") + assert.Equal(t, w.Code, http.StatusTeapot) +} + +func TestRouterNotFound(t *testing.T) { + router := New() + router.GET("/path", func(c *Context) {}) + router.GET("/dir/", func(c *Context) {}) + router.GET("/", func(c *Context) {}) + + testRoutes := []struct { + route string + code int + header string + }{ + {"/path/", 301, "map[Location:[/path]]"}, // TSR -/ + {"/dir", 301, "map[Location:[/dir/]]"}, // TSR +/ + {"", 301, "map[Location:[/]]"}, // TSR +/ + {"/PATH", 301, "map[Location:[/path]]"}, // Fixed Case + {"/DIR/", 301, "map[Location:[/dir/]]"}, // Fixed Case + {"/PATH/", 301, "map[Location:[/path]]"}, // Fixed Case -/ + {"/DIR", 301, "map[Location:[/dir/]]"}, // Fixed Case +/ + {"/../path", 301, "map[Location:[/path]]"}, // CleanPath + {"/nope", 404, ""}, // NotFound + } + for _, tr := range testRoutes { + w := performRequest(router, "GET", tr.route) + assert.Equal(t, w.Code, tr.code) + if w.Code != 404 { + assert.Equal(t, fmt.Sprint(w.Header()), tr.header) + } + } + + // Test custom not found handler + var notFound bool + router.NoRoute(func(c *Context) { + c.AbortWithStatus(404) + notFound = true + }) + w := performRequest(router, "GET", "/nope") + assert.Equal(t, w.Code, 404) + assert.True(t, notFound) + + // Test other method than GET (want 307 instead of 301) + router.PATCH("/path", func(c *Context) {}) + w = performRequest(router, "PATCH", "/path/") + assert.Equal(t, w.Code, 307) + assert.Equal(t, fmt.Sprint(w.Header()), "map[Location:[/path]]") + + // Test special case where no node for the prefix "/" exists + router = New() + router.GET("/a", func(c *Context) {}) + w = performRequest(router, "GET", "/") + assert.Equal(t, w.Code, 404) +} diff --git a/tree.go b/tree.go index 9cd04fe8..8cd67e70 100644 --- a/tree.go +++ b/tree.go @@ -78,6 +78,7 @@ func (n *node) incrementChildPrio(pos int) int { // addRoute adds a node with the given handle to the path. // Not concurrency-safe! func (n *node) addRoute(path string, handlers []HandlerFunc) { + fullPath := path n.priority++ numParams := countParams(path) @@ -147,7 +148,9 @@ func (n *node) addRoute(path string, handlers []HandlerFunc) { } } - panic("conflict with wildcard route") + panic("path segment '" + path + + "' conflicts with existing wildcard '" + n.path + + "' in path '" + fullPath + "'") } c := path[0] @@ -179,23 +182,23 @@ func (n *node) addRoute(path string, handlers []HandlerFunc) { n.incrementChildPrio(len(n.indices) - 1) n = child } - n.insertChild(numParams, path, handlers) + n.insertChild(numParams, path, fullPath, handlers) return } else if i == len(path) { // Make node a (in-path) leaf if n.handlers != nil { - panic("a Handle is already registered for this path") + panic("handlers are already registered for path ''" + fullPath + "'") } n.handlers = handlers } return } } else { // Empty tree - n.insertChild(numParams, path, handlers) + n.insertChild(numParams, path, fullPath, handlers) } } -func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc) { +func (n *node) insertChild(numParams uint8, path string, fullPath string, handlers []HandlerFunc) { var offset int // already handled bytes of the path // find prefix until first wildcard (beginning with ':'' or '*'') @@ -205,27 +208,29 @@ func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc) continue } - // check if this Node existing children which would be - // unreachable if we insert the wildcard here - if len(n.children) > 0 { - panic("wildcard route conflicts with existing children") - } - // find wildcard end (either '/' or path end) end := i + 1 for end < max && path[end] != '/' { switch path[end] { // the wildcard name must not contain ':' and '*' case ':', '*': - panic("only one wildcard per path segment is allowed") + panic("only one wildcard per path segment is allowed, has: '" + + path[i:] + "' in path '" + fullPath + "'") default: end++ } } + // check if this Node existing children which would be + // unreachable if we insert the wildcard here + if len(n.children) > 0 { + panic("wildcard route '" + path[i:end] + + "' conflicts with existing children in path '" + fullPath + "'") + } + // check if the wildcard has a name if end-i < 2 { - panic("wildcards must be named with a non-empty name") + panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") } if c == ':' { // param @@ -261,17 +266,17 @@ func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc) } else { // catchAll if end != max || numParams > 1 { - panic("catch-all routes are only allowed at the end of the path") + panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'") } if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { - panic("catch-all conflicts with existing handle for the path segment root") + panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'") } // currently fixed width 1 for '/' i-- if path[i] != '/' { - panic("no / before catch-all") + panic("no / before catch-all in path '" + fullPath + "'") } n.path = path[offset:i] @@ -394,7 +399,7 @@ walk: // Outer loop for walking the tree return default: - panic("Invalid node type") + panic("invalid node type") } } } else if path == n.path { @@ -505,7 +510,7 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPa return append(ciPath, path...), true default: - panic("Invalid node type") + panic("invalid node type") } } else { // We should have reached the node containing the handle. diff --git a/tree_test.go b/tree_test.go index 50f2fc44..800e7512 100644 --- a/tree_test.go +++ b/tree_test.go @@ -357,7 +357,7 @@ func TestTreeDoubleWildcard(t *testing.T) { tree.addRoute(route, nil) }) - if rs, ok := recv.(string); !ok || rs != panicMsg { + if rs, ok := recv.(string); !ok || !strings.HasPrefix(rs, panicMsg) { t.Fatalf(`"Expected panic "%s" for route '%s', got "%v"`, panicMsg, route, recv) } } @@ -594,15 +594,15 @@ func TestTreeInvalidNodeType(t *testing.T) { recv := catchPanic(func() { tree.getValue("/test", nil) }) - if rs, ok := recv.(string); !ok || rs != "Invalid node type" { - t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv) + if rs, ok := recv.(string); !ok || rs != "invalid node type" { + t.Fatalf(`Expected panic "invalid node type", got "%v"`, recv) } // case-insensitive lookup recv = catchPanic(func() { tree.findCaseInsensitivePath("/test", true) }) - if rs, ok := recv.(string); !ok || rs != "Invalid node type" { - t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv) + if rs, ok := recv.(string); !ok || rs != "invalid node type" { + t.Fatalf(`Expected panic "invalid node type", got "%v"`, recv) } }