diff --git a/AUTHORS.md b/AUTHORS.md index a477611b..c634e6be 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -190,6 +190,8 @@ People and companies, who have contributed, in alphabetical order. **@rogierlommers (Rogier Lommers)** - Add updated static serve example +**@rw-access (Ross Wolf)** +- Added support to mix exact and param routes **@se77en (Damon Zhao)** - Improve color logging diff --git a/CHANGELOG.md b/CHANGELOG.md index ddf30e18..dc2c2f55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,44 @@ # Gin ChangeLog +## Gin v1.7.1 + +### BUGFIXES + +* fix: data race with trustedCIDRs from [#2674](https://github.com/gin-gonic/gin/issues/2674)([#2675](https://github.com/gin-gonic/gin/pull/2675)) + +## Gin v1.7.0 + +### BUGFIXES + +* fix compile error from [#2572](https://github.com/gin-gonic/gin/pull/2572) ([#2600](https://github.com/gin-gonic/gin/pull/2600)) +* fix: print headers without Authorization header on broken pipe ([#2528](https://github.com/gin-gonic/gin/pull/2528)) +* fix(tree): reassign fullpath when register new node ([#2366](https://github.com/gin-gonic/gin/pull/2366)) + +### ENHANCEMENTS + +* Support params and exact routes without creating conflicts ([#2663](https://github.com/gin-gonic/gin/pull/2663)) +* chore: improve render string performance ([#2365](https://github.com/gin-gonic/gin/pull/2365)) +* Sync route tree to httprouter latest code ([#2368](https://github.com/gin-gonic/gin/pull/2368)) +* chore: rename getQueryCache/getFormCache to initQueryCache/initFormCa ([#2375](https://github.com/gin-gonic/gin/pull/2375)) +* chore(performance): improve countParams ([#2378](https://github.com/gin-gonic/gin/pull/2378)) +* Remove some functions that have the same effect as the bytes package ([#2387](https://github.com/gin-gonic/gin/pull/2387)) +* update:SetMode function ([#2321](https://github.com/gin-gonic/gin/pull/2321)) +* remove a unused type SecureJSONPrefix ([#2391](https://github.com/gin-gonic/gin/pull/2391)) +* Add a redirect sample for POST method ([#2389](https://github.com/gin-gonic/gin/pull/2389)) +* Add CustomRecovery builtin middleware ([#2322](https://github.com/gin-gonic/gin/pull/2322)) +* binding: avoid 2038 problem on 32-bit architectures ([#2450](https://github.com/gin-gonic/gin/pull/2450)) +* Prevent panic in Context.GetQuery() when there is no Request ([#2412](https://github.com/gin-gonic/gin/pull/2412)) +* Add GetUint and GetUint64 method on gin.context ([#2487](https://github.com/gin-gonic/gin/pull/2487)) +* update content-disposition header to MIME-style ([#2512](https://github.com/gin-gonic/gin/pull/2512)) +* reduce allocs and improve the render `WriteString` ([#2508](https://github.com/gin-gonic/gin/pull/2508)) +* implement ".Unwrap() error" on Error type ([#2525](https://github.com/gin-gonic/gin/pull/2525)) ([#2526](https://github.com/gin-gonic/gin/pull/2526)) +* Allow bind with a map[string]string ([#2484](https://github.com/gin-gonic/gin/pull/2484)) +* chore: update tree ([#2371](https://github.com/gin-gonic/gin/pull/2371)) +* Support binding for slice/array obj [Rewrite] ([#2302](https://github.com/gin-gonic/gin/pull/2302)) +* basic auth: fix timing oracle ([#2609](https://github.com/gin-gonic/gin/pull/2609)) +* Add mixed param and non-param paths (port of httprouter[#329](https://github.com/gin-gonic/gin/pull/329)) ([#2663](https://github.com/gin-gonic/gin/pull/2663)) +* feat(engine): add trustedproxies and remoteIP ([#2632](https://github.com/gin-gonic/gin/pull/2632)) + ## Gin v1.6.3 ### ENHANCEMENTS diff --git a/README.md b/README.md index 119f9452..d4772d76 100644 --- a/README.md +++ b/README.md @@ -243,6 +243,13 @@ func main() { c.FullPath() == "/user/:name/*action" // true }) + // This handler will add a new router for /user/groups. + // Exact routes are resolved before param routes, regardless of the order they were defined. + // Routes starting with /user/groups are never interpreted as /user/:name/... routes + router.GET("/user/groups", func(c *gin.Context) { + c.String(http.StatusOK, "The available groups are [...]", name) + }) + router.Run(":8080") } ``` @@ -2117,6 +2124,39 @@ func main() { } ``` +## Don't trust all proxies + +Gin lets you specify which headers to hold the real client IP (if any), +as well as specifying which proxies (or direct clients) you trust to +specify one of these headers. + +The `TrustedProxies` slice on your `gin.Engine` specifes network addresses or +network CIDRs from where clients which their request headers related to client +IP can be trusted. They can be IPv4 addresses, IPv4 CIDRs, IPv6 addresses or +IPv6 CIDRs. + +```go +import ( + "fmt" + + "github.com/gin-gonic/gin" +) + +func main() { + + router := gin.Default() + router.TrustedProxies = []string{"192.168.1.2"} + + router.GET("/", func(c *gin.Context) { + // If the client is 192.168.1.2, use the X-Forwarded-For + // header to deduce the original client IP from the trust- + // worthy parts of that header. + // Otherwise, simply return the direct client IP + fmt.Printf("ClientIP: %s\n", c.ClientIP()) + }) + router.Run() +} +``` ## Testing diff --git a/context.go b/context.go index 598dda80..e1b557bc 100644 --- a/context.go +++ b/context.go @@ -730,32 +730,80 @@ func (c *Context) ShouldBindBodyWith(obj interface{}, bb binding.BindingBody) (e return bb.BindBody(body, obj) } -// ClientIP implements a best effort algorithm to return the real client IP, it parses -// X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy. -// Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP. +// ClientIP implements a best effort algorithm to return the real client IP. +// It called c.RemoteIP() under the hood, to check if the remote IP is a trusted proxy or not. +// If it's it will then try to parse the headers defined in Engine.RemoteIPHeaders (defaulting to [X-Forwarded-For, X-Real-Ip]). +// If the headers are nots syntactically valid OR the remote IP does not correspong to a trusted proxy, +// the remote IP (coming form Request.RemoteAddr) is returned. func (c *Context) ClientIP() string { - if c.engine.ForwardedByClientIP { - clientIP := c.requestHeader("X-Forwarded-For") - clientIP = strings.TrimSpace(strings.Split(clientIP, ",")[0]) - if clientIP == "" { - clientIP = strings.TrimSpace(c.requestHeader("X-Real-Ip")) - } - if clientIP != "" { - return clientIP - } - } - if c.engine.AppEngine { if addr := c.requestHeader("X-Appengine-Remote-Addr"); addr != "" { return addr } } - if ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)); err == nil { - return ip + remoteIP, trusted := c.RemoteIP() + if remoteIP == nil { + return "" } - return "" + if trusted && c.engine.ForwardedByClientIP && c.engine.RemoteIPHeaders != nil { + for _, headerName := range c.engine.RemoteIPHeaders { + ip, valid := validateHeader(c.requestHeader(headerName)) + if valid { + return ip + } + } + } + return remoteIP.String() +} + +// RemoteIP parses the IP from Request.RemoteAddr, normalizes and returns the IP (without the port). +// It also checks if the remoteIP is a trusted proxy or not. +// In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks +// defined in Engine.TrustedProxies +func (c *Context) RemoteIP() (net.IP, bool) { + ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)) + if err != nil { + return nil, false + } + remoteIP := net.ParseIP(ip) + if remoteIP == nil { + return nil, false + } + + if c.engine.trustedCIDRs != nil { + for _, cidr := range c.engine.trustedCIDRs { + if cidr.Contains(remoteIP) { + return remoteIP, true + } + } + } + + return remoteIP, false +} + +func validateHeader(header string) (clientIP string, valid bool) { + if header == "" { + return "", false + } + items := strings.Split(header, ",") + for i, ipStr := range items { + ipStr = strings.TrimSpace(ipStr) + ip := net.ParseIP(ipStr) + if ip == nil { + return "", false + } + + // We need to return the first IP in the list, but, + // we should not early return since we need to validate that + // the rest of the header is syntactically valid + if i == 0 { + clientIP = ipStr + valid = true + } + } + return } // ContentType returns the Content-Type header of the request. diff --git a/context_test.go b/context_test.go index 8e1e3b57..cf3f0be9 100644 --- a/context_test.go +++ b/context_test.go @@ -1388,15 +1388,18 @@ func TestContextAbortWithError(t *testing.T) { assert.True(t, c.IsAborted()) } +func resetTrustedCIDRs(c *Context) { + c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs() +} + func TestContextClientIP(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Request, _ = http.NewRequest("POST", "/", nil) + resetTrustedCIDRs(c) + resetContextForClientIPTests(c) - c.Request.Header.Set("X-Real-IP", " 10.10.10.10 ") - c.Request.Header.Set("X-Forwarded-For", " 20.20.20.20, 30.30.30.30") - c.Request.Header.Set("X-Appengine-Remote-Addr", "50.50.50.50") - c.Request.RemoteAddr = " 40.40.40.40:42123 " - + // Legacy tests (validating that the defaults don't break the + // (insecure!) old behaviour) assert.Equal(t, "20.20.20.20", c.ClientIP()) c.Request.Header.Del("X-Forwarded-For") @@ -1416,6 +1419,84 @@ func TestContextClientIP(t *testing.T) { // no port c.Request.RemoteAddr = "50.50.50.50" assert.Empty(t, c.ClientIP()) + + // Tests exercising the TrustedProxies functionality + resetContextForClientIPTests(c) + + // No trusted proxies + c.engine.TrustedProxies = []string{} + resetTrustedCIDRs(c) + c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"} + assert.Equal(t, "40.40.40.40", c.ClientIP()) + + // Last proxy is trusted, but the RemoteAddr is not + c.engine.TrustedProxies = []string{"30.30.30.30"} + resetTrustedCIDRs(c) + assert.Equal(t, "40.40.40.40", c.ClientIP()) + + // Only trust RemoteAddr + c.engine.TrustedProxies = []string{"40.40.40.40"} + resetTrustedCIDRs(c) + assert.Equal(t, "20.20.20.20", c.ClientIP()) + + // All steps are trusted + c.engine.TrustedProxies = []string{"40.40.40.40", "30.30.30.30", "20.20.20.20"} + resetTrustedCIDRs(c) + assert.Equal(t, "20.20.20.20", c.ClientIP()) + + // Use CIDR + c.engine.TrustedProxies = []string{"40.40.25.25/16", "30.30.30.30"} + resetTrustedCIDRs(c) + assert.Equal(t, "20.20.20.20", c.ClientIP()) + + // Use hostname that resolves to all the proxies + c.engine.TrustedProxies = []string{"foo"} + resetTrustedCIDRs(c) + assert.Equal(t, "40.40.40.40", c.ClientIP()) + + // Use hostname that returns an error + c.engine.TrustedProxies = []string{"bar"} + resetTrustedCIDRs(c) + assert.Equal(t, "40.40.40.40", c.ClientIP()) + + // X-Forwarded-For has a non-IP element + c.engine.TrustedProxies = []string{"40.40.40.40"} + resetTrustedCIDRs(c) + c.Request.Header.Set("X-Forwarded-For", " blah ") + assert.Equal(t, "40.40.40.40", c.ClientIP()) + + // Result from LookupHost has non-IP element. This should never + // happen, but we should test it to make sure we handle it + // gracefully. + c.engine.TrustedProxies = []string{"baz"} + resetTrustedCIDRs(c) + c.Request.Header.Set("X-Forwarded-For", " 30.30.30.30 ") + assert.Equal(t, "40.40.40.40", c.ClientIP()) + + c.engine.TrustedProxies = []string{"40.40.40.40"} + resetTrustedCIDRs(c) + c.Request.Header.Del("X-Forwarded-For") + c.engine.RemoteIPHeaders = []string{"X-Forwarded-For", "X-Real-IP"} + assert.Equal(t, "10.10.10.10", c.ClientIP()) + + c.engine.RemoteIPHeaders = []string{} + c.engine.AppEngine = true + assert.Equal(t, "50.50.50.50", c.ClientIP()) + + c.Request.Header.Del("X-Appengine-Remote-Addr") + assert.Equal(t, "40.40.40.40", c.ClientIP()) + + // no port + c.Request.RemoteAddr = "50.50.50.50" + assert.Empty(t, c.ClientIP()) +} + +func resetContextForClientIPTests(c *Context) { + c.Request.Header.Set("X-Real-IP", " 10.10.10.10 ") + c.Request.Header.Set("X-Forwarded-For", " 20.20.20.20, 30.30.30.30") + c.Request.Header.Set("X-Appengine-Remote-Addr", "50.50.50.50") + c.Request.RemoteAddr = " 40.40.40.40:42123 " + c.engine.AppEngine = false } func TestContextContentType(t *testing.T) { @@ -1960,3 +2041,12 @@ func TestContextWithKeysMutex(t *testing.T) { assert.Nil(t, value) assert.False(t, err) } + +func TestRemoteIPFail(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", nil) + c.Request.RemoteAddr = "[:::]:80" + ip, trust := c.RemoteIP() + assert.Nil(t, ip) + assert.False(t, trust) +} diff --git a/gin.go b/gin.go index 1e126179..03a0e127 100644 --- a/gin.go +++ b/gin.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "path" + "strings" "sync" "github.com/gin-gonic/gin/internal/bytesconv" @@ -81,9 +82,26 @@ type Engine struct { // If no other Method is allowed, the request is delegated to the NotFound // handler. HandleMethodNotAllowed bool - ForwardedByClientIP bool - // #726 #755 If enabled, it will thrust some headers starting with + // If enabled, client IP will be parsed from the request's headers that + // match those stored at `(*gin.Engine).RemoteIPHeaders`. If no IP was + // fetched, it falls back to the IP obtained from + // `(*gin.Context).Request.RemoteAddr`. + ForwardedByClientIP bool + + // List of headers used to obtain the client IP when + // `(*gin.Engine).ForwardedByClientIP` is `true` and + // `(*gin.Context).Request.RemoteAddr` is matched by at least one of the + // network origins of `(*gin.Engine).TrustedProxies`. + RemoteIPHeaders []string + + // List of network origins (IPv4 addresses, IPv4 CIDRs, IPv6 addresses or + // IPv6 CIDRs) from which to trust request's headers that contain + // alternative client IP when `(*gin.Engine).ForwardedByClientIP` is + // `true`. + TrustedProxies []string + + // #726 #755 If enabled, it will trust some headers starting with // 'X-AppEngine...' for better integration with that PaaS. AppEngine bool @@ -114,6 +132,7 @@ type Engine struct { pool sync.Pool trees methodTrees maxParams uint16 + trustedCIDRs []*net.IPNet } var _ IRouter = &Engine{} @@ -139,6 +158,8 @@ func New() *Engine { RedirectFixedPath: false, HandleMethodNotAllowed: false, ForwardedByClientIP: true, + RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, + TrustedProxies: []string{"0.0.0.0/0"}, AppEngine: defaultAppEngine, UseRawPath: false, RemoveExtraSlash: false, @@ -305,12 +326,60 @@ func iterate(path, method string, routes RoutesInfo, root *node) RoutesInfo { func (engine *Engine) Run(addr ...string) (err error) { defer func() { debugPrintError(err) }() + trustedCIDRs, err := engine.prepareTrustedCIDRs() + if err != nil { + return err + } + engine.trustedCIDRs = trustedCIDRs address := resolveAddress(addr) debugPrint("Listening and serving HTTP on %s\n", address) err = http.ListenAndServe(address, engine) return } +func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) { + if engine.TrustedProxies == nil { + return nil, nil + } + + cidr := make([]*net.IPNet, 0, len(engine.TrustedProxies)) + for _, trustedProxy := range engine.TrustedProxies { + if !strings.Contains(trustedProxy, "/") { + ip := parseIP(trustedProxy) + if ip == nil { + return cidr, &net.ParseError{Type: "IP address", Text: trustedProxy} + } + + switch len(ip) { + case net.IPv4len: + trustedProxy += "/32" + case net.IPv6len: + trustedProxy += "/128" + } + } + _, cidrNet, err := net.ParseCIDR(trustedProxy) + if err != nil { + return cidr, err + } + cidr = append(cidr, cidrNet) + } + return cidr, nil +} + +// parseIP parse a string representation of an IP and returns a net.IP with the +// minimum byte representation or nil if input is invalid. +func parseIP(ip string) net.IP { + parsedIP := net.ParseIP(ip) + + if ipv4 := parsedIP.To4(); ipv4 != nil { + // return ip in a 4-byte representation + return ipv4 + } + + // return ip in a 16-byte representation or nil + return parsedIP +} + // RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests. // It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router) // Note: this method will block the calling goroutine indefinitely unless an error happens. diff --git a/gin_integration_test.go b/gin_integration_test.go index 41ad9874..fd972657 100644 --- a/gin_integration_test.go +++ b/gin_integration_test.go @@ -55,6 +55,13 @@ func TestRunEmpty(t *testing.T) { testRequest(t, "http://localhost:8080/example") } +func TestTrustedCIDRsForRun(t *testing.T) { + os.Setenv("PORT", "") + router := New() + router.TrustedProxies = []string{"hello/world"} + assert.Error(t, router.Run(":8080")) +} + func TestRunTLS(t *testing.T) { router := New() go func() { diff --git a/gin_test.go b/gin_test.go index 11bdd79c..678d95f2 100644 --- a/gin_test.go +++ b/gin_test.go @@ -9,6 +9,7 @@ import ( "fmt" "html/template" "io/ioutil" + "net" "net/http" "net/http/httptest" "reflect" @@ -532,6 +533,139 @@ func TestEngineHandleContextManyReEntries(t *testing.T) { assert.Equal(t, int64(expectValue), middlewareCounter) } +func TestPrepareTrustedCIRDsWith(t *testing.T) { + r := New() + + // valid ipv4 cidr + { + expectedTrustedCIDRs := []*net.IPNet{parseCIDR("0.0.0.0/0")} + r.TrustedProxies = []string{"0.0.0.0/0"} + + trustedCIDRs, err := r.prepareTrustedCIDRs() + + assert.NoError(t, err) + assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) + } + + // invalid ipv4 cidr + { + r.TrustedProxies = []string{"192.168.1.33/33"} + + _, err := r.prepareTrustedCIDRs() + + assert.Error(t, err) + } + + // valid ipv4 address + { + expectedTrustedCIDRs := []*net.IPNet{parseCIDR("192.168.1.33/32")} + r.TrustedProxies = []string{"192.168.1.33"} + + trustedCIDRs, err := r.prepareTrustedCIDRs() + + assert.NoError(t, err) + assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) + } + + // invalid ipv4 address + { + r.TrustedProxies = []string{"192.168.1.256"} + + _, err := r.prepareTrustedCIDRs() + + assert.Error(t, err) + } + + // valid ipv6 address + { + expectedTrustedCIDRs := []*net.IPNet{parseCIDR("2002:0000:0000:1234:abcd:ffff:c0a8:0101/128")} + r.TrustedProxies = []string{"2002:0000:0000:1234:abcd:ffff:c0a8:0101"} + + trustedCIDRs, err := r.prepareTrustedCIDRs() + + assert.NoError(t, err) + assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) + } + + // invalid ipv6 address + { + r.TrustedProxies = []string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101"} + + _, err := r.prepareTrustedCIDRs() + + assert.Error(t, err) + } + + // valid ipv6 cidr + { + expectedTrustedCIDRs := []*net.IPNet{parseCIDR("::/0")} + r.TrustedProxies = []string{"::/0"} + + trustedCIDRs, err := r.prepareTrustedCIDRs() + + assert.NoError(t, err) + assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) + } + + // invalid ipv6 cidr + { + r.TrustedProxies = []string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101/129"} + + _, err := r.prepareTrustedCIDRs() + + assert.Error(t, err) + } + + // valid combination + { + expectedTrustedCIDRs := []*net.IPNet{ + parseCIDR("::/0"), + parseCIDR("192.168.0.0/16"), + parseCIDR("172.16.0.1/32"), + } + r.TrustedProxies = []string{ + "::/0", + "192.168.0.0/16", + "172.16.0.1", + } + + trustedCIDRs, err := r.prepareTrustedCIDRs() + + assert.NoError(t, err) + assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) + } + + // invalid combination + { + r.TrustedProxies = []string{ + "::/0", + "192.168.0.0/16", + "172.16.0.256", + } + _, err := r.prepareTrustedCIDRs() + + assert.Error(t, err) + } + + // nil value + { + r.TrustedProxies = nil + trustedCIDRs, err := r.prepareTrustedCIDRs() + + assert.Nil(t, trustedCIDRs) + assert.Nil(t, err) + } + +} + +func parseCIDR(cidr string) *net.IPNet { + _, parsedCIDR, err := net.ParseCIDR(cidr) + if err != nil { + fmt.Println(err) + } + return parsedCIDR +} + func assertRoutePresent(t *testing.T, gotRoutes RoutesInfo, wantRoute RouteInfo) { for _, gotRoute := range gotRoutes { if gotRoute.Path == wantRoute.Path && gotRoute.Method == wantRoute.Method { diff --git a/logger_test.go b/logger_test.go index 0d40666e..80961ce1 100644 --- a/logger_test.go +++ b/logger_test.go @@ -185,6 +185,8 @@ func TestLoggerWithConfigFormatting(t *testing.T) { buffer := new(bytes.Buffer) router := New() + router.engine.trustedCIDRs, _ = router.engine.prepareTrustedCIDRs() + router.Use(LoggerWithConfig(LoggerConfig{ Output: buffer, Formatter: func(param LogFormatterParams) string { diff --git a/tree.go b/tree.go index 74e07e84..ca753e6d 100644 --- a/tree.go +++ b/tree.go @@ -80,6 +80,16 @@ func longestCommonPrefix(a, b string) int { return i } +// addChild will add a child node, keeping wildcards at the end +func (n *node) addChild(child *node) { + if n.wildChild && len(n.children) > 0 { + wildcardChild := n.children[len(n.children)-1] + n.children = append(n.children[:len(n.children)-1], child, wildcardChild) + } else { + n.children = append(n.children, child) + } +} + func countParams(path string) uint16 { var n uint16 s := bytesconv.StringToBytes(path) @@ -103,7 +113,7 @@ type node struct { wildChild bool nType nodeType priority uint32 - children []*node + children []*node // child nodes, at most 1 :param style node at the end of the array handlers HandlersChain fullPath string } @@ -177,36 +187,9 @@ walk: // Make new node a child of this node if i < len(path) { path = path[i:] - - if n.wildChild { - parentFullPathIndex += len(n.path) - n = n.children[0] - n.priority++ - - // Check if the wildcard matches - if len(path) >= len(n.path) && n.path == path[:len(n.path)] && - // Adding a child to a catchAll is not possible - n.nType != catchAll && - // Check for longer wildcard, e.g. :name and :names - (len(n.path) >= len(path) || path[len(n.path)] == '/') { - continue walk - } - - pathSeg := path - if n.nType != catchAll { - pathSeg = strings.SplitN(path, "/", 2)[0] - } - prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path - panic("'" + pathSeg + - "' in new path '" + fullPath + - "' conflicts with existing wildcard '" + n.path + - "' in existing prefix '" + prefix + - "'") - } - c := path[0] - // slash after param + // '/' after param if n.nType == param && c == '/' && len(n.children) == 1 { parentFullPathIndex += len(n.path) n = n.children[0] @@ -225,21 +208,47 @@ walk: } // Otherwise insert it - if c != ':' && c != '*' { + if c != ':' && c != '*' && n.nType != catchAll { // []byte for proper unicode char conversion, see #65 n.indices += bytesconv.BytesToString([]byte{c}) child := &node{ fullPath: fullPath, } - n.children = append(n.children, child) + n.addChild(child) n.incrementChildPrio(len(n.indices) - 1) n = child + } else if n.wildChild { + // inserting a wildcard node, need to check if it conflicts with the existing wildcard + n = n.children[len(n.children)-1] + n.priority++ + + // Check if the wildcard matches + if len(path) >= len(n.path) && n.path == path[:len(n.path)] && + // Adding a child to a catchAll is not possible + n.nType != catchAll && + // Check for longer wildcard, e.g. :name and :names + (len(n.path) >= len(path) || path[len(n.path)] == '/') { + continue walk + } + + // Wildcard conflict + pathSeg := path + if n.nType != catchAll { + pathSeg = strings.SplitN(pathSeg, "/", 2)[0] + } + prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path + panic("'" + pathSeg + + "' in new path '" + fullPath + + "' conflicts with existing wildcard '" + n.path + + "' in existing prefix '" + prefix + + "'") } + n.insertChild(path, fullPath, handlers) return } - // Otherwise and handle to current node + // Otherwise add handle to current node if n.handlers != nil { panic("handlers are already registered for path '" + fullPath + "'") } @@ -293,13 +302,6 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") } - // Check if this node has existing children which would be - // unreachable if we insert the wildcard here - if len(n.children) > 0 { - panic("wildcard segment '" + wildcard + - "' conflicts with existing children in path '" + fullPath + "'") - } - if wildcard[0] == ':' { // param if i > 0 { // Insert prefix before the current wildcard @@ -307,13 +309,13 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) path = path[i:] } - n.wildChild = true child := &node{ nType: param, path: wildcard, fullPath: fullPath, } - n.children = []*node{child} + n.addChild(child) + n.wildChild = true n = child n.priority++ @@ -326,7 +328,7 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) priority: 1, fullPath: fullPath, } - n.children = []*node{child} + n.addChild(child) n = child continue } @@ -360,7 +362,7 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain) fullPath: fullPath, } - n.children = []*node{child} + n.addChild(child) n.indices = string('/') n = child n.priority++ @@ -404,18 +406,18 @@ walk: // Outer loop for walking the tree if len(path) > len(prefix) { if path[:len(prefix)] == prefix { path = path[len(prefix):] - // If this node does not have a wildcard (param or catchAll) - // child, we can just look up the next child node and continue - // to walk down the tree - if !n.wildChild { - idxc := path[0] - for i, c := range []byte(n.indices) { - if c == idxc { - n = n.children[i] - continue walk - } - } + // Try all the non-wildcard children first by matching the indices + idxc := path[0] + for i, c := range []byte(n.indices) { + if c == idxc { + n = n.children[i] + continue walk + } + } + + // If there is no wildcard pattern, recommend a redirection + if !n.wildChild { // Nothing found. // We can recommend to redirect to the same URL without a // trailing slash if a leaf exists for that path. @@ -423,8 +425,9 @@ walk: // Outer loop for walking the tree return } - // Handle wildcard child - n = n.children[0] + // Handle wildcard child, which is always at the end of the array + n = n.children[len(n.children)-1] + switch n.nType { case param: // Find param end (either '/' or path end) diff --git a/tree_test.go b/tree_test.go index 1cb4f559..d7c4fb0b 100644 --- a/tree_test.go +++ b/tree_test.go @@ -137,6 +137,8 @@ func TestTreeWildcard(t *testing.T) { "/", "/cmd/:tool/:sub", "/cmd/:tool/", + "/cmd/whoami", + "/cmd/whoami/root/", "/src/*filepath", "/search/", "/search/:query", @@ -155,8 +157,12 @@ func TestTreeWildcard(t *testing.T) { checkRequests(t, tree, testRequests{ {"/", false, "/", nil}, - {"/cmd/test/", false, "/cmd/:tool/", Params{Param{Key: "tool", Value: "test"}}}, - {"/cmd/test", true, "", Params{Param{Key: "tool", Value: "test"}}}, + {"/cmd/test", true, "/cmd/:tool/", Params{Param{"tool", "test"}}}, + {"/cmd/test/", false, "/cmd/:tool/", Params{Param{"tool", "test"}}}, + {"/cmd/whoami", false, "/cmd/whoami", nil}, + {"/cmd/whoami/", true, "/cmd/whoami", nil}, + {"/cmd/whoami/root/", false, "/cmd/whoami/root/", nil}, + {"/cmd/whoami/root", true, "/cmd/whoami/root/", nil}, {"/cmd/test/3", false, "/cmd/:tool/:sub", Params{Param{Key: "tool", Value: "test"}, Param{Key: "sub", Value: "3"}}}, {"/src/", false, "/src/*filepath", Params{Param{Key: "filepath", Value: "/"}}}, {"/src/some/file.png", false, "/src/*filepath", Params{Param{Key: "filepath", Value: "/some/file.png"}}}, @@ -245,20 +251,38 @@ func testRoutes(t *testing.T, routes []testRoute) { func TestTreeWildcardConflict(t *testing.T) { routes := []testRoute{ {"/cmd/:tool/:sub", false}, - {"/cmd/vet", true}, + {"/cmd/vet", false}, + {"/foo/bar", false}, + {"/foo/:name", false}, + {"/foo/:names", true}, + {"/cmd/*path", true}, + {"/cmd/:badvar", true}, + {"/cmd/:tool/names", false}, + {"/cmd/:tool/:badsub/details", true}, {"/src/*filepath", false}, + {"/src/:file", true}, + {"/src/static.json", true}, {"/src/*filepathx", true}, {"/src/", true}, + {"/src/foo/bar", true}, {"/src1/", false}, {"/src1/*filepath", true}, {"/src2*filepath", true}, + {"/src2/*filepath", false}, {"/search/:query", false}, - {"/search/invalid", true}, + {"/search/valid", false}, {"/user_:name", false}, - {"/user_x", true}, + {"/user_x", false}, {"/user_:name", false}, {"/id:id", false}, - {"/id/:id", true}, + {"/id/:id", false}, + } + testRoutes(t, routes) +} + +func TestCatchAllAfterSlash(t *testing.T) { + routes := []testRoute{ + {"/non-leading-*catchall", true}, } testRoutes(t, routes) } @@ -266,14 +290,17 @@ func TestTreeWildcardConflict(t *testing.T) { func TestTreeChildConflict(t *testing.T) { routes := []testRoute{ {"/cmd/vet", false}, - {"/cmd/:tool/:sub", true}, + {"/cmd/:tool", false}, + {"/cmd/:tool/:sub", false}, + {"/cmd/:tool/misc", false}, + {"/cmd/:tool/:othersub", true}, {"/src/AUTHORS", false}, {"/src/*filepath", true}, {"/user_x", false}, - {"/user_:name", true}, + {"/user_:name", false}, {"/id/:id", false}, - {"/id:id", true}, - {"/:id", true}, + {"/id:id", false}, + {"/:id", false}, {"/*filepath", true}, } testRoutes(t, routes) @@ -688,8 +715,7 @@ func TestTreeWildcardConflictEx(t *testing.T) { {"/who/are/foo", "/foo", `/who/are/\*you`, `/\*you`}, {"/who/are/foo/", "/foo/", `/who/are/\*you`, `/\*you`}, {"/who/are/foo/bar", "/foo/bar", `/who/are/\*you`, `/\*you`}, - {"/conxxx", "xxx", `/con:tact`, `:tact`}, - {"/conooo/xxx", "ooo", `/con:tact`, `:tact`}, + {"/con:nection", ":nection", `/con:tact`, `:tact`}, } for _, conflict := range conflicts { diff --git a/version.go b/version.go index 3e9687dc..3647461b 100644 --- a/version.go +++ b/version.go @@ -5,4 +5,4 @@ package gin // Version is the current gin framework's version. -const Version = "v1.6.3" +const Version = "v1.7.1"