diff --git a/middleware_test.go b/middleware_test.go index 4ae367a9..a0fac56d 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -9,6 +9,7 @@ import ( "testing" + "github.com/manucorporat/sse" "github.com/stretchr/testify/assert" ) @@ -27,10 +28,10 @@ func TestMiddlewareGeneralCase(t *testing.T) { signature += "D" }) router.NoRoute(func(c *Context) { - signature += "X" + signature += " X " }) router.NoMethod(func(c *Context) { - signature += "X" + signature += " XX " }) // RUN w := performRequest(router, "GET", "/") @@ -40,8 +41,7 @@ func TestMiddlewareGeneralCase(t *testing.T) { assert.Equal(t, signature, "ACDB") } -// TestBadAbortHandlersChain - ensure that Abort after switch context will not interrupt pending handlers -func TestMiddlewareNextOrder(t *testing.T) { +func TestMiddlewareNoRoute(t *testing.T) { signature := "" router := New() router.Use(func(c *Context) { @@ -52,6 +52,9 @@ func TestMiddlewareNextOrder(t *testing.T) { router.Use(func(c *Context) { signature += "C" c.Next() + c.Next() + c.Next() + c.Next() signature += "D" }) router.NoRoute(func(c *Context) { @@ -63,6 +66,9 @@ func TestMiddlewareNextOrder(t *testing.T) { c.Next() signature += "H" }) + router.NoMethod(func(c *Context) { + signature += " X " + }) // RUN w := performRequest(router, "GET", "/") @@ -71,8 +77,43 @@ func TestMiddlewareNextOrder(t *testing.T) { assert.Equal(t, signature, "ACEGHFDB") } -// TestAbortHandlersChain - ensure that Abort interrupt used middlewares in fifo order -func TestMiddlewareAbortHandlersChain(t *testing.T) { +func TestMiddlewareNoMethod(t *testing.T) { + signature := "" + router := New() + router.Use(func(c *Context) { + signature += "A" + c.Next() + signature += "B" + }) + router.Use(func(c *Context) { + signature += "C" + c.Next() + signature += "D" + }) + router.NoMethod(func(c *Context) { + signature += "E" + c.Next() + signature += "F" + }, func(c *Context) { + signature += "G" + c.Next() + signature += "H" + }) + router.NoRoute(func(c *Context) { + signature += " X " + }) + router.POST("/", func(c *Context) { + signature += " XX " + }) + // RUN + w := performRequest(router, "GET", "/") + + // TEST + assert.Equal(t, w.Code, 405) + assert.Equal(t, signature, "ACEGHFDB") +} + +func TestMiddlewareAbort(t *testing.T) { signature := "" router := New() router.Use(func(c *Context) { @@ -80,21 +121,21 @@ func TestMiddlewareAbortHandlersChain(t *testing.T) { }) router.Use(func(c *Context) { signature += "C" - c.AbortWithStatus(409) + c.AbortWithStatus(401) c.Next() signature += "D" }) router.GET("/", func(c *Context) { - signature += "D" + signature += " X " c.Next() - signature += "E" + signature += " XX " }) // RUN w := performRequest(router, "GET", "/") // TEST - assert.Equal(t, w.Code, 409) + assert.Equal(t, w.Code, 401) assert.Equal(t, signature, "ACD") } @@ -103,8 +144,8 @@ func TestMiddlewareAbortHandlersChainAndNext(t *testing.T) { router := New() router.Use(func(c *Context) { signature += "A" - c.AbortWithStatus(410) c.Next() + c.AbortWithStatus(410) signature += "B" }) @@ -117,7 +158,7 @@ func TestMiddlewareAbortHandlersChainAndNext(t *testing.T) { // TEST assert.Equal(t, w.Code, 410) - assert.Equal(t, signature, "AB") + assert.Equal(t, signature, "ACB") } // TestFailHandlersChain - ensure that Fail interrupt used middlewares in fifo order as @@ -142,3 +183,35 @@ func TestMiddlewareFailHandlersChain(t *testing.T) { assert.Equal(t, w.Code, 500) assert.Equal(t, signature, "A") } + +func TestMiddlewareWrite(t *testing.T) { + router := New() + router.Use(func(c *Context) { + c.String(400, "hola\n") + }) + router.Use(func(c *Context) { + c.XML(400, H{"foo": "bar"}) + }) + router.Use(func(c *Context) { + c.JSON(400, H{"foo": "bar"}) + }) + router.GET("/", func(c *Context) { + c.JSON(400, H{"foo": "bar"}) + }, func(c *Context) { + c.Render(400, sse.Event{ + Event: "test", + Data: "message", + }) + }) + + w := performRequest(router, "GET", "/") + + assert.Equal(t, w.Code, 400) + assert.Equal(t, w.Body.String(), `hola +bar{"foo":"bar"} +{"foo":"bar"} +event: test +data: message + +`) +} diff --git a/routergroup.go b/routergroup.go index fe5a2fea..7d0a00fc 100644 --- a/routergroup.go +++ b/routergroup.go @@ -119,11 +119,14 @@ func (group *RouterGroup) StaticFile(relativePath, filepath string) { // use : // router.Static("/static", "/var/www") func (group *RouterGroup) Static(relativePath, root string) { - group.StaticFS(relativePath, http.Dir(root)) + group.StaticFS(relativePath, http.Dir(root), false) } -func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) { - handler := group.createStaticHandler(relativePath, fs) +func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem, listDirectory bool) { + if strings.Contains(relativePath, ":") || strings.Contains(relativePath, "*") { + panic("URL parameters can not be used when serving a static folder") + } + handler := group.createStaticHandler(relativePath, fs, listDirectory) relativePath = path.Join(relativePath, "/*filepath") // Register GET and HEAD handlers @@ -131,10 +134,16 @@ func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) { group.HEAD(relativePath, handler) } -func (group *RouterGroup) createStaticHandler(relativePath string, fs http.FileSystem) func(*Context) { +func (group *RouterGroup) createStaticHandler(relativePath string, fs http.FileSystem, listDirectory bool) HandlerFunc { absolutePath := group.calculateAbsolutePath(relativePath) fileServer := http.StripPrefix(absolutePath, http.FileServer(fs)) - return WrapH(fileServer) + return func(c *Context) { + if !listDirectory && lastChar(c.Request.URL.Path) == '/' { + http.NotFound(c.Writer, c.Request) + return + } + fileServer.ServeHTTP(c.Writer, c.Request) + } } func (group *RouterGroup) combineHandlers(handlers HandlersChain) HandlersChain { diff --git a/routergroup_test.go b/routergroup_test.go index 1ec70ddb..25ddda69 100644 --- a/routergroup_test.go +++ b/routergroup_test.go @@ -88,6 +88,17 @@ func performRequestInGroup(t *testing.T, method string) { assert.Equal(t, w.Body.String(), "the method was "+method+" and index 1") } +func TestRouterGroupInvalidStatic(t *testing.T) { + router := New() + assert.Panics(t, func() { + router.Static("/path/:param", "/") + }) + + assert.Panics(t, func() { + router.Static("/path/*param", "/") + }) +} + func TestRouterGroupInvalidStaticFile(t *testing.T) { router := New() assert.Panics(t, func() { diff --git a/routes_test.go b/routes_test.go index af7e95d3..e547e181 100644 --- a/routes_test.go +++ b/routes_test.go @@ -78,32 +78,41 @@ func testRouteNotOK2(method string, t *testing.T) { } func TestRouterGroupRouteOK(t *testing.T) { + testRouteOK("GET", t) testRouteOK("POST", t) - testRouteOK("DELETE", t) - testRouteOK("PATCH", t) testRouteOK("PUT", t) - testRouteOK("OPTIONS", t) + testRouteOK("PATCH", t) testRouteOK("HEAD", t) + testRouteOK("OPTIONS", t) + testRouteOK("DELETE", t) + testRouteOK("CONNECT", t) + testRouteOK("TRACE", t) } // TestSingleRouteOK tests that POST route is correctly invoked. func TestRouteNotOK(t *testing.T) { + testRouteNotOK("GET", t) testRouteNotOK("POST", t) - testRouteNotOK("DELETE", t) - testRouteNotOK("PATCH", t) testRouteNotOK("PUT", t) - testRouteNotOK("OPTIONS", t) + testRouteNotOK("PATCH", t) testRouteNotOK("HEAD", t) + testRouteNotOK("OPTIONS", t) + testRouteNotOK("DELETE", t) + testRouteNotOK("CONNECT", t) + testRouteNotOK("TRACE", t) } // TestSingleRouteOK tests that POST route is correctly invoked. func TestRouteNotOK2(t *testing.T) { + testRouteNotOK2("GET", t) testRouteNotOK2("POST", t) - testRouteNotOK2("DELETE", t) - testRouteNotOK2("PATCH", t) testRouteNotOK2("PUT", t) - testRouteNotOK2("OPTIONS", t) + testRouteNotOK2("PATCH", t) testRouteNotOK2("HEAD", t) + testRouteNotOK2("OPTIONS", t) + testRouteNotOK2("DELETE", t) + testRouteNotOK2("CONNECT", t) + testRouteNotOK2("TRACE", t) } // TestContextParamsGet tests that a parameter can be parsed from the URL. @@ -142,25 +151,35 @@ func TestRouteStaticFile(t *testing.T) { t.Error(err) } defer os.Remove(f.Name()) - filePath := path.Join("/", path.Base(f.Name())) f.WriteString("Gin Web Framework") f.Close() + dir, filename := path.Split(f.Name()) + // SETUP gin router := New() - router.Static("./", testRoot) + router.Static("/using_static", dir) + router.StaticFile("/result", f.Name()) - w := performRequest(router, "GET", filePath) + w := performRequest(router, "GET", "/using_static/"+filename) + w2 := performRequest(router, "GET", "/result") + assert.Equal(t, w, w2) assert.Equal(t, w.Code, 200) assert.Equal(t, w.Body.String(), "Gin Web Framework") assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8") + + w3 := performRequest(router, "HEAD", "/using_static/"+filename) + w4 := performRequest(router, "HEAD", "/result") + + assert.Equal(t, w3, w4) + assert.Equal(t, w3.Code, 200) } // TestHandleStaticDir - ensure the root/sub dir handles properly -func TestRouteStaticDir(t *testing.T) { +func TestRouteStaticListingDir(t *testing.T) { router := New() - router.Static("/", "./") + router.StaticFS("/", http.Dir("./"), true) w := performRequest(router, "GET", "/") @@ -170,15 +189,14 @@ func TestRouteStaticDir(t *testing.T) { } // TestHandleHeadToDir - ensure the root/sub dir handles properly -func TestRouteHeadToDir(t *testing.T) { +func TestRouteStaticNoListing(t *testing.T) { router := New() router.Static("/", "./") - w := performRequest(router, "HEAD", "/") + w := performRequest(router, "GET", "/") - assert.Equal(t, w.Code, 200) - assert.Contains(t, w.Body.String(), "gin.go") - assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/html; charset=utf-8") + assert.Equal(t, w.Code, 404) + assert.NotContains(t, w.Body.String(), "gin.go") } func TestRouterMiddlewareAndStatic(t *testing.T) { @@ -190,11 +208,11 @@ func TestRouterMiddlewareAndStatic(t *testing.T) { }) static.Static("/", "./") - w := performRequest(router, "GET", "/") + w := performRequest(router, "GET", "/gin.go") assert.Equal(t, w.Code, 200) - assert.Contains(t, w.Body.String(), "gin.go") - assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/html; charset=utf-8") + assert.Contains(t, w.Body.String(), "package gin") + assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8") assert.NotEqual(t, w.HeaderMap.Get("Last-Modified"), "Mon, 02 Jan 2006 15:04:05 MST") assert.Equal(t, w.HeaderMap.Get("Expires"), "Mon, 02 Jan 2006 15:04:05 MST") assert.Equal(t, w.HeaderMap.Get("x-GIN"), "Gin Framework")