Perfect TrustedProxies feature

This commit is contained in:
Notealot 2021-09-29 22:56:28 +08:00
parent 6d75aba83f
commit 4d41a2f010
6 changed files with 80 additions and 72 deletions

View File

@ -2202,11 +2202,17 @@ Gin lets you specify which headers to hold the real client IP (if any),
as well as specifying which proxies (or direct clients) you trust to as well as specifying which proxies (or direct clients) you trust to
specify one of these headers. specify one of these headers.
The `TrustedProxies` slice on your `gin.Engine` specifes network addresses or Use function `SetTrustedProxies()` on your `gin.Engine` to specifies network addresses
network CIDRs from where clients which their request headers related to client or network CIDRs from where clients which their request headers related to client
IP can be trusted. They can be IPv4 addresses, IPv4 CIDRs, IPv6 addresses or IP can be trusted. They can be IPv4 addresses, IPv4 CIDRs, IPv6 addresses or
IPv6 CIDRs. IPv6 CIDRs.
**Attention:** Gin trust all proxies by default if you don't specify a trusted
proxy using this function, **this is not safe**. At the same time, if you don't use
any proxy, you can disable this feature by use `Engine.SetTrustedProxies(nil)`,
then `Context.ClientIP()` will return the remote address directly to avoid some
unnecessary computation.
```go ```go
import ( import (
"fmt" "fmt"
@ -2217,7 +2223,7 @@ import (
func main() { func main() {
router := gin.Default() router := gin.Default()
router.TrustedProxies = []string{"192.168.1.2"} router.SetTrustedProxies([]string{"192.168.1.2"})
router.GET("/", func(c *gin.Context) { router.GET("/", func(c *gin.Context) {
// If the client is 192.168.1.2, use the X-Forwarded-For // If the client is 192.168.1.2, use the X-Forwarded-For

View File

@ -778,7 +778,7 @@ func (c *Context) ClientIP() string {
// RemoteIP parses the IP from Request.RemoteAddr, normalizes and returns the IP (without the port). // RemoteIP parses the IP from Request.RemoteAddr, normalizes and returns the IP (without the port).
// It also checks if the remoteIP is a trusted proxy or not. // It also checks if the remoteIP is a trusted proxy or not.
// In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks // In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks
// defined in Engine.TrustedProxies // defined by Engine.SetTrustedProxies()
func (c *Context) RemoteIP() (net.IP, bool) { func (c *Context) RemoteIP() (net.IP, bool) {
ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)) ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr))
if err != nil { if err != nil {

View File

@ -1409,6 +1409,10 @@ func TestContextClientIP(t *testing.T) {
c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"} c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"}
assert.Equal(t, "40.40.40.40", c.ClientIP()) assert.Equal(t, "40.40.40.40", c.ClientIP())
// Disabled TrustedProxies feature
_ = c.engine.SetTrustedProxies(nil)
assert.Equal(t, "40.40.40.40", c.ClientIP())
// Last proxy is trusted, but the RemoteAddr is not // Last proxy is trusted, but the RemoteAddr is not
_ = c.engine.SetTrustedProxies([]string{"30.30.30.30"}) _ = c.engine.SetTrustedProxies([]string{"30.30.30.30"})
assert.Equal(t, "40.40.40.40", c.ClientIP()) assert.Equal(t, "40.40.40.40", c.ClientIP())

70
gin.go
View File

@ -11,6 +11,7 @@ import (
"net/http" "net/http"
"os" "os"
"path" "path"
"reflect"
"strings" "strings"
"sync" "sync"
@ -27,6 +28,8 @@ var (
var defaultPlatform string var defaultPlatform string
var defaultTrustedCIDRs = []*net.IPNet{{IP: net.IP{0x0, 0x0, 0x0, 0x0}, Mask: net.IPMask{0x0, 0x0, 0x0, 0x0}}} // 0.0.0.0/0
// HandlerFunc defines the handler used by gin middleware as return value. // HandlerFunc defines the handler used by gin middleware as return value.
type HandlerFunc func(*Context) type HandlerFunc func(*Context)
@ -119,15 +122,9 @@ type Engine struct {
// List of headers used to obtain the client IP when // List of headers used to obtain the client IP when
// `(*gin.Engine).ForwardedByClientIP` is `true` and // `(*gin.Engine).ForwardedByClientIP` is `true` and
// `(*gin.Context).Request.RemoteAddr` is matched by at least one of the // `(*gin.Context).Request.RemoteAddr` is matched by at least one of the
// network origins of `(*gin.Engine).TrustedProxies`. // network origins of list set by `(*gin.Engine).SetTrustedProxies()`.
RemoteIPHeaders []string RemoteIPHeaders []string
// List of network origins (IPv4 addresses, IPv4 CIDRs, IPv6 addresses or
// IPv6 CIDRs) from which to trust request's headers that contain
// alternative client IP when `(*gin.Engine).ForwardedByClientIP` is
// `true`.
TrustedProxies []string
// If set to a constant of value gin.Platform*, trusts the headers set by // If set to a constant of value gin.Platform*, trusts the headers set by
// that platform, for example to determine the client IP // that platform, for example to determine the client IP
TrustedPlatform string TrustedPlatform string
@ -147,6 +144,7 @@ type Engine struct {
pool sync.Pool pool sync.Pool
trees methodTrees trees methodTrees
maxParams uint16 maxParams uint16
trustedProxies []string
trustedCIDRs []*net.IPNet trustedCIDRs []*net.IPNet
} }
@ -174,7 +172,6 @@ func New() *Engine {
HandleMethodNotAllowed: false, HandleMethodNotAllowed: false,
ForwardedByClientIP: true, ForwardedByClientIP: true,
RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"},
TrustedProxies: []string{"0.0.0.0/0"},
TrustedPlatform: defaultPlatform, TrustedPlatform: defaultPlatform,
UseRawPath: false, UseRawPath: false,
RemoveExtraSlash: false, RemoveExtraSlash: false,
@ -183,7 +180,8 @@ func New() *Engine {
trees: make(methodTrees, 0, 9), trees: make(methodTrees, 0, 9),
delims: render.Delims{Left: "{{", Right: "}}"}, delims: render.Delims{Left: "{{", Right: "}}"},
secureJSONPrefix: "while(1);", secureJSONPrefix: "while(1);",
trustedCIDRs: []*net.IPNet{{IP: net.IP{0x0, 0x0, 0x0, 0x0}, Mask: net.IPMask{0x0, 0x0, 0x0, 0x0}}}, trustedProxies: []string{"0.0.0.0/0"},
trustedCIDRs: defaultTrustedCIDRs,
} }
engine.RouterGroup.engine = engine engine.RouterGroup.engine = engine
engine.pool.New = func() interface{} { engine.pool.New = func() interface{} {
@ -342,9 +340,9 @@ func iterate(path, method string, routes RoutesInfo, root *node) RoutesInfo {
func (engine *Engine) Run(addr ...string) (err error) { func (engine *Engine) Run(addr ...string) (err error) {
defer func() { debugPrintError(err) }() defer func() { debugPrintError(err) }()
err = engine.parseTrustedProxies() if engine.isUnsafeTrustedProxies() {
if err != nil { debugPrint("[WARNING] You trusted all proxies, this is not safe. We recommend you to set a value.\n" +
return err "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.")
} }
address := resolveAddress(addr) address := resolveAddress(addr)
@ -354,12 +352,12 @@ func (engine *Engine) Run(addr ...string) (err error) {
} }
func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) { func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) {
if engine.TrustedProxies == nil { if engine.trustedProxies == nil {
return nil, nil return nil, nil
} }
cidr := make([]*net.IPNet, 0, len(engine.TrustedProxies)) cidr := make([]*net.IPNet, 0, len(engine.trustedProxies))
for _, trustedProxy := range engine.TrustedProxies { for _, trustedProxy := range engine.trustedProxies {
if !strings.Contains(trustedProxy, "/") { if !strings.Contains(trustedProxy, "/") {
ip := parseIP(trustedProxy) ip := parseIP(trustedProxy)
if ip == nil { if ip == nil {
@ -382,13 +380,25 @@ func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) {
return cidr, nil return cidr, nil
} }
// SetTrustedProxies set Engine.TrustedProxies // SetTrustedProxies set a list of network origins (IPv4 addresses,
// IPv4 CIDRs, IPv6 addresses or IPv6 CIDRs) from which to trust
// request's headers that contain alternative client IP when
// `(*gin.Engine).ForwardedByClientIP` is `true`. `TrustedProxies`
// feature is enabled by default, and it also trusts all proxies
// by default. If you want to disable this feature, use
// Engine.SetTrustedProxies(nil), then Context.ClientIP() will
// return the remote address directly.
func (engine *Engine) SetTrustedProxies(trustedProxies []string) error { func (engine *Engine) SetTrustedProxies(trustedProxies []string) error {
engine.TrustedProxies = trustedProxies engine.trustedProxies = trustedProxies
return engine.parseTrustedProxies() return engine.parseTrustedProxies()
} }
// parseTrustedProxies parse Engine.TrustedProxies to Engine.trustedCIDRs // isUnsafeTrustedProxies equals Engine.trustedCIDRs and defaultTrustedCIDRs, it's not safe if true
func (engine *Engine) isUnsafeTrustedProxies() bool {
return reflect.DeepEqual(engine.trustedCIDRs, defaultTrustedCIDRs)
}
// parseTrustedProxies parse Engine.trustedProxies to Engine.trustedCIDRs
func (engine *Engine) parseTrustedProxies() error { func (engine *Engine) parseTrustedProxies() error {
trustedCIDRs, err := engine.prepareTrustedCIDRs() trustedCIDRs, err := engine.prepareTrustedCIDRs()
engine.trustedCIDRs = trustedCIDRs engine.trustedCIDRs = trustedCIDRs
@ -416,9 +426,9 @@ func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) {
debugPrint("Listening and serving HTTPS on %s\n", addr) debugPrint("Listening and serving HTTPS on %s\n", addr)
defer func() { debugPrintError(err) }() defer func() { debugPrintError(err) }()
err = engine.parseTrustedProxies() if engine.isUnsafeTrustedProxies() {
if err != nil { debugPrint("[WARNING] You trusted all proxies, this is not safe. We recommend you to set a value.\n" +
return err "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.")
} }
err = http.ListenAndServeTLS(addr, certFile, keyFile, engine) err = http.ListenAndServeTLS(addr, certFile, keyFile, engine)
@ -432,9 +442,9 @@ func (engine *Engine) RunUnix(file string) (err error) {
debugPrint("Listening and serving HTTP on unix:/%s", file) debugPrint("Listening and serving HTTP on unix:/%s", file)
defer func() { debugPrintError(err) }() defer func() { debugPrintError(err) }()
err = engine.parseTrustedProxies() if engine.isUnsafeTrustedProxies() {
if err != nil { debugPrint("[WARNING] You trusted all proxies, this is not safe. We recommend you to set a value.\n" +
return err "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.")
} }
listener, err := net.Listen("unix", file) listener, err := net.Listen("unix", file)
@ -455,9 +465,9 @@ func (engine *Engine) RunFd(fd int) (err error) {
debugPrint("Listening and serving HTTP on fd@%d", fd) debugPrint("Listening and serving HTTP on fd@%d", fd)
defer func() { debugPrintError(err) }() defer func() { debugPrintError(err) }()
err = engine.parseTrustedProxies() if engine.isUnsafeTrustedProxies() {
if err != nil { debugPrint("[WARNING] You trusted all proxies, this is not safe. We recommend you to set a value.\n" +
return err "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.")
} }
f := os.NewFile(uintptr(fd), fmt.Sprintf("fd@%d", fd)) f := os.NewFile(uintptr(fd), fmt.Sprintf("fd@%d", fd))
@ -476,9 +486,9 @@ func (engine *Engine) RunListener(listener net.Listener) (err error) {
debugPrint("Listening and serving HTTP on listener what's bind with address@%s", listener.Addr()) debugPrint("Listening and serving HTTP on listener what's bind with address@%s", listener.Addr())
defer func() { debugPrintError(err) }() defer func() { debugPrintError(err) }()
err = engine.parseTrustedProxies() if engine.isUnsafeTrustedProxies() {
if err != nil { debugPrint("[WARNING] You trusted all proxies, this is not safe. We recommend you to set a value.\n" +
return err "Please check https://pkg.go.dev/github.com/gin-gonic/gin#readme-don-t-trust-all-proxies for details.")
} }
err = http.Serve(listener, engine) err = http.Serve(listener, engine)

View File

@ -76,6 +76,12 @@ func TestRunEmpty(t *testing.T) {
testRequest(t, "http://localhost:8080/example") testRequest(t, "http://localhost:8080/example")
} }
func TestBadTrustedCIDRs(t *testing.T) {
router := New()
assert.Error(t, router.SetTrustedProxies([]string{"hello/world"}))
}
/* legacy tests
func TestBadTrustedCIDRsForRun(t *testing.T) { func TestBadTrustedCIDRsForRun(t *testing.T) {
os.Setenv("PORT", "") os.Setenv("PORT", "")
router := New() router := New()
@ -143,6 +149,7 @@ func TestBadTrustedCIDRsForRunTLS(t *testing.T) {
router.TrustedProxies = []string{"hello/world"} router.TrustedProxies = []string{"hello/world"}
assert.Error(t, router.RunTLS(":8080", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem")) assert.Error(t, router.RunTLS(":8080", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem"))
} }
*/
func TestRunTLS(t *testing.T) { func TestRunTLS(t *testing.T) {
router := New() router := New()

View File

@ -539,19 +539,15 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) {
// valid ipv4 cidr // valid ipv4 cidr
{ {
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("0.0.0.0/0")} expectedTrustedCIDRs := []*net.IPNet{parseCIDR("0.0.0.0/0")}
r.TrustedProxies = []string{"0.0.0.0/0"} err := r.SetTrustedProxies([]string{"0.0.0.0/0"})
trustedCIDRs, err := r.prepareTrustedCIDRs()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
} }
// invalid ipv4 cidr // invalid ipv4 cidr
{ {
r.TrustedProxies = []string{"192.168.1.33/33"} err := r.SetTrustedProxies([]string{"192.168.1.33/33"})
_, err := r.prepareTrustedCIDRs()
assert.Error(t, err) assert.Error(t, err)
} }
@ -559,19 +555,16 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) {
// valid ipv4 address // valid ipv4 address
{ {
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("192.168.1.33/32")} expectedTrustedCIDRs := []*net.IPNet{parseCIDR("192.168.1.33/32")}
r.TrustedProxies = []string{"192.168.1.33"}
trustedCIDRs, err := r.prepareTrustedCIDRs() err := r.SetTrustedProxies([]string{"192.168.1.33"})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
} }
// invalid ipv4 address // invalid ipv4 address
{ {
r.TrustedProxies = []string{"192.168.1.256"} err := r.SetTrustedProxies([]string{"192.168.1.256"})
_, err := r.prepareTrustedCIDRs()
assert.Error(t, err) assert.Error(t, err)
} }
@ -579,19 +572,15 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) {
// valid ipv6 address // valid ipv6 address
{ {
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("2002:0000:0000:1234:abcd:ffff:c0a8:0101/128")} expectedTrustedCIDRs := []*net.IPNet{parseCIDR("2002:0000:0000:1234:abcd:ffff:c0a8:0101/128")}
r.TrustedProxies = []string{"2002:0000:0000:1234:abcd:ffff:c0a8:0101"} err := r.SetTrustedProxies([]string{"2002:0000:0000:1234:abcd:ffff:c0a8:0101"})
trustedCIDRs, err := r.prepareTrustedCIDRs()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
} }
// invalid ipv6 address // invalid ipv6 address
{ {
r.TrustedProxies = []string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101"} err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101"})
_, err := r.prepareTrustedCIDRs()
assert.Error(t, err) assert.Error(t, err)
} }
@ -599,19 +588,15 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) {
// valid ipv6 cidr // valid ipv6 cidr
{ {
expectedTrustedCIDRs := []*net.IPNet{parseCIDR("::/0")} expectedTrustedCIDRs := []*net.IPNet{parseCIDR("::/0")}
r.TrustedProxies = []string{"::/0"} err := r.SetTrustedProxies([]string{"::/0"})
trustedCIDRs, err := r.prepareTrustedCIDRs()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
} }
// invalid ipv6 cidr // invalid ipv6 cidr
{ {
r.TrustedProxies = []string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101/129"} err := r.SetTrustedProxies([]string{"gggg:0000:0000:1234:abcd:ffff:c0a8:0101/129"})
_, err := r.prepareTrustedCIDRs()
assert.Error(t, err) assert.Error(t, err)
} }
@ -623,36 +608,32 @@ func TestPrepareTrustedCIRDsWith(t *testing.T) {
parseCIDR("192.168.0.0/16"), parseCIDR("192.168.0.0/16"),
parseCIDR("172.16.0.1/32"), parseCIDR("172.16.0.1/32"),
} }
r.TrustedProxies = []string{ err := r.SetTrustedProxies([]string{
"::/0", "::/0",
"192.168.0.0/16", "192.168.0.0/16",
"172.16.0.1", "172.16.0.1",
} })
trustedCIDRs, err := r.prepareTrustedCIDRs()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedTrustedCIDRs, trustedCIDRs) assert.Equal(t, expectedTrustedCIDRs, r.trustedCIDRs)
} }
// invalid combination // invalid combination
{ {
r.TrustedProxies = []string{ err := r.SetTrustedProxies([]string{
"::/0", "::/0",
"192.168.0.0/16", "192.168.0.0/16",
"172.16.0.256", "172.16.0.256",
} })
_, err := r.prepareTrustedCIDRs()
assert.Error(t, err) assert.Error(t, err)
} }
// nil value // nil value
{ {
r.TrustedProxies = nil err := r.SetTrustedProxies(nil)
trustedCIDRs, err := r.prepareTrustedCIDRs()
assert.Nil(t, trustedCIDRs) assert.Nil(t, r.trustedCIDRs)
assert.Nil(t, err) assert.Nil(t, err)
} }