diff --git a/debug.go b/debug.go index a0b99f43..e5b39cc2 100644 --- a/debug.go +++ b/debug.go @@ -16,7 +16,7 @@ func IsDebugging() bool { func debugPrintRoute(httpMethod, absolutePath string, handlers HandlersChain) { if IsDebugging() { nuHandlers := len(handlers) - handlerName := nameOfFunction(handlers[nuHandlers-1]) + handlerName := nameOfFunction(handlers.Last()) debugPrint("%-5s %-25s --> %s (%d handlers)\n", httpMethod, absolutePath, handlerName, nuHandlers) } } diff --git a/gin.go b/gin.go index 8130b88f..580f853a 100644 --- a/gin.go +++ b/gin.go @@ -61,8 +61,22 @@ type ( HandleMethodNotAllowed bool ForwardedByClientIP bool } + + RouteInfo struct { + Method string + Path string + Handler string + } ) +func (c HandlersChain) Last() HandlerFunc { + length := len(c) + if length > 0 { + return c[length-1] + } + return nil +} + // Returns a new blank Engine instance without any middleware attached. // The most basic configuration func New() *Engine { @@ -181,6 +195,28 @@ func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { root.addRoute(path, handlers) } +func (engine *Engine) Routes() (routes []RouteInfo) { + for _, tree := range engine.trees { + routes = iterate("", tree.method, routes, tree.root) + } + return routes +} + +func iterate(path, method string, routes []RouteInfo, root *node) []RouteInfo { + path += root.path + if len(root.handlers) > 0 { + routes = append(routes, RouteInfo{ + Method: method, + Path: path, + Handler: nameOfFunction(root.handlers.Last()), + }) + } + for _, node := range root.children { + routes = iterate(path, method, routes, node) + } + return routes +} + // The router is attached to a http.Server and starts listening and serving HTTP requests. // It is a shortcut for http.ListenAndServe(addr, router) // Note: this method will block the calling goroutine undefinitelly unless an error happens. diff --git a/gin_test.go b/gin_test.go index 28bba734..9d8fffa1 100644 --- a/gin_test.go +++ b/gin_test.go @@ -14,7 +14,6 @@ import ( //TODO // func (engine *Engine) LoadHTMLGlob(pattern string) { // func (engine *Engine) LoadHTMLFiles(files ...string) { -// func (engine *Engine) Run(addr string) error { // func (engine *Engine) RunTLS(addr string, cert string, key string) error { func init() { @@ -180,3 +179,50 @@ func compareFunc(t *testing.T, a, b interface{}) { t.Error("different functions") } } + +func TestListOfRoutes(t *testing.T) { + handler := func(c *Context) {} + router := New() + router.GET("/favicon.ico", handler) + router.GET("/", handler) + group := router.Group("/users") + { + group.GET("/", handler) + group.GET("/:id", handler) + group.POST("/:id", handler) + } + router.Static("/static", ".") + + list := router.Routes() + + assert.Len(t, list, 7) + assert.Contains(t, list, RouteInfo{ + Method: "GET", + Path: "/favicon.ico", + }) + assert.Contains(t, list, RouteInfo{ + Method: "GET", + Path: "/", + }) + assert.Contains(t, list, RouteInfo{ + Method: "GET", + Path: "/users/", + }) + assert.Contains(t, list, RouteInfo{ + Method: "GET", + Path: "/users/:id", + }) + assert.Contains(t, list, RouteInfo{ + Method: "POST", + Path: "/users/:id", + }) + assert.Contains(t, list, RouteInfo{ + Method: "GET", + Path: "/static/*filepath", + }) + assert.Contains(t, list, RouteInfo{ + Method: "HEAD", + Path: "/static/*filepath", + }) + +}