diff --git a/.travis.yml b/.travis.yml
index d7086b38..81662315 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -3,13 +3,9 @@ language: go
matrix:
fast_finish: true
include:
- - go: 1.11.x
- env: GO111MODULE=on
- - go: 1.12.x
- env: GO111MODULE=on
- go: 1.13.x
- go: 1.13.x
- env:
+ env:
- TESTTAGS=nomsgpack
- go: 1.14.x
- go: 1.14.x
@@ -19,6 +15,10 @@ matrix:
- go: 1.15.x
env:
- TESTTAGS=nomsgpack
+ - go: 1.16.x
+ - go: 1.16.x
+ env:
+ - TESTTAGS=nomsgpack
- go: master
git:
diff --git a/AUTHORS.md b/AUTHORS.md
index dda19bcf..c634e6be 100644
--- a/AUTHORS.md
+++ b/AUTHORS.md
@@ -156,7 +156,7 @@ People and companies, who have contributed, in alphabetical order.
- Fix variadic parameter in the flexible render API
- Fix Corrupted plain render
- Add Pluggable View Renderer Example
-
+
**@msemenistyi (Mykyta Semenistyi)**
- update Readme.md. Add code to String method
@@ -190,6 +190,8 @@ People and companies, who have contributed, in alphabetical order.
**@rogierlommers (Rogier Lommers)**
- Add updated static serve example
+**@rw-access (Ross Wolf)**
+- Added support to mix exact and param routes
**@se77en (Damon Zhao)**
- Improve color logging
diff --git a/BENCHMARKS.md b/BENCHMARKS.md
index 0f59b509..c11ee99a 100644
--- a/BENCHMARKS.md
+++ b/BENCHMARKS.md
@@ -1,11 +1,11 @@
# Benchmark System
-**VM HOST:** Travis
-**Machine:** Ubuntu 16.04.6 LTS x64
-**Date:** May 04th, 2020
+**VM HOST:** Travis
+**Machine:** Ubuntu 16.04.6 LTS x64
+**Date:** May 04th, 2020
**Version:** Gin v1.6.3
-**Go Version:** 1.14.2 linux/amd64
+**Go Version:** 1.14.2 linux/amd64
**Source:** [Go HTTP Router Benchmark](https://github.com/gin-gonic/go-http-routing-benchmark)
**Result:** [See the gist](https://gist.github.com/appleboy/b5f2ecfaf50824ae9c64dcfb9165ae5e) or [Travis result](https://travis-ci.org/github/gin-gonic/go-http-routing-benchmark/jobs/682947061)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3ac51ad3..dc2c2f55 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,44 @@
# Gin ChangeLog
+## Gin v1.7.1
+
+### BUGFIXES
+
+* fix: data race with trustedCIDRs from [#2674](https://github.com/gin-gonic/gin/issues/2674)([#2675](https://github.com/gin-gonic/gin/pull/2675))
+
+## Gin v1.7.0
+
+### BUGFIXES
+
+* fix compile error from [#2572](https://github.com/gin-gonic/gin/pull/2572) ([#2600](https://github.com/gin-gonic/gin/pull/2600))
+* fix: print headers without Authorization header on broken pipe ([#2528](https://github.com/gin-gonic/gin/pull/2528))
+* fix(tree): reassign fullpath when register new node ([#2366](https://github.com/gin-gonic/gin/pull/2366))
+
+### ENHANCEMENTS
+
+* Support params and exact routes without creating conflicts ([#2663](https://github.com/gin-gonic/gin/pull/2663))
+* chore: improve render string performance ([#2365](https://github.com/gin-gonic/gin/pull/2365))
+* Sync route tree to httprouter latest code ([#2368](https://github.com/gin-gonic/gin/pull/2368))
+* chore: rename getQueryCache/getFormCache to initQueryCache/initFormCa ([#2375](https://github.com/gin-gonic/gin/pull/2375))
+* chore(performance): improve countParams ([#2378](https://github.com/gin-gonic/gin/pull/2378))
+* Remove some functions that have the same effect as the bytes package ([#2387](https://github.com/gin-gonic/gin/pull/2387))
+* update:SetMode function ([#2321](https://github.com/gin-gonic/gin/pull/2321))
+* remove a unused type SecureJSONPrefix ([#2391](https://github.com/gin-gonic/gin/pull/2391))
+* Add a redirect sample for POST method ([#2389](https://github.com/gin-gonic/gin/pull/2389))
+* Add CustomRecovery builtin middleware ([#2322](https://github.com/gin-gonic/gin/pull/2322))
+* binding: avoid 2038 problem on 32-bit architectures ([#2450](https://github.com/gin-gonic/gin/pull/2450))
+* Prevent panic in Context.GetQuery() when there is no Request ([#2412](https://github.com/gin-gonic/gin/pull/2412))
+* Add GetUint and GetUint64 method on gin.context ([#2487](https://github.com/gin-gonic/gin/pull/2487))
+* update content-disposition header to MIME-style ([#2512](https://github.com/gin-gonic/gin/pull/2512))
+* reduce allocs and improve the render `WriteString` ([#2508](https://github.com/gin-gonic/gin/pull/2508))
+* implement ".Unwrap() error" on Error type ([#2525](https://github.com/gin-gonic/gin/pull/2525)) ([#2526](https://github.com/gin-gonic/gin/pull/2526))
+* Allow bind with a map[string]string ([#2484](https://github.com/gin-gonic/gin/pull/2484))
+* chore: update tree ([#2371](https://github.com/gin-gonic/gin/pull/2371))
+* Support binding for slice/array obj [Rewrite] ([#2302](https://github.com/gin-gonic/gin/pull/2302))
+* basic auth: fix timing oracle ([#2609](https://github.com/gin-gonic/gin/pull/2609))
+* Add mixed param and non-param paths (port of httprouter[#329](https://github.com/gin-gonic/gin/pull/329)) ([#2663](https://github.com/gin-gonic/gin/pull/2663))
+* feat(engine): add trustedproxies and remoteIP ([#2632](https://github.com/gin-gonic/gin/pull/2632))
+
## Gin v1.6.3
### ENHANCEMENTS
@@ -215,12 +254,12 @@
## Gin 1.1
-- [NEW] Implement QueryArray and PostArray methods
-- [NEW] Refactor GetQuery and GetPostForm
-- [NEW] Add contribution guide
+- [NEW] Implement QueryArray and PostArray methods
+- [NEW] Refactor GetQuery and GetPostForm
+- [NEW] Add contribution guide
- [FIX] Corrected typos in README
-- [FIX] Removed additional Iota
-- [FIX] Changed imports to gopkg instead of github in README (#733)
+- [FIX] Removed additional Iota
+- [FIX] Changed imports to gopkg instead of github in README (#733)
- [FIX] Logger: skip ANSI color commands if output is not a tty
## Gin 1.0rc2 (...)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 98d758ef..97daa808 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,4 +1,4 @@
-## Contributing
+## Contributing
- With issues:
- Use the search tool before opening a new issue.
diff --git a/README.md b/README.md
index 3be8df7b..289ac959 100644
--- a/README.md
+++ b/README.md
@@ -23,7 +23,7 @@ Gin is a web framework written in Go (Golang). It features a martini-like API wi
- [Quick start](#quick-start)
- [Benchmarks](#benchmarks)
- [Gin v1. stable](#gin-v1-stable)
- - [Build with jsoniter](#build-with-jsoniter)
+ - [Build with jsoniter/go-json](#build-with-json-replacement)
- [API Examples](#api-examples)
- [Using GET, POST, PUT, PATCH, DELETE and OPTIONS](#using-get-post-put-patch-delete-and-options)
- [Parameters in path](#parameters-in-path)
@@ -84,7 +84,7 @@ Gin is a web framework written in Go (Golang). It features a martini-like API wi
To install Gin package, you need to install Go and set your Go workspace first.
-1. The first need [Go](https://golang.org/) installed (**version 1.11+ is required**), then you can use the below Go command to install Gin.
+1. The first need [Go](https://golang.org/) installed (**version 1.13+ is required**), then you can use the below Go command to install Gin.
```sh
$ go get -u github.com/gin-gonic/gin
@@ -103,7 +103,7 @@ import "net/http"
```
## Quick start
-
+
```sh
# assume the following codes in example.go file
$ cat example.go
@@ -182,13 +182,18 @@ Gin uses a custom version of [HttpRouter](https://github.com/julienschmidt/httpr
- [x] Battle tested.
- [x] API frozen, new releases will not break your code.
-## Build with [jsoniter](https://github.com/json-iterator/go)
+## Build with json replacement
-Gin uses `encoding/json` as default json package but you can change to [jsoniter](https://github.com/json-iterator/go) by build from other tags.
+Gin uses `encoding/json` as default json package but you can change it by build from other tags.
+[jsoniter](https://github.com/json-iterator/go)
```sh
$ go build -tags=jsoniter .
```
+[go-json](https://github.com/goccy/go-json)
+```sh
+$ go build -tags=go_json .
+```
## API Examples
@@ -243,6 +248,13 @@ func main() {
c.FullPath() == "/user/:name/*action" // true
})
+ // This handler will add a new router for /user/groups.
+ // Exact routes are resolved before param routes, regardless of the order they were defined.
+ // Routes starting with /user/groups are never interpreted as /user/:name/... routes
+ router.GET("/user/groups", func(c *gin.Context) {
+ c.String(http.StatusOK, "The available groups are [...]", name)
+ })
+
router.Run(":8080")
}
```
@@ -588,44 +600,44 @@ func main() {
::1 - [Fri, 07 Dec 2018 17:04:38 JST] "GET /ping HTTP/1.1 200 122.767µs "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.80 Safari/537.36" "
```
-### Controlling Log output coloring
+### Controlling Log output coloring
By default, logs output on console should be colorized depending on the detected TTY.
-Never colorize logs:
+Never colorize logs:
```go
func main() {
// Disable log's color
gin.DisableConsoleColor()
-
+
// Creates a gin router with default middleware:
// logger and recovery (crash-free) middleware
router := gin.Default()
-
+
router.GET("/ping", func(c *gin.Context) {
c.String(200, "pong")
})
-
+
router.Run(":8080")
}
```
-Always colorize logs:
+Always colorize logs:
```go
func main() {
// Force log's color
gin.ForceConsoleColor()
-
+
// Creates a gin router with default middleware:
// logger and recovery (crash-free) middleware
router := gin.Default()
-
+
router.GET("/ping", func(c *gin.Context) {
c.String(200, "pong")
})
-
+
router.Run(":8080")
}
```
@@ -667,12 +679,12 @@ func main() {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
-
+
if json.User != "manu" || json.Password != "123" {
c.JSON(http.StatusUnauthorized, gin.H{"status": "unauthorized"})
return
- }
-
+ }
+
c.JSON(http.StatusOK, gin.H{"status": "you are logged in"})
})
@@ -688,12 +700,12 @@ func main() {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
-
+
if xml.User != "manu" || xml.Password != "123" {
c.JSON(http.StatusUnauthorized, gin.H{"status": "unauthorized"})
return
- }
-
+ }
+
c.JSON(http.StatusOK, gin.H{"status": "you are logged in"})
})
@@ -705,12 +717,12 @@ func main() {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
-
+
if form.User != "manu" || form.Password != "123" {
c.JSON(http.StatusUnauthorized, gin.H{"status": "unauthorized"})
return
- }
-
+ }
+
c.JSON(http.StatusOK, gin.H{"status": "you are logged in"})
})
@@ -807,7 +819,7 @@ $ curl "localhost:8085/bookable?check_in=2030-03-10&check_out=2030-03-09"
{"error":"Key: 'Booking.CheckOut' Error:Field validation for 'CheckOut' failed on the 'gtfield' tag"}
$ curl "localhost:8085/bookable?check_in=2000-03-09&check_out=2000-03-10"
-{"error":"Key: 'Booking.CheckIn' Error:Field validation for 'CheckIn' failed on the 'bookabledate' tag"}%
+{"error":"Key: 'Booking.CheckIn' Error:Field validation for 'CheckIn' failed on the 'bookabledate' tag"}%
```
[Struct level validations](https://github.com/go-playground/validator/releases/tag/v8.7) can also be registered this way.
@@ -1145,7 +1157,7 @@ func main() {
data := gin.H{
"foo": "bar",
}
-
+
//callback is x
// Will output : x({\"foo\":\"bar\"})
c.JSONP(http.StatusOK, data)
@@ -1190,21 +1202,21 @@ This feature is unavailable in Go 1.6 and lower.
```go
func main() {
r := gin.Default()
-
+
// Serves unicode entities
r.GET("/json", func(c *gin.Context) {
c.JSON(200, gin.H{
"html": "Hello, world!",
})
})
-
+
// Serves literal characters
r.GET("/purejson", func(c *gin.Context) {
c.PureJSON(200, gin.H{
"html": "Hello, world!",
})
})
-
+
// listen and serve on 0.0.0.0:8080
r.Run(":8080")
}
@@ -1797,8 +1809,8 @@ func main() {
// Initializing the server in a goroutine so that
// it won't block the graceful shutdown handling below
go func() {
- if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- log.Fatalf("listen: %s\n", err)
+ if err := srv.ListenAndServe(); err != nil && errors.Is(err, http.ErrServerClosed) {
+ log.Printf("listen: %s\n", err)
}
}()
@@ -1816,10 +1828,11 @@ func main() {
// the request it is currently handling
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
+
if err := srv.Shutdown(ctx); err != nil {
log.Fatal("Server forced to shutdown:", err)
}
-
+
log.Println("Server exiting")
}
```
@@ -2120,6 +2133,39 @@ func main() {
}
```
+## Don't trust all proxies
+
+Gin lets you specify which headers to hold the real client IP (if any),
+as well as specifying which proxies (or direct clients) you trust to
+specify one of these headers.
+
+The `TrustedProxies` slice on your `gin.Engine` specifes network addresses or
+network CIDRs from where clients which their request headers related to client
+IP can be trusted. They can be IPv4 addresses, IPv4 CIDRs, IPv6 addresses or
+IPv6 CIDRs.
+
+```go
+import (
+ "fmt"
+
+ "github.com/gin-gonic/gin"
+)
+
+func main() {
+
+ router := gin.Default()
+ router.TrustedProxies = []string{"192.168.1.2"}
+
+ router.GET("/", func(c *gin.Context) {
+ // If the client is 192.168.1.2, use the X-Forwarded-For
+ // header to deduce the original client IP from the trust-
+ // worthy parts of that header.
+ // Otherwise, simply return the direct client IP
+ fmt.Printf("ClientIP: %s\n", c.ClientIP())
+ })
+ router.Run()
+}
+```
## Testing
@@ -2178,3 +2224,4 @@ Awesome project lists using [Gin](https://github.com/gin-gonic/gin) web framewor
* [picfit](https://github.com/thoas/picfit): An image resizing server written in Go.
* [brigade](https://github.com/brigadecore/brigade): Event-based Scripting for Kubernetes.
* [dkron](https://github.com/distribworks/dkron): Distributed, fault tolerant job scheduling system.
+
diff --git a/auth.go b/auth.go
index 43ad36f5..4d8a6ce4 100644
--- a/auth.go
+++ b/auth.go
@@ -5,6 +5,7 @@
package gin
import (
+ "crypto/subtle"
"encoding/base64"
"net/http"
"strconv"
@@ -30,7 +31,7 @@ func (a authPairs) searchCredential(authValue string) (string, bool) {
return "", false
}
for _, pair := range a {
- if pair.value == authValue {
+ if subtle.ConstantTimeCompare([]byte(pair.value), []byte(authValue)) == 1 {
return pair.user, true
}
}
diff --git a/binding/binding.go b/binding/binding.go
index 57562845..5caeb581 100644
--- a/binding/binding.go
+++ b/binding/binding.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
+//go:build !nomsgpack
// +build !nomsgpack
package binding
@@ -51,7 +52,8 @@ type BindingUri interface {
// https://github.com/go-playground/validator/tree/v8.18.2.
type StructValidator interface {
// ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
- // If the received type is not a struct, any validation should be skipped and nil must be returned.
+ // If the received type is a slice|array, the validation should be performed travel on every element.
+ // If the received type is not a struct or slice|array, any validation should be skipped and nil must be returned.
// If the received type is a struct or pointer to a struct, the validation should be performed.
// If the struct is not valid or the validation itself fails, a descriptive error should be returned.
// Otherwise nil must be returned.
diff --git a/binding/binding_msgpack_test.go b/binding/binding_msgpack_test.go
index 9791a607..04d94079 100644
--- a/binding/binding_msgpack_test.go
+++ b/binding/binding_msgpack_test.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
+//go:build !nomsgpack
// +build !nomsgpack
package binding
diff --git a/binding/binding_nomsgpack.go b/binding/binding_nomsgpack.go
index fd227b11..9afa3dcf 100644
--- a/binding/binding_nomsgpack.go
+++ b/binding/binding_nomsgpack.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
+//go:build nomsgpack
// +build nomsgpack
package binding
diff --git a/binding/binding_test.go b/binding/binding_test.go
index 4424bab9..17336177 100644
--- a/binding/binding_test.go
+++ b/binding/binding_test.go
@@ -13,6 +13,7 @@ import (
"mime/multipart"
"net/http"
"os"
+ "reflect"
"strconv"
"strings"
"testing"
@@ -34,7 +35,7 @@ type QueryTest struct {
}
type FooStruct struct {
- Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required"`
+ Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" binding:"required,max=32"`
}
type FooBarStruct struct {
@@ -180,6 +181,20 @@ func TestBindingJSON(t *testing.T) {
`{"foo": "bar"}`, `{"bar": "foo"}`)
}
+func TestBindingJSONSlice(t *testing.T) {
+ EnableDecoderDisallowUnknownFields = true
+ defer func() {
+ EnableDecoderDisallowUnknownFields = false
+ }()
+
+ testBodyBindingSlice(t, JSON, "json", "/", "/", `[]`, ``)
+ testBodyBindingSlice(t, JSON, "json", "/", "/", `[{"foo": "123"}]`, `[{}]`)
+ testBodyBindingSlice(t, JSON, "json", "/", "/", `[{"foo": "123"}]`, `[{"foo": ""}]`)
+ testBodyBindingSlice(t, JSON, "json", "/", "/", `[{"foo": "123"}]`, `[{"foo": 123}]`)
+ testBodyBindingSlice(t, JSON, "json", "/", "/", `[{"foo": "123"}]`, `[{"bar": 123}]`)
+ testBodyBindingSlice(t, JSON, "json", "/", "/", `[{"foo": "123"}]`, `[{"foo": "123456789012345678901234567890123"}]`)
+}
+
func TestBindingJSONUseNumber(t *testing.T) {
testBodyBindingUseNumber(t,
JSON, "json",
@@ -200,6 +215,12 @@ func TestBindingJSONDisallowUnknownFields(t *testing.T) {
`{"foo": "bar"}`, `{"foo": "bar", "what": "this"}`)
}
+func TestBindingJSONStringMap(t *testing.T) {
+ testBodyBindingStringMap(t, JSON,
+ "/", "/",
+ `{"foo": "bar", "hello": "world"}`, `{"num": 2}`)
+}
+
func TestBindingForm(t *testing.T) {
testFormBinding(t, "POST",
"/", "/",
@@ -336,6 +357,37 @@ func TestBindingFormForType(t *testing.T) {
"", "", "StructPointer")
}
+func TestBindingFormStringMap(t *testing.T) {
+ testBodyBindingStringMap(t, Form,
+ "/", "",
+ `foo=bar&hello=world`, "")
+ // Should pick the last value
+ testBodyBindingStringMap(t, Form,
+ "/", "",
+ `foo=something&foo=bar&hello=world`, "")
+}
+
+func TestBindingFormStringSliceMap(t *testing.T) {
+ obj := make(map[string][]string)
+ req := requestWithBody("POST", "/", "foo=something&foo=bar&hello=world")
+ req.Header.Add("Content-Type", MIMEPOSTForm)
+ err := Form.Bind(req, &obj)
+ assert.NoError(t, err)
+ assert.NotNil(t, obj)
+ assert.Len(t, obj, 2)
+ target := map[string][]string{
+ "foo": {"something", "bar"},
+ "hello": {"world"},
+ }
+ assert.True(t, reflect.DeepEqual(obj, target))
+
+ objInvalid := make(map[string][]int)
+ req = requestWithBody("POST", "/", "foo=something&foo=bar&hello=world")
+ req.Header.Add("Content-Type", MIMEPOSTForm)
+ err = Form.Bind(req, &objInvalid)
+ assert.Error(t, err)
+}
+
func TestBindingQuery(t *testing.T) {
testQueryBinding(t, "POST",
"/?foo=bar&bar=foo", "/",
@@ -366,6 +418,28 @@ func TestBindingQueryBoolFail(t *testing.T) {
"bool_foo=unused", "")
}
+func TestBindingQueryStringMap(t *testing.T) {
+ b := Query
+
+ obj := make(map[string]string)
+ req := requestWithBody("GET", "/?foo=bar&hello=world", "")
+ err := b.Bind(req, &obj)
+ assert.NoError(t, err)
+ assert.NotNil(t, obj)
+ assert.Len(t, obj, 2)
+ assert.Equal(t, "bar", obj["foo"])
+ assert.Equal(t, "world", obj["hello"])
+
+ obj = make(map[string]string)
+ req = requestWithBody("GET", "/?foo=bar&foo=2&hello=world", "") // should pick last
+ err = b.Bind(req, &obj)
+ assert.NoError(t, err)
+ assert.NotNil(t, obj)
+ assert.Len(t, obj, 2)
+ assert.Equal(t, "2", obj["foo"])
+ assert.Equal(t, "world", obj["hello"])
+}
+
func TestBindingXML(t *testing.T) {
testBodyBinding(t,
XML, "xml",
@@ -387,6 +461,13 @@ func TestBindingYAML(t *testing.T) {
`foo: bar`, `bar: foo`)
}
+func TestBindingYAMLStringMap(t *testing.T) {
+ // YAML is a superset of JSON, so the test below is JSON (to avoid newlines)
+ testBodyBindingStringMap(t, YAML,
+ "/", "/",
+ `{"foo": "bar", "hello": "world"}`, `{"nested": {"foo": "bar"}}`)
+}
+
func TestBindingYAMLFail(t *testing.T) {
testBodyBindingFail(t,
YAML, "yaml",
@@ -1114,6 +1195,46 @@ func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody
assert.Error(t, err)
}
+func testBodyBindingSlice(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
+ assert.Equal(t, name, b.Name())
+
+ var obj1 []FooStruct
+ req := requestWithBody("POST", path, body)
+ err := b.Bind(req, &obj1)
+ assert.NoError(t, err)
+
+ var obj2 []FooStruct
+ req = requestWithBody("POST", badPath, badBody)
+ err = JSON.Bind(req, &obj2)
+ assert.Error(t, err)
+}
+
+func testBodyBindingStringMap(t *testing.T, b Binding, path, badPath, body, badBody string) {
+ obj := make(map[string]string)
+ req := requestWithBody("POST", path, body)
+ if b.Name() == "form" {
+ req.Header.Add("Content-Type", MIMEPOSTForm)
+ }
+ err := b.Bind(req, &obj)
+ assert.NoError(t, err)
+ assert.NotNil(t, obj)
+ assert.Len(t, obj, 2)
+ assert.Equal(t, "bar", obj["foo"])
+ assert.Equal(t, "world", obj["hello"])
+
+ if badPath != "" && badBody != "" {
+ obj = make(map[string]string)
+ req = requestWithBody("POST", badPath, badBody)
+ err = b.Bind(req, &obj)
+ assert.Error(t, err)
+ }
+
+ objInt := make(map[string]int)
+ req = requestWithBody("POST", path, body)
+ err = b.Bind(req, &objInt)
+ assert.Error(t, err)
+}
+
func testBodyBindingUseNumber(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
assert.Equal(t, name, b.Name())
diff --git a/binding/default_validator.go b/binding/default_validator.go
index a4c1a7f6..c57a120f 100644
--- a/binding/default_validator.go
+++ b/binding/default_validator.go
@@ -5,7 +5,9 @@
package binding
import (
+ "fmt"
"reflect"
+ "strings"
"sync"
"github.com/go-playground/validator/v10"
@@ -16,22 +18,54 @@ type defaultValidator struct {
validate *validator.Validate
}
+type sliceValidateError []error
+
+func (err sliceValidateError) Error() string {
+ var errMsgs []string
+ for i, e := range err {
+ if e == nil {
+ continue
+ }
+ errMsgs = append(errMsgs, fmt.Sprintf("[%d]: %s", i, e.Error()))
+ }
+ return strings.Join(errMsgs, "\n")
+}
+
var _ StructValidator = &defaultValidator{}
// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type.
func (v *defaultValidator) ValidateStruct(obj interface{}) error {
+ if obj == nil {
+ return nil
+ }
+
value := reflect.ValueOf(obj)
- valueType := value.Kind()
- if valueType == reflect.Ptr {
- valueType = value.Elem().Kind()
- }
- if valueType == reflect.Struct {
- v.lazyinit()
- if err := v.validate.Struct(obj); err != nil {
- return err
+ switch value.Kind() {
+ case reflect.Ptr:
+ return v.ValidateStruct(value.Elem().Interface())
+ case reflect.Struct:
+ return v.validateStruct(obj)
+ case reflect.Slice, reflect.Array:
+ count := value.Len()
+ validateRet := make(sliceValidateError, 0)
+ for i := 0; i < count; i++ {
+ if err := v.ValidateStruct(value.Index(i).Interface()); err != nil {
+ validateRet = append(validateRet, err)
+ }
}
+ if len(validateRet) == 0 {
+ return nil
+ }
+ return validateRet
+ default:
+ return nil
}
- return nil
+}
+
+// validateStruct receives struct type
+func (v *defaultValidator) validateStruct(obj interface{}) error {
+ v.lazyinit()
+ return v.validate.Struct(obj)
}
// Engine returns the underlying validator engine which powers the default
diff --git a/binding/default_validator_test.go b/binding/default_validator_test.go
new file mode 100644
index 00000000..e9c6de44
--- /dev/null
+++ b/binding/default_validator_test.go
@@ -0,0 +1,68 @@
+// Copyright 2020 Gin Core Team. All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+package binding
+
+import (
+ "errors"
+ "testing"
+)
+
+func TestSliceValidateError(t *testing.T) {
+ tests := []struct {
+ name string
+ err sliceValidateError
+ want string
+ }{
+ {"has nil elements", sliceValidateError{errors.New("test error"), nil}, "[0]: test error"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.err.Error(); got != tt.want {
+ t.Errorf("sliceValidateError.Error() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestDefaultValidator(t *testing.T) {
+ type exampleStruct struct {
+ A string `binding:"max=8"`
+ B int `binding:"gt=0"`
+ }
+ tests := []struct {
+ name string
+ v *defaultValidator
+ obj interface{}
+ wantErr bool
+ }{
+ {"validate nil obj", &defaultValidator{}, nil, false},
+ {"validate int obj", &defaultValidator{}, 3, false},
+ {"validate struct failed-1", &defaultValidator{}, exampleStruct{A: "123456789", B: 1}, true},
+ {"validate struct failed-2", &defaultValidator{}, exampleStruct{A: "12345678", B: 0}, true},
+ {"validate struct passed", &defaultValidator{}, exampleStruct{A: "12345678", B: 1}, false},
+ {"validate *struct failed-1", &defaultValidator{}, &exampleStruct{A: "123456789", B: 1}, true},
+ {"validate *struct failed-2", &defaultValidator{}, &exampleStruct{A: "12345678", B: 0}, true},
+ {"validate *struct passed", &defaultValidator{}, &exampleStruct{A: "12345678", B: 1}, false},
+ {"validate []struct failed-1", &defaultValidator{}, []exampleStruct{{A: "123456789", B: 1}}, true},
+ {"validate []struct failed-2", &defaultValidator{}, []exampleStruct{{A: "12345678", B: 0}}, true},
+ {"validate []struct passed", &defaultValidator{}, []exampleStruct{{A: "12345678", B: 1}}, false},
+ {"validate []*struct failed-1", &defaultValidator{}, []*exampleStruct{{A: "123456789", B: 1}}, true},
+ {"validate []*struct failed-2", &defaultValidator{}, []*exampleStruct{{A: "12345678", B: 0}}, true},
+ {"validate []*struct passed", &defaultValidator{}, []*exampleStruct{{A: "12345678", B: 1}}, false},
+ {"validate *[]struct failed-1", &defaultValidator{}, &[]exampleStruct{{A: "123456789", B: 1}}, true},
+ {"validate *[]struct failed-2", &defaultValidator{}, &[]exampleStruct{{A: "12345678", B: 0}}, true},
+ {"validate *[]struct passed", &defaultValidator{}, &[]exampleStruct{{A: "12345678", B: 1}}, false},
+ {"validate *[]*struct failed-1", &defaultValidator{}, &[]*exampleStruct{{A: "123456789", B: 1}}, true},
+ {"validate *[]*struct failed-2", &defaultValidator{}, &[]*exampleStruct{{A: "12345678", B: 0}}, true},
+ {"validate *[]*struct passed", &defaultValidator{}, &[]*exampleStruct{{A: "12345678", B: 1}}, false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if err := tt.v.ValidateStruct(tt.obj); (err != nil) != tt.wantErr {
+ t.Errorf("defaultValidator.Validate() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
diff --git a/binding/form_mapping.go b/binding/form_mapping.go
index f0913ea5..2f4e45b4 100644
--- a/binding/form_mapping.go
+++ b/binding/form_mapping.go
@@ -29,6 +29,21 @@ func mapForm(ptr interface{}, form map[string][]string) error {
var emptyField = reflect.StructField{}
func mapFormByTag(ptr interface{}, form map[string][]string, tag string) error {
+ // Check if ptr is a map
+ ptrVal := reflect.ValueOf(ptr)
+ var pointed interface{}
+ if ptrVal.Kind() == reflect.Ptr {
+ ptrVal = ptrVal.Elem()
+ pointed = ptrVal.Interface()
+ }
+ if ptrVal.Kind() == reflect.Map &&
+ ptrVal.Type().Key().Kind() == reflect.String {
+ if pointed != nil {
+ ptr = pointed
+ }
+ return setFormMap(ptr, form)
+ }
+
return mappingByPtr(ptr, formSource(form), tag)
}
@@ -349,3 +364,29 @@ func head(str, sep string) (head string, tail string) {
}
return str[:idx], str[idx+len(sep):]
}
+
+func setFormMap(ptr interface{}, form map[string][]string) error {
+ el := reflect.TypeOf(ptr).Elem()
+
+ if el.Kind() == reflect.Slice {
+ ptrMap, ok := ptr.(map[string][]string)
+ if !ok {
+ return errors.New("cannot convert to map slices of strings")
+ }
+ for k, v := range form {
+ ptrMap[k] = v
+ }
+
+ return nil
+ }
+
+ ptrMap, ok := ptr.(map[string]string)
+ if !ok {
+ return errors.New("cannot convert to map of strings")
+ }
+ for k, v := range form {
+ ptrMap[k] = v[len(v)-1] // pick last
+ }
+
+ return nil
+}
diff --git a/binding/json.go b/binding/json.go
index d62e0705..45aaa494 100644
--- a/binding/json.go
+++ b/binding/json.go
@@ -6,7 +6,7 @@ package binding
import (
"bytes"
- "fmt"
+ "errors"
"io"
"net/http"
@@ -32,7 +32,7 @@ func (jsonBinding) Name() string {
func (jsonBinding) Bind(req *http.Request, obj interface{}) error {
if req == nil || req.Body == nil {
- return fmt.Errorf("invalid request")
+ return errors.New("invalid request")
}
return decodeJSON(req.Body, obj)
}
diff --git a/binding/json_test.go b/binding/json_test.go
index cae4cccc..fbd5c527 100644
--- a/binding/json_test.go
+++ b/binding/json_test.go
@@ -19,3 +19,12 @@ func TestJSONBindingBindBody(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "FOO", s.Foo)
}
+
+func TestJSONBindingBindBodyMap(t *testing.T) {
+ s := make(map[string]string)
+ err := jsonBinding{}.BindBody([]byte(`{"foo": "FOO","hello":"world"}`), &s)
+ require.NoError(t, err)
+ assert.Len(t, s, 2)
+ assert.Equal(t, "FOO", s["foo"])
+ assert.Equal(t, "world", s["hello"])
+}
diff --git a/binding/msgpack.go b/binding/msgpack.go
index a5bc2ad2..2a442996 100644
--- a/binding/msgpack.go
+++ b/binding/msgpack.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
+//go:build !nomsgpack
// +build !nomsgpack
package binding
diff --git a/binding/msgpack_test.go b/binding/msgpack_test.go
index 296d3eb1..75600ba8 100644
--- a/binding/msgpack_test.go
+++ b/binding/msgpack_test.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
+//go:build !nomsgpack
// +build !nomsgpack
package binding
diff --git a/context.go b/context.go
index 71fb5937..dc03c358 100644
--- a/context.go
+++ b/context.go
@@ -725,32 +725,80 @@ func (c *Context) ShouldBindBodyWith(obj interface{}, bb binding.BindingBody) (e
return bb.BindBody(body, obj)
}
-// ClientIP implements a best effort algorithm to return the real client IP, it parses
-// X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
-// Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP.
+// ClientIP implements a best effort algorithm to return the real client IP.
+// It called c.RemoteIP() under the hood, to check if the remote IP is a trusted proxy or not.
+// If it's it will then try to parse the headers defined in Engine.RemoteIPHeaders (defaulting to [X-Forwarded-For, X-Real-Ip]).
+// If the headers are nots syntactically valid OR the remote IP does not correspong to a trusted proxy,
+// the remote IP (coming form Request.RemoteAddr) is returned.
func (c *Context) ClientIP() string {
- if c.engine.ForwardedByClientIP {
- clientIP := c.requestHeader("X-Forwarded-For")
- clientIP = strings.TrimSpace(strings.Split(clientIP, ",")[0])
- if clientIP == "" {
- clientIP = strings.TrimSpace(c.requestHeader("X-Real-Ip"))
- }
- if clientIP != "" {
- return clientIP
- }
- }
-
if c.engine.AppEngine {
if addr := c.requestHeader("X-Appengine-Remote-Addr"); addr != "" {
return addr
}
}
- if ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)); err == nil {
- return ip
+ remoteIP, trusted := c.RemoteIP()
+ if remoteIP == nil {
+ return ""
}
- return ""
+ if trusted && c.engine.ForwardedByClientIP && c.engine.RemoteIPHeaders != nil {
+ for _, headerName := range c.engine.RemoteIPHeaders {
+ ip, valid := validateHeader(c.requestHeader(headerName))
+ if valid {
+ return ip
+ }
+ }
+ }
+ return remoteIP.String()
+}
+
+// RemoteIP parses the IP from Request.RemoteAddr, normalizes and returns the IP (without the port).
+// It also checks if the remoteIP is a trusted proxy or not.
+// In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks
+// defined in Engine.TrustedProxies
+func (c *Context) RemoteIP() (net.IP, bool) {
+ ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr))
+ if err != nil {
+ return nil, false
+ }
+ remoteIP := net.ParseIP(ip)
+ if remoteIP == nil {
+ return nil, false
+ }
+
+ if c.engine.trustedCIDRs != nil {
+ for _, cidr := range c.engine.trustedCIDRs {
+ if cidr.Contains(remoteIP) {
+ return remoteIP, true
+ }
+ }
+ }
+
+ return remoteIP, false
+}
+
+func validateHeader(header string) (clientIP string, valid bool) {
+ if header == "" {
+ return "", false
+ }
+ items := strings.Split(header, ",")
+ for i, ipStr := range items {
+ ipStr = strings.TrimSpace(ipStr)
+ ip := net.ParseIP(ipStr)
+ if ip == nil {
+ return "", false
+ }
+
+ // We need to return the first IP in the list, but,
+ // we should not early return since we need to validate that
+ // the rest of the header is syntactically valid
+ if i == 0 {
+ clientIP = ipStr
+ valid = true
+ }
+ }
+ return
}
// ContentType returns the Content-Type header of the request.
@@ -891,7 +939,7 @@ func (c *Context) SecureJSON(code int, obj interface{}) {
}
// JSONP serializes the given struct as JSON into the response body.
-// It add padding to response body to request data from a server residing in a different domain than the client.
+// It adds padding to response body to request data from a server residing in a different domain than the client.
// It also sets the Content-Type as "application/javascript".
func (c *Context) JSONP(code int, obj interface{}) {
callback := c.DefaultQuery("callback", "")
@@ -968,7 +1016,7 @@ func (c *Context) DataFromReader(code int, contentLength int64, contentType stri
})
}
-// File writes the specified file into the body stream in a efficient way.
+// File writes the specified file into the body stream in an efficient way.
func (c *Context) File(filepath string) {
http.ServeFile(c.Writer, c.Request, filepath)
}
diff --git a/context_appengine.go b/context_appengine.go
index 38c189a0..d5658434 100644
--- a/context_appengine.go
+++ b/context_appengine.go
@@ -1,9 +1,10 @@
-// +build appengine
-
// Copyright 2017 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
+//go:build appengine
+// +build appengine
+
package gin
func init() {
diff --git a/context_test.go b/context_test.go
index 8e1e3b57..993c632f 100644
--- a/context_test.go
+++ b/context_test.go
@@ -1018,7 +1018,9 @@ func TestContextRenderFile(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "func New() *Engine {")
- assert.Equal(t, "text/plain; charset=utf-8", w.Header().Get("Content-Type"))
+ // Content-Type='text/plain; charset=utf-8' when go version <= 1.16,
+ // else, Content-Type='text/x-go; charset=utf-8'
+ assert.NotEqual(t, "", w.Header().Get("Content-Type"))
}
func TestContextRenderFileFromFS(t *testing.T) {
@@ -1030,7 +1032,9 @@ func TestContextRenderFileFromFS(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "func New() *Engine {")
- assert.Equal(t, "text/plain; charset=utf-8", w.Header().Get("Content-Type"))
+ // Content-Type='text/plain; charset=utf-8' when go version <= 1.16,
+ // else, Content-Type='text/x-go; charset=utf-8'
+ assert.NotEqual(t, "", w.Header().Get("Content-Type"))
assert.Equal(t, "/some/path", c.Request.URL.Path)
}
@@ -1044,7 +1048,7 @@ func TestContextRenderAttachment(t *testing.T) {
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "func New() *Engine {")
- assert.Equal(t, fmt.Sprintf("attachment; filename=\"%s\"", newFilename), w.HeaderMap.Get("Content-Disposition"))
+ assert.Equal(t, fmt.Sprintf("attachment; filename=\"%s\"", newFilename), w.Header().Get("Content-Disposition"))
}
// TestContextRenderYAML tests that the response is serialized as YAML
@@ -1388,15 +1392,18 @@ func TestContextAbortWithError(t *testing.T) {
assert.True(t, c.IsAborted())
}
+func resetTrustedCIDRs(c *Context) {
+ c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs()
+}
+
func TestContextClientIP(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil)
+ resetTrustedCIDRs(c)
+ resetContextForClientIPTests(c)
- c.Request.Header.Set("X-Real-IP", " 10.10.10.10 ")
- c.Request.Header.Set("X-Forwarded-For", " 20.20.20.20, 30.30.30.30")
- c.Request.Header.Set("X-Appengine-Remote-Addr", "50.50.50.50")
- c.Request.RemoteAddr = " 40.40.40.40:42123 "
-
+ // Legacy tests (validating that the defaults don't break the
+ // (insecure!) old behaviour)
assert.Equal(t, "20.20.20.20", c.ClientIP())
c.Request.Header.Del("X-Forwarded-For")
@@ -1416,6 +1423,84 @@ func TestContextClientIP(t *testing.T) {
// no port
c.Request.RemoteAddr = "50.50.50.50"
assert.Empty(t, c.ClientIP())
+
+ // Tests exercising the TrustedProxies functionality
+ resetContextForClientIPTests(c)
+
+ // No trusted proxies
+ c.engine.TrustedProxies = []string{}
+ resetTrustedCIDRs(c)
+ c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"}
+ assert.Equal(t, "40.40.40.40", c.ClientIP())
+
+ // Last proxy is trusted, but the RemoteAddr is not
+ c.engine.TrustedProxies = []string{"30.30.30.30"}
+ resetTrustedCIDRs(c)
+ assert.Equal(t, "40.40.40.40", c.ClientIP())
+
+ // Only trust RemoteAddr
+ c.engine.TrustedProxies = []string{"40.40.40.40"}
+ resetTrustedCIDRs(c)
+ assert.Equal(t, "20.20.20.20", c.ClientIP())
+
+ // All steps are trusted
+ c.engine.TrustedProxies = []string{"40.40.40.40", "30.30.30.30", "20.20.20.20"}
+ resetTrustedCIDRs(c)
+ assert.Equal(t, "20.20.20.20", c.ClientIP())
+
+ // Use CIDR
+ c.engine.TrustedProxies = []string{"40.40.25.25/16", "30.30.30.30"}
+ resetTrustedCIDRs(c)
+ assert.Equal(t, "20.20.20.20", c.ClientIP())
+
+ // Use hostname that resolves to all the proxies
+ c.engine.TrustedProxies = []string{"foo"}
+ resetTrustedCIDRs(c)
+ assert.Equal(t, "40.40.40.40", c.ClientIP())
+
+ // Use hostname that returns an error
+ c.engine.TrustedProxies = []string{"bar"}
+ resetTrustedCIDRs(c)
+ assert.Equal(t, "40.40.40.40", c.ClientIP())
+
+ // X-Forwarded-For has a non-IP element
+ c.engine.TrustedProxies = []string{"40.40.40.40"}
+ resetTrustedCIDRs(c)
+ c.Request.Header.Set("X-Forwarded-For", " blah ")
+ assert.Equal(t, "40.40.40.40", c.ClientIP())
+
+ // Result from LookupHost has non-IP element. This should never
+ // happen, but we should test it to make sure we handle it
+ // gracefully.
+ c.engine.TrustedProxies = []string{"baz"}
+ resetTrustedCIDRs(c)
+ c.Request.Header.Set("X-Forwarded-For", " 30.30.30.30 ")
+ assert.Equal(t, "40.40.40.40", c.ClientIP())
+
+ c.engine.TrustedProxies = []string{"40.40.40.40"}
+ resetTrustedCIDRs(c)
+ c.Request.Header.Del("X-Forwarded-For")
+ c.engine.RemoteIPHeaders = []string{"X-Forwarded-For", "X-Real-IP"}
+ assert.Equal(t, "10.10.10.10", c.ClientIP())
+
+ c.engine.RemoteIPHeaders = []string{}
+ c.engine.AppEngine = true
+ assert.Equal(t, "50.50.50.50", c.ClientIP())
+
+ c.Request.Header.Del("X-Appengine-Remote-Addr")
+ assert.Equal(t, "40.40.40.40", c.ClientIP())
+
+ // no port
+ c.Request.RemoteAddr = "50.50.50.50"
+ assert.Empty(t, c.ClientIP())
+}
+
+func resetContextForClientIPTests(c *Context) {
+ c.Request.Header.Set("X-Real-IP", " 10.10.10.10 ")
+ c.Request.Header.Set("X-Forwarded-For", " 20.20.20.20, 30.30.30.30")
+ c.Request.Header.Set("X-Appengine-Remote-Addr", "50.50.50.50")
+ c.Request.RemoteAddr = " 40.40.40.40:42123 "
+ c.engine.AppEngine = false
}
func TestContextContentType(t *testing.T) {
@@ -1960,3 +2045,12 @@ func TestContextWithKeysMutex(t *testing.T) {
assert.Nil(t, value)
assert.False(t, err)
}
+
+func TestRemoteIPFail(t *testing.T) {
+ c, _ := CreateTestContext(httptest.NewRecorder())
+ c.Request, _ = http.NewRequest("POST", "/", nil)
+ c.Request.RemoteAddr = "[:::]:80"
+ ip, trust := c.RemoteIP()
+ assert.Nil(t, ip)
+ assert.False(t, trust)
+}
diff --git a/debug.go b/debug.go
index c66ca440..9bacc685 100644
--- a/debug.go
+++ b/debug.go
@@ -12,7 +12,7 @@ import (
"strings"
)
-const ginSupportMinGoVer = 10
+const ginSupportMinGoVer = 13
// IsDebugging returns true if the framework is running in debug mode.
// Use SetMode(gin.ReleaseMode) to disable debug mode.
@@ -67,7 +67,7 @@ func getMinVer(v string) (uint64, error) {
func debugPrintWARNINGDefault() {
if v, e := getMinVer(runtime.Version()); e == nil && v <= ginSupportMinGoVer {
- debugPrint(`[WARNING] Now Gin requires Go 1.11 or later and Go 1.12 will be required soon.
+ debugPrint(`[WARNING] Now Gin requires Go 1.13+.
`)
}
diff --git a/debug_test.go b/debug_test.go
index d8cd5d1a..05509992 100644
--- a/debug_test.go
+++ b/debug_test.go
@@ -104,7 +104,7 @@ func TestDebugPrintWARNINGDefault(t *testing.T) {
})
m, e := getMinVer(runtime.Version())
if e == nil && m <= ginSupportMinGoVer {
- assert.Equal(t, "[GIN-debug] [WARNING] Now Gin requires Go 1.11 or later and Go 1.12 will be required soon.\n\n[GIN-debug] [WARNING] Creating an Engine instance with the Logger and Recovery middleware already attached.\n\n", re)
+ assert.Equal(t, "[GIN-debug] [WARNING] Now Gin requires Go 1.13+.\n\n[GIN-debug] [WARNING] Creating an Engine instance with the Logger and Recovery middleware already attached.\n\n", re)
} else {
assert.Equal(t, "[GIN-debug] [WARNING] Creating an Engine instance with the Logger and Recovery middleware already attached.\n\n", re)
}
diff --git a/errors_1.13_test.go b/errors_1.13_test.go
deleted file mode 100644
index a8f9a94e..00000000
--- a/errors_1.13_test.go
+++ /dev/null
@@ -1,33 +0,0 @@
-// +build go1.13
-
-package gin
-
-import (
- "errors"
- "fmt"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-type TestErr string
-
-func (e TestErr) Error() string { return string(e) }
-
-// TestErrorUnwrap tests the behavior of gin.Error with "errors.Is()" and "errors.As()".
-// "errors.Is()" and "errors.As()" have been added to the standard library in go 1.13,
-// hence the "// +build go1.13" directive at the beginning of this file.
-func TestErrorUnwrap(t *testing.T) {
- innerErr := TestErr("somme error")
-
- // 2 layers of wrapping : use 'fmt.Errorf("%w")' to wrap a gin.Error{}, which itself wraps innerErr
- err := fmt.Errorf("wrapped: %w", &Error{
- Err: innerErr,
- Type: ErrorTypeAny,
- })
-
- // check that 'errors.Is()' and 'errors.As()' behave as expected :
- assert.True(t, errors.Is(err, innerErr))
- var testErr TestErr
- assert.True(t, errors.As(err, &testErr))
-}
diff --git a/errors_test.go b/errors_test.go
index 6aae1c10..ee95ab31 100644
--- a/errors_test.go
+++ b/errors_test.go
@@ -6,6 +6,7 @@ package gin
import (
"errors"
+ "fmt"
"testing"
"github.com/gin-gonic/gin/internal/json"
@@ -104,3 +105,24 @@ Error #03: third
assert.Nil(t, errs.JSON())
assert.Empty(t, errs.String())
}
+
+type TestErr string
+
+func (e TestErr) Error() string { return string(e) }
+
+// TestErrorUnwrap tests the behavior of gin.Error with "errors.Is()" and "errors.As()".
+// "errors.Is()" and "errors.As()" have been added to the standard library in go 1.13.
+func TestErrorUnwrap(t *testing.T) {
+ innerErr := TestErr("somme error")
+
+ // 2 layers of wrapping : use 'fmt.Errorf("%w")' to wrap a gin.Error{}, which itself wraps innerErr
+ err := fmt.Errorf("wrapped: %w", &Error{
+ Err: innerErr,
+ Type: ErrorTypeAny,
+ })
+
+ // check that 'errors.Is()' and 'errors.As()' behave as expected :
+ assert.True(t, errors.Is(err, innerErr))
+ var testErr TestErr
+ assert.True(t, errors.As(err, &testErr))
+}
diff --git a/gin.go b/gin.go
index f370fa57..4297ce04 100644
--- a/gin.go
+++ b/gin.go
@@ -84,9 +84,26 @@ type Engine struct {
// If no other Method is allowed, the request is delegated to the NotFound
// handler.
HandleMethodNotAllowed bool
- ForwardedByClientIP bool
- // #726 #755 If enabled, it will thrust some headers starting with
+ // If enabled, client IP will be parsed from the request's headers that
+ // match those stored at `(*gin.Engine).RemoteIPHeaders`. If no IP was
+ // fetched, it falls back to the IP obtained from
+ // `(*gin.Context).Request.RemoteAddr`.
+ ForwardedByClientIP bool
+
+ // List of headers used to obtain the client IP when
+ // `(*gin.Engine).ForwardedByClientIP` is `true` and
+ // `(*gin.Context).Request.RemoteAddr` is matched by at least one of the
+ // network origins of `(*gin.Engine).TrustedProxies`.
+ RemoteIPHeaders []string
+
+ // List of network origins (IPv4 addresses, IPv4 CIDRs, IPv6 addresses or
+ // IPv6 CIDRs) from which to trust request's headers that contain
+ // alternative client IP when `(*gin.Engine).ForwardedByClientIP` is
+ // `true`.
+ TrustedProxies []string
+
+ // #726 #755 If enabled, it will trust some headers starting with
// 'X-AppEngine...' for better integration with that PaaS.
AppEngine bool
@@ -117,6 +134,7 @@ type Engine struct {
pool sync.Pool
trees methodTrees
maxParams uint16
+ trustedCIDRs []*net.IPNet
}
var _ IRouter = &Engine{}
@@ -142,6 +160,8 @@ func New() *Engine {
RedirectFixedPath: false,
HandleMethodNotAllowed: false,
ForwardedByClientIP: true,
+ RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"},
+ TrustedProxies: []string{"0.0.0.0/0"},
AppEngine: defaultAppEngine,
UseRawPath: false,
RemoveExtraSlash: false,
@@ -327,12 +347,60 @@ func iterate(path, method string, routes RoutesInfo, root *node) RoutesInfo {
func (engine *Engine) Run(addr ...string) (err error) {
defer func() { debugPrintError(err) }()
+ trustedCIDRs, err := engine.prepareTrustedCIDRs()
+ if err != nil {
+ return err
+ }
+ engine.trustedCIDRs = trustedCIDRs
address := resolveAddress(addr)
debugPrint("Listening and serving HTTP on %s\n", address)
err = http.ListenAndServe(address, engine)
return
}
+func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) {
+ if engine.TrustedProxies == nil {
+ return nil, nil
+ }
+
+ cidr := make([]*net.IPNet, 0, len(engine.TrustedProxies))
+ for _, trustedProxy := range engine.TrustedProxies {
+ if !strings.Contains(trustedProxy, "/") {
+ ip := parseIP(trustedProxy)
+ if ip == nil {
+ return cidr, &net.ParseError{Type: "IP address", Text: trustedProxy}
+ }
+
+ switch len(ip) {
+ case net.IPv4len:
+ trustedProxy += "/32"
+ case net.IPv6len:
+ trustedProxy += "/128"
+ }
+ }
+ _, cidrNet, err := net.ParseCIDR(trustedProxy)
+ if err != nil {
+ return cidr, err
+ }
+ cidr = append(cidr, cidrNet)
+ }
+ return cidr, nil
+}
+
+// parseIP parse a string representation of an IP and returns a net.IP with the
+// minimum byte representation or nil if input is invalid.
+func parseIP(ip string) net.IP {
+ parsedIP := net.ParseIP(ip)
+
+ if ipv4 := parsedIP.To4(); ipv4 != nil {
+ // return ip in a 4-byte representation
+ return ipv4
+ }
+
+ // return ip in a 16-byte representation or nil
+ return parsedIP
+}
+
// RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests.
// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
diff --git a/gin_integration_test.go b/gin_integration_test.go
index 5f508c70..fd972657 100644
--- a/gin_integration_test.go
+++ b/gin_integration_test.go
@@ -14,6 +14,7 @@ import (
"net/http"
"net/http/httptest"
"os"
+ "path/filepath"
"sync"
"testing"
"time"
@@ -54,6 +55,13 @@ func TestRunEmpty(t *testing.T) {
testRequest(t, "http://localhost:8080/example")
}
+func TestTrustedCIDRsForRun(t *testing.T) {
+ os.Setenv("PORT", "")
+ router := New()
+ router.TrustedProxies = []string{"hello/world"}
+ assert.Error(t, router.Run(":8080"))
+}
+
func TestRunTLS(t *testing.T) {
router := New()
go func() {
@@ -146,7 +154,7 @@ func TestRunWithPort(t *testing.T) {
func TestUnixSocket(t *testing.T) {
router := New()
- unixTestSocket := "/tmp/unix_unit_test"
+ unixTestSocket := filepath.Join(os.TempDir(), "unix_unit_test")
defer os.Remove(unixTestSocket)
diff --git a/gin_test.go b/gin_test.go
index 56da0e4e..43f2293c 100644
--- a/gin_test.go
+++ b/gin_test.go
@@ -9,6 +9,7 @@ import (
"fmt"
"html/template"
"io/ioutil"
+ "net"
"net/http"
"net/http/httptest"
"reflect"
@@ -654,6 +655,139 @@ func TestEngineHandleContextManyReEntries(t *testing.T) {
assert.Equal(t, int64(expectValue), middlewareCounter)
}
+func TestPrepareTrustedCIRDsWith(t *testing.T) {
+ r := New()
+
+ // valid ipv4 cidr
+ {
+ expectedTrustedCIDRs := []*net.IPNet{parseCIDR("0.0.0.0/0")}
+ r.TrustedProxies = []string{"0.0.0.0/0"}
+
+ trustedCIDRs, err := r.prepareTrustedCIDRs()
+
+ assert.NoError(t, err)
+ assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs)
+ }
+
+ // invalid ipv4 cidr
+ {
+ r.TrustedProxies = []string{"192.168.1.33/33"}
+
+ _, err := r.prepareTrustedCIDRs()
+
+ assert.Error(t, err)
+ }
+
+ // valid ipv4 address
+ {
+ expectedTrustedCIDRs := []*net.IPNet{parseCIDR("192.168.1.33/32")}
+ r.TrustedProxies = []string{"192.168.1.33"}
+
+ trustedCIDRs, err := r.prepareTrustedCIDRs()
+
+ assert.NoError(t, err)
+ assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs)
+ }
+
+ // invalid ipv4 address
+ {
+ r.TrustedProxies = []string{"192.168.1.256"}
+
+ _, err := r.prepareTrustedCIDRs()
+
+ assert.Error(t, err)
+ }
+
+ // valid ipv6 address
+ {
+ expectedTrustedCIDRs := []*net.IPNet{parseCIDR("2002:0000:0000:1234:abcd:ffff:c0a8:0101/128")}
+ r.TrustedProxies = []string{"2002:0000:0000:1234:abcd:ffff:c0a8:0101"}
+
+ trustedCIDRs, err := r.prepareTrustedCIDRs()
+
+ assert.NoError(t, err)
+ assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs)
+ }
+
+ // invalid ipv6 address
+ {
+ r.TrustedProxies = []string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101"}
+
+ _, err := r.prepareTrustedCIDRs()
+
+ assert.Error(t, err)
+ }
+
+ // valid ipv6 cidr
+ {
+ expectedTrustedCIDRs := []*net.IPNet{parseCIDR("::/0")}
+ r.TrustedProxies = []string{"::/0"}
+
+ trustedCIDRs, err := r.prepareTrustedCIDRs()
+
+ assert.NoError(t, err)
+ assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs)
+ }
+
+ // invalid ipv6 cidr
+ {
+ r.TrustedProxies = []string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101/129"}
+
+ _, err := r.prepareTrustedCIDRs()
+
+ assert.Error(t, err)
+ }
+
+ // valid combination
+ {
+ expectedTrustedCIDRs := []*net.IPNet{
+ parseCIDR("::/0"),
+ parseCIDR("192.168.0.0/16"),
+ parseCIDR("172.16.0.1/32"),
+ }
+ r.TrustedProxies = []string{
+ "::/0",
+ "192.168.0.0/16",
+ "172.16.0.1",
+ }
+
+ trustedCIDRs, err := r.prepareTrustedCIDRs()
+
+ assert.NoError(t, err)
+ assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs)
+ }
+
+ // invalid combination
+ {
+ r.TrustedProxies = []string{
+ "::/0",
+ "192.168.0.0/16",
+ "172.16.0.256",
+ }
+ _, err := r.prepareTrustedCIDRs()
+
+ assert.Error(t, err)
+ }
+
+ // nil value
+ {
+ r.TrustedProxies = nil
+ trustedCIDRs, err := r.prepareTrustedCIDRs()
+
+ assert.Nil(t, trustedCIDRs)
+ assert.Nil(t, err)
+ }
+
+}
+
+func parseCIDR(cidr string) *net.IPNet {
+ _, parsedCIDR, err := net.ParseCIDR(cidr)
+ if err != nil {
+ fmt.Println(err)
+ }
+ return parsedCIDR
+}
+
func assertRoutePresent(t *testing.T, gotRoutes RoutesInfo, wantRoute RouteInfo) {
for _, gotRoute := range gotRoutes {
if gotRoute.Path == wantRoute.Path && gotRoute.Method == wantRoute.Method {
diff --git a/go.sum b/go.sum
index a64b3319..ac92ada3 100644
--- a/go.sum
+++ b/go.sum
@@ -11,6 +11,8 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE=
github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4=
+github.com/goccy/go-json v0.4.11 h1:92nyX606ZN/cUFwctfxwDWm8YWSA38Zlv9s7taFeLyo=
+github.com/goccy/go-json v0.4.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
diff --git a/internal/bytesconv/bytesconv.go b/internal/bytesconv/bytesconv.go
index 7b80e335..86e4c4d4 100644
--- a/internal/bytesconv/bytesconv.go
+++ b/internal/bytesconv/bytesconv.go
@@ -5,16 +5,17 @@
package bytesconv
import (
- "reflect"
"unsafe"
)
// StringToBytes converts string to byte slice without a memory allocation.
-func StringToBytes(s string) (b []byte) {
- sh := *(*reflect.StringHeader)(unsafe.Pointer(&s))
- bh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
- bh.Data, bh.Len, bh.Cap = sh.Data, sh.Len, sh.Len
- return b
+func StringToBytes(s string) []byte {
+ return *(*[]byte)(unsafe.Pointer(
+ &struct {
+ string
+ Cap int
+ }{s, len(s)},
+ ))
}
// BytesToString converts byte slice to string without a memory allocation.
diff --git a/internal/json/go_json.go b/internal/json/go_json.go
new file mode 100644
index 00000000..da960571
--- /dev/null
+++ b/internal/json/go_json.go
@@ -0,0 +1,23 @@
+// Copyright 2017 Bo-Yi Wu. All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+//go:build go_json
+// +build go_json
+
+package json
+
+import json "github.com/goccy/go-json"
+
+var (
+ // Marshal is exported by gin/json package.
+ Marshal = json.Marshal
+ // Unmarshal is exported by gin/json package.
+ Unmarshal = json.Unmarshal
+ // MarshalIndent is exported by gin/json package.
+ MarshalIndent = json.MarshalIndent
+ // NewDecoder is exported by gin/json package.
+ NewDecoder = json.NewDecoder
+ // NewEncoder is exported by gin/json package.
+ NewEncoder = json.NewEncoder
+)
diff --git a/internal/json/json.go b/internal/json/json.go
index 480e8bff..75b60224 100644
--- a/internal/json/json.go
+++ b/internal/json/json.go
@@ -2,7 +2,8 @@
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
-// +build !jsoniter
+//go:build !jsoniter && !go_json
+// +build !jsoniter,!go_json
package json
diff --git a/internal/json/jsoniter.go b/internal/json/jsoniter.go
index 649a3cdb..232f8dca 100644
--- a/internal/json/jsoniter.go
+++ b/internal/json/jsoniter.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
+//go:build jsoniter
// +build jsoniter
package json
diff --git a/logger_test.go b/logger_test.go
index 0d40666e..80961ce1 100644
--- a/logger_test.go
+++ b/logger_test.go
@@ -185,6 +185,8 @@ func TestLoggerWithConfigFormatting(t *testing.T) {
buffer := new(bytes.Buffer)
router := New()
+ router.engine.trustedCIDRs, _ = router.engine.prepareTrustedCIDRs()
+
router.Use(LoggerWithConfig(LoggerConfig{
Output: buffer,
Formatter: func(param LogFormatterParams) string {
diff --git a/mode.go b/mode.go
index 11f833e9..c8813aff 100644
--- a/mode.go
+++ b/mode.go
@@ -63,7 +63,7 @@ func SetMode(value string) {
case TestMode:
ginMode = testCode
default:
- panic("gin mode unknown: " + value)
+ panic("gin mode unknown: " + value + " (available mode: debug release test)")
}
modeName = value
diff --git a/render/msgpack.go b/render/msgpack.go
index be2d45c5..6ef5b6e5 100644
--- a/render/msgpack.go
+++ b/render/msgpack.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
+//go:build !nomsgpack
// +build !nomsgpack
package render
diff --git a/render/render_msgpack_test.go b/render/render_msgpack_test.go
index e439ac48..8170fbe8 100644
--- a/render/render_msgpack_test.go
+++ b/render/render_msgpack_test.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
+//go:build !nomsgpack
// +build !nomsgpack
package render
diff --git a/routergroup.go b/routergroup.go
index 15d9930d..6f14bf59 100644
--- a/routergroup.go
+++ b/routergroup.go
@@ -11,6 +11,11 @@ import (
"strings"
)
+var (
+ // reg match english letters for http method name
+ regEnLetter = regexp.MustCompile("^[A-Z]+$")
+)
+
// IRouter defines all router handle interface includes single and group router.
type IRouter interface {
IRoutes
@@ -87,7 +92,7 @@ func (group *RouterGroup) handle(httpMethod, relativePath string, handlers Handl
// frequently used, non-standardized or custom methods (e.g. for internal
// communication with a proxy).
func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
- if matches, err := regexp.MatchString("^[A-Z]+$", httpMethod); !matches || err != nil {
+ if matched := regEnLetter.MatchString(httpMethod); !matched {
panic("http method " + httpMethod + " is not valid")
}
return group.handle(httpMethod, relativePath, handlers)
diff --git a/routes_test.go b/routes_test.go
index 11ff71a6..036fa1c3 100644
--- a/routes_test.go
+++ b/routes_test.go
@@ -238,7 +238,6 @@ func TestRouteParamsByName(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, name, c.Param("name"))
- assert.Equal(t, name, c.Param("name"))
assert.Equal(t, lastName, c.Param("last_name"))
assert.Empty(t, c.Param("wtf"))
@@ -272,7 +271,6 @@ func TestRouteParamsByNameWithExtraSlash(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, name, c.Param("name"))
- assert.Equal(t, name, c.Param("name"))
assert.Equal(t, lastName, c.Param("last_name"))
assert.Empty(t, c.Param("wtf"))
@@ -362,7 +360,9 @@ func TestRouterMiddlewareAndStatic(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "package gin")
- assert.Equal(t, "text/plain; charset=utf-8", w.Header().Get("Content-Type"))
+ // Content-Type='text/plain; charset=utf-8' when go version <= 1.16,
+ // else, Content-Type='text/x-go; charset=utf-8'
+ assert.NotEqual(t, "", w.Header().Get("Content-Type"))
assert.NotEqual(t, w.Header().Get("Last-Modified"), "Mon, 02 Jan 2006 15:04:05 MST")
assert.Equal(t, "Mon, 02 Jan 2006 15:04:05 MST", w.Header().Get("Expires"))
assert.Equal(t, "Gin Framework", w.Header().Get("x-GIN"))
diff --git a/tree.go b/tree.go
index 7a80af9e..ca753e6d 100644
--- a/tree.go
+++ b/tree.go
@@ -80,6 +80,16 @@ func longestCommonPrefix(a, b string) int {
return i
}
+// addChild will add a child node, keeping wildcards at the end
+func (n *node) addChild(child *node) {
+ if n.wildChild && len(n.children) > 0 {
+ wildcardChild := n.children[len(n.children)-1]
+ n.children = append(n.children[:len(n.children)-1], child, wildcardChild)
+ } else {
+ n.children = append(n.children, child)
+ }
+}
+
func countParams(path string) uint16 {
var n uint16
s := bytesconv.StringToBytes(path)
@@ -103,7 +113,7 @@ type node struct {
wildChild bool
nType nodeType
priority uint32
- children []*node
+ children []*node // child nodes, at most 1 :param style node at the end of the array
handlers HandlersChain
fullPath string
}
@@ -119,7 +129,6 @@ func (n *node) incrementChildPrio(pos int) int {
for ; newPos > 0 && cs[newPos-1].priority < prio; newPos-- {
// Swap node positions
cs[newPos-1], cs[newPos] = cs[newPos], cs[newPos-1]
-
}
// Build new index char string
@@ -178,36 +187,9 @@ walk:
// Make new node a child of this node
if i < len(path) {
path = path[i:]
-
- if n.wildChild {
- parentFullPathIndex += len(n.path)
- n = n.children[0]
- n.priority++
-
- // Check if the wildcard matches
- if len(path) >= len(n.path) && n.path == path[:len(n.path)] &&
- // Adding a child to a catchAll is not possible
- n.nType != catchAll &&
- // Check for longer wildcard, e.g. :name and :names
- (len(n.path) >= len(path) || path[len(n.path)] == '/') {
- continue walk
- }
-
- pathSeg := path
- if n.nType != catchAll {
- pathSeg = strings.SplitN(path, "/", 2)[0]
- }
- prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path
- panic("'" + pathSeg +
- "' in new path '" + fullPath +
- "' conflicts with existing wildcard '" + n.path +
- "' in existing prefix '" + prefix +
- "'")
- }
-
c := path[0]
- // slash after param
+ // '/' after param
if n.nType == param && c == '/' && len(n.children) == 1 {
parentFullPathIndex += len(n.path)
n = n.children[0]
@@ -226,21 +208,47 @@ walk:
}
// Otherwise insert it
- if c != ':' && c != '*' {
+ if c != ':' && c != '*' && n.nType != catchAll {
// []byte for proper unicode char conversion, see #65
n.indices += bytesconv.BytesToString([]byte{c})
child := &node{
fullPath: fullPath,
}
- n.children = append(n.children, child)
+ n.addChild(child)
n.incrementChildPrio(len(n.indices) - 1)
n = child
+ } else if n.wildChild {
+ // inserting a wildcard node, need to check if it conflicts with the existing wildcard
+ n = n.children[len(n.children)-1]
+ n.priority++
+
+ // Check if the wildcard matches
+ if len(path) >= len(n.path) && n.path == path[:len(n.path)] &&
+ // Adding a child to a catchAll is not possible
+ n.nType != catchAll &&
+ // Check for longer wildcard, e.g. :name and :names
+ (len(n.path) >= len(path) || path[len(n.path)] == '/') {
+ continue walk
+ }
+
+ // Wildcard conflict
+ pathSeg := path
+ if n.nType != catchAll {
+ pathSeg = strings.SplitN(pathSeg, "/", 2)[0]
+ }
+ prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path
+ panic("'" + pathSeg +
+ "' in new path '" + fullPath +
+ "' conflicts with existing wildcard '" + n.path +
+ "' in existing prefix '" + prefix +
+ "'")
}
+
n.insertChild(path, fullPath, handlers)
return
}
- // Otherwise and handle to current node
+ // Otherwise add handle to current node
if n.handlers != nil {
panic("handlers are already registered for path '" + fullPath + "'")
}
@@ -294,13 +302,6 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain)
panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
}
- // Check if this node has existing children which would be
- // unreachable if we insert the wildcard here
- if len(n.children) > 0 {
- panic("wildcard segment '" + wildcard +
- "' conflicts with existing children in path '" + fullPath + "'")
- }
-
if wildcard[0] == ':' { // param
if i > 0 {
// Insert prefix before the current wildcard
@@ -308,13 +309,13 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain)
path = path[i:]
}
- n.wildChild = true
child := &node{
nType: param,
path: wildcard,
fullPath: fullPath,
}
- n.children = []*node{child}
+ n.addChild(child)
+ n.wildChild = true
n = child
n.priority++
@@ -327,7 +328,7 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain)
priority: 1,
fullPath: fullPath,
}
- n.children = []*node{child}
+ n.addChild(child)
n = child
continue
}
@@ -361,7 +362,7 @@ func (n *node) insertChild(path string, fullPath string, handlers HandlersChain)
fullPath: fullPath,
}
- n.children = []*node{child}
+ n.addChild(child)
n.indices = string('/')
n = child
n.priority++
@@ -405,18 +406,18 @@ walk: // Outer loop for walking the tree
if len(path) > len(prefix) {
if path[:len(prefix)] == prefix {
path = path[len(prefix):]
- // If this node does not have a wildcard (param or catchAll)
- // child, we can just look up the next child node and continue
- // to walk down the tree
- if !n.wildChild {
- idxc := path[0]
- for i, c := range []byte(n.indices) {
- if c == idxc {
- n = n.children[i]
- continue walk
- }
- }
+ // Try all the non-wildcard children first by matching the indices
+ idxc := path[0]
+ for i, c := range []byte(n.indices) {
+ if c == idxc {
+ n = n.children[i]
+ continue walk
+ }
+ }
+
+ // If there is no wildcard pattern, recommend a redirection
+ if !n.wildChild {
// Nothing found.
// We can recommend to redirect to the same URL without a
// trailing slash if a leaf exists for that path.
@@ -424,8 +425,9 @@ walk: // Outer loop for walking the tree
return
}
- // Handle wildcard child
- n = n.children[0]
+ // Handle wildcard child, which is always at the end of the array
+ n = n.children[len(n.children)-1]
+
switch n.nType {
case param:
// Find param end (either '/' or path end)
@@ -559,8 +561,8 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]by
// Use a static sized buffer on the stack in the common case.
// If the path is too long, allocate a buffer on the heap instead.
buf := make([]byte, 0, stackBufSize)
- if l := len(path) + 1; l > stackBufSize {
- buf = make([]byte, 0, l)
+ if length := len(path) + 1; length > stackBufSize {
+ buf = make([]byte, 0, length)
}
ciPath := n.findCaseInsensitivePathRec(
@@ -600,142 +602,7 @@ walk: // Outer loop for walking the tree
path = path[npLen:]
ciPath = append(ciPath, n.path...)
- if len(path) > 0 {
- // If this node does not have a wildcard (param or catchAll) child,
- // we can just look up the next child node and continue to walk down
- // the tree
- if !n.wildChild {
- // Skip rune bytes already processed
- rb = shiftNRuneBytes(rb, npLen)
-
- if rb[0] != 0 {
- // Old rune not finished
- idxc := rb[0]
- for i, c := range []byte(n.indices) {
- if c == idxc {
- // continue with child node
- n = n.children[i]
- npLen = len(n.path)
- continue walk
- }
- }
- } else {
- // Process a new rune
- var rv rune
-
- // Find rune start.
- // Runes are up to 4 byte long,
- // -4 would definitely be another rune.
- var off int
- for max := min(npLen, 3); off < max; off++ {
- if i := npLen - off; utf8.RuneStart(oldPath[i]) {
- // read rune from cached path
- rv, _ = utf8.DecodeRuneInString(oldPath[i:])
- break
- }
- }
-
- // Calculate lowercase bytes of current rune
- lo := unicode.ToLower(rv)
- utf8.EncodeRune(rb[:], lo)
-
- // Skip already processed bytes
- rb = shiftNRuneBytes(rb, off)
-
- idxc := rb[0]
- for i, c := range []byte(n.indices) {
- // Lowercase matches
- if c == idxc {
- // must use a recursive approach since both the
- // uppercase byte and the lowercase byte might exist
- // as an index
- if out := n.children[i].findCaseInsensitivePathRec(
- path, ciPath, rb, fixTrailingSlash,
- ); out != nil {
- return out
- }
- break
- }
- }
-
- // If we found no match, the same for the uppercase rune,
- // if it differs
- if up := unicode.ToUpper(rv); up != lo {
- utf8.EncodeRune(rb[:], up)
- rb = shiftNRuneBytes(rb, off)
-
- idxc := rb[0]
- for i, c := range []byte(n.indices) {
- // Uppercase matches
- if c == idxc {
- // Continue with child node
- n = n.children[i]
- npLen = len(n.path)
- continue walk
- }
- }
- }
- }
-
- // Nothing found. We can recommend to redirect to the same URL
- // without a trailing slash if a leaf exists for that path
- if fixTrailingSlash && path == "/" && n.handlers != nil {
- return ciPath
- }
- return nil
- }
-
- n = n.children[0]
- switch n.nType {
- case param:
- // Find param end (either '/' or path end)
- end := 0
- for end < len(path) && path[end] != '/' {
- end++
- }
-
- // Add param value to case insensitive path
- ciPath = append(ciPath, path[:end]...)
-
- // We need to go deeper!
- if end < len(path) {
- if len(n.children) > 0 {
- // Continue with child node
- n = n.children[0]
- npLen = len(n.path)
- path = path[end:]
- continue
- }
-
- // ... but we can't
- if fixTrailingSlash && len(path) == end+1 {
- return ciPath
- }
- return nil
- }
-
- if n.handlers != nil {
- return ciPath
- }
-
- if fixTrailingSlash && len(n.children) == 1 {
- // No handle found. Check if a handle for this path + a
- // trailing slash exists
- n = n.children[0]
- if n.path == "/" && n.handlers != nil {
- return append(ciPath, '/')
- }
- }
-
- return nil
-
- case catchAll:
- return append(ciPath, path...)
-
- default:
- panic("invalid node type")
- }
- } else {
+ if len(path) == 0 {
// We should have reached the node containing the handle.
// Check if this node has a handle registered.
if n.handlers != nil {
@@ -758,6 +625,141 @@ walk: // Outer loop for walking the tree
}
return nil
}
+
+ // If this node does not have a wildcard (param or catchAll) child,
+ // we can just look up the next child node and continue to walk down
+ // the tree
+ if !n.wildChild {
+ // Skip rune bytes already processed
+ rb = shiftNRuneBytes(rb, npLen)
+
+ if rb[0] != 0 {
+ // Old rune not finished
+ idxc := rb[0]
+ for i, c := range []byte(n.indices) {
+ if c == idxc {
+ // continue with child node
+ n = n.children[i]
+ npLen = len(n.path)
+ continue walk
+ }
+ }
+ } else {
+ // Process a new rune
+ var rv rune
+
+ // Find rune start.
+ // Runes are up to 4 byte long,
+ // -4 would definitely be another rune.
+ var off int
+ for max := min(npLen, 3); off < max; off++ {
+ if i := npLen - off; utf8.RuneStart(oldPath[i]) {
+ // read rune from cached path
+ rv, _ = utf8.DecodeRuneInString(oldPath[i:])
+ break
+ }
+ }
+
+ // Calculate lowercase bytes of current rune
+ lo := unicode.ToLower(rv)
+ utf8.EncodeRune(rb[:], lo)
+
+ // Skip already processed bytes
+ rb = shiftNRuneBytes(rb, off)
+
+ idxc := rb[0]
+ for i, c := range []byte(n.indices) {
+ // Lowercase matches
+ if c == idxc {
+ // must use a recursive approach since both the
+ // uppercase byte and the lowercase byte might exist
+ // as an index
+ if out := n.children[i].findCaseInsensitivePathRec(
+ path, ciPath, rb, fixTrailingSlash,
+ ); out != nil {
+ return out
+ }
+ break
+ }
+ }
+
+ // If we found no match, the same for the uppercase rune,
+ // if it differs
+ if up := unicode.ToUpper(rv); up != lo {
+ utf8.EncodeRune(rb[:], up)
+ rb = shiftNRuneBytes(rb, off)
+
+ idxc := rb[0]
+ for i, c := range []byte(n.indices) {
+ // Uppercase matches
+ if c == idxc {
+ // Continue with child node
+ n = n.children[i]
+ npLen = len(n.path)
+ continue walk
+ }
+ }
+ }
+ }
+
+ // Nothing found. We can recommend to redirect to the same URL
+ // without a trailing slash if a leaf exists for that path
+ if fixTrailingSlash && path == "/" && n.handlers != nil {
+ return ciPath
+ }
+ return nil
+ }
+
+ n = n.children[0]
+ switch n.nType {
+ case param:
+ // Find param end (either '/' or path end)
+ end := 0
+ for end < len(path) && path[end] != '/' {
+ end++
+ }
+
+ // Add param value to case insensitive path
+ ciPath = append(ciPath, path[:end]...)
+
+ // We need to go deeper!
+ if end < len(path) {
+ if len(n.children) > 0 {
+ // Continue with child node
+ n = n.children[0]
+ npLen = len(n.path)
+ path = path[end:]
+ continue
+ }
+
+ // ... but we can't
+ if fixTrailingSlash && len(path) == end+1 {
+ return ciPath
+ }
+ return nil
+ }
+
+ if n.handlers != nil {
+ return ciPath
+ }
+
+ if fixTrailingSlash && len(n.children) == 1 {
+ // No handle found. Check if a handle for this path + a
+ // trailing slash exists
+ n = n.children[0]
+ if n.path == "/" && n.handlers != nil {
+ return append(ciPath, '/')
+ }
+ }
+
+ return nil
+
+ case catchAll:
+ return append(ciPath, path...)
+
+ default:
+ panic("invalid node type")
+ }
}
// Nothing found.
diff --git a/tree_test.go b/tree_test.go
index 1cb4f559..d7c4fb0b 100644
--- a/tree_test.go
+++ b/tree_test.go
@@ -137,6 +137,8 @@ func TestTreeWildcard(t *testing.T) {
"/",
"/cmd/:tool/:sub",
"/cmd/:tool/",
+ "/cmd/whoami",
+ "/cmd/whoami/root/",
"/src/*filepath",
"/search/",
"/search/:query",
@@ -155,8 +157,12 @@ func TestTreeWildcard(t *testing.T) {
checkRequests(t, tree, testRequests{
{"/", false, "/", nil},
- {"/cmd/test/", false, "/cmd/:tool/", Params{Param{Key: "tool", Value: "test"}}},
- {"/cmd/test", true, "", Params{Param{Key: "tool", Value: "test"}}},
+ {"/cmd/test", true, "/cmd/:tool/", Params{Param{"tool", "test"}}},
+ {"/cmd/test/", false, "/cmd/:tool/", Params{Param{"tool", "test"}}},
+ {"/cmd/whoami", false, "/cmd/whoami", nil},
+ {"/cmd/whoami/", true, "/cmd/whoami", nil},
+ {"/cmd/whoami/root/", false, "/cmd/whoami/root/", nil},
+ {"/cmd/whoami/root", true, "/cmd/whoami/root/", nil},
{"/cmd/test/3", false, "/cmd/:tool/:sub", Params{Param{Key: "tool", Value: "test"}, Param{Key: "sub", Value: "3"}}},
{"/src/", false, "/src/*filepath", Params{Param{Key: "filepath", Value: "/"}}},
{"/src/some/file.png", false, "/src/*filepath", Params{Param{Key: "filepath", Value: "/some/file.png"}}},
@@ -245,20 +251,38 @@ func testRoutes(t *testing.T, routes []testRoute) {
func TestTreeWildcardConflict(t *testing.T) {
routes := []testRoute{
{"/cmd/:tool/:sub", false},
- {"/cmd/vet", true},
+ {"/cmd/vet", false},
+ {"/foo/bar", false},
+ {"/foo/:name", false},
+ {"/foo/:names", true},
+ {"/cmd/*path", true},
+ {"/cmd/:badvar", true},
+ {"/cmd/:tool/names", false},
+ {"/cmd/:tool/:badsub/details", true},
{"/src/*filepath", false},
+ {"/src/:file", true},
+ {"/src/static.json", true},
{"/src/*filepathx", true},
{"/src/", true},
+ {"/src/foo/bar", true},
{"/src1/", false},
{"/src1/*filepath", true},
{"/src2*filepath", true},
+ {"/src2/*filepath", false},
{"/search/:query", false},
- {"/search/invalid", true},
+ {"/search/valid", false},
{"/user_:name", false},
- {"/user_x", true},
+ {"/user_x", false},
{"/user_:name", false},
{"/id:id", false},
- {"/id/:id", true},
+ {"/id/:id", false},
+ }
+ testRoutes(t, routes)
+}
+
+func TestCatchAllAfterSlash(t *testing.T) {
+ routes := []testRoute{
+ {"/non-leading-*catchall", true},
}
testRoutes(t, routes)
}
@@ -266,14 +290,17 @@ func TestTreeWildcardConflict(t *testing.T) {
func TestTreeChildConflict(t *testing.T) {
routes := []testRoute{
{"/cmd/vet", false},
- {"/cmd/:tool/:sub", true},
+ {"/cmd/:tool", false},
+ {"/cmd/:tool/:sub", false},
+ {"/cmd/:tool/misc", false},
+ {"/cmd/:tool/:othersub", true},
{"/src/AUTHORS", false},
{"/src/*filepath", true},
{"/user_x", false},
- {"/user_:name", true},
+ {"/user_:name", false},
{"/id/:id", false},
- {"/id:id", true},
- {"/:id", true},
+ {"/id:id", false},
+ {"/:id", false},
{"/*filepath", true},
}
testRoutes(t, routes)
@@ -688,8 +715,7 @@ func TestTreeWildcardConflictEx(t *testing.T) {
{"/who/are/foo", "/foo", `/who/are/\*you`, `/\*you`},
{"/who/are/foo/", "/foo/", `/who/are/\*you`, `/\*you`},
{"/who/are/foo/bar", "/foo/bar", `/who/are/\*you`, `/\*you`},
- {"/conxxx", "xxx", `/con:tact`, `:tact`},
- {"/conooo/xxx", "ooo", `/con:tact`, `:tact`},
+ {"/con:nection", ":nection", `/con:tact`, `:tact`},
}
for _, conflict := range conflicts {
diff --git a/version.go b/version.go
index 3e9687dc..3647461b 100644
--- a/version.go
+++ b/version.go
@@ -5,4 +5,4 @@
package gin
// Version is the current gin framework's version.
-const Version = "v1.6.3"
+const Version = "v1.7.1"