diff --git a/context_test.go b/context_test.go index 993c632f..aaa358e9 100644 --- a/context_test.go +++ b/context_test.go @@ -1392,14 +1392,10 @@ func TestContextAbortWithError(t *testing.T) { assert.True(t, c.IsAborted()) } -func resetTrustedCIDRs(c *Context) { - c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs() -} - func TestContextClientIP(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) c.Request, _ = http.NewRequest("POST", "/", nil) - resetTrustedCIDRs(c) + c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs() resetContextForClientIPTests(c) // Legacy tests (validating that the defaults don't break the @@ -1428,57 +1424,47 @@ func TestContextClientIP(t *testing.T) { resetContextForClientIPTests(c) // No trusted proxies - c.engine.TrustedProxies = []string{} - resetTrustedCIDRs(c) + _ = c.engine.SetTrustedProxies([]string{}) c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"} assert.Equal(t, "40.40.40.40", c.ClientIP()) // Last proxy is trusted, but the RemoteAddr is not - c.engine.TrustedProxies = []string{"30.30.30.30"} - resetTrustedCIDRs(c) + _ = c.engine.SetTrustedProxies([]string{"30.30.30.30"}) assert.Equal(t, "40.40.40.40", c.ClientIP()) // Only trust RemoteAddr - c.engine.TrustedProxies = []string{"40.40.40.40"} - resetTrustedCIDRs(c) + _ = c.engine.SetTrustedProxies([]string{"40.40.40.40"}) assert.Equal(t, "20.20.20.20", c.ClientIP()) // All steps are trusted - c.engine.TrustedProxies = []string{"40.40.40.40", "30.30.30.30", "20.20.20.20"} - resetTrustedCIDRs(c) + _ = c.engine.SetTrustedProxies([]string{"40.40.40.40", "30.30.30.30", "20.20.20.20"}) assert.Equal(t, "20.20.20.20", c.ClientIP()) // Use CIDR - c.engine.TrustedProxies = []string{"40.40.25.25/16", "30.30.30.30"} - resetTrustedCIDRs(c) + _ = c.engine.SetTrustedProxies([]string{"40.40.25.25/16", "30.30.30.30"}) assert.Equal(t, "20.20.20.20", c.ClientIP()) // Use hostname that resolves to all the proxies - c.engine.TrustedProxies = []string{"foo"} - resetTrustedCIDRs(c) + _ = c.engine.SetTrustedProxies([]string{"foo"}) assert.Equal(t, "40.40.40.40", c.ClientIP()) // Use hostname that returns an error - c.engine.TrustedProxies = []string{"bar"} - resetTrustedCIDRs(c) + _ = c.engine.SetTrustedProxies([]string{"bar"}) assert.Equal(t, "40.40.40.40", c.ClientIP()) // X-Forwarded-For has a non-IP element - c.engine.TrustedProxies = []string{"40.40.40.40"} - resetTrustedCIDRs(c) + _ = c.engine.SetTrustedProxies([]string{"40.40.40.40"}) c.Request.Header.Set("X-Forwarded-For", " blah ") assert.Equal(t, "40.40.40.40", c.ClientIP()) // Result from LookupHost has non-IP element. This should never // happen, but we should test it to make sure we handle it // gracefully. - c.engine.TrustedProxies = []string{"baz"} - resetTrustedCIDRs(c) + _ = c.engine.SetTrustedProxies([]string{"baz"}) c.Request.Header.Set("X-Forwarded-For", " 30.30.30.30 ") assert.Equal(t, "40.40.40.40", c.ClientIP()) - c.engine.TrustedProxies = []string{"40.40.40.40"} - resetTrustedCIDRs(c) + _ = c.engine.SetTrustedProxies([]string{"40.40.40.40"}) c.Request.Header.Del("X-Forwarded-For") c.engine.RemoteIPHeaders = []string{"X-Forwarded-For", "X-Real-IP"} assert.Equal(t, "10.10.10.10", c.ClientIP()) diff --git a/gin.go b/gin.go index 03a0e127..56e5c768 100644 --- a/gin.go +++ b/gin.go @@ -326,11 +326,11 @@ func iterate(path, method string, routes RoutesInfo, root *node) RoutesInfo { func (engine *Engine) Run(addr ...string) (err error) { defer func() { debugPrintError(err) }() - trustedCIDRs, err := engine.prepareTrustedCIDRs() + err = engine.parseTrustedProxies() if err != nil { return err } - engine.trustedCIDRs = trustedCIDRs + address := resolveAddress(addr) debugPrint("Listening and serving HTTP on %s\n", address) err = http.ListenAndServe(address, engine) @@ -366,6 +366,19 @@ func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) { return cidr, nil } +// SetTrustedProxies set Engine.TrustedProxies +func (engine *Engine) SetTrustedProxies(trustedProxies []string) error { + engine.TrustedProxies = trustedProxies + return engine.parseTrustedProxies() +} + +// parseTrustedProxies parse Engine.TrustedProxies to Engine.trustedCIDRs +func (engine *Engine) parseTrustedProxies() error { + trustedCIDRs, err := engine.prepareTrustedCIDRs() + engine.trustedCIDRs = trustedCIDRs + return err +} + // parseIP parse a string representation of an IP and returns a net.IP with the // minimum byte representation or nil if input is invalid. func parseIP(ip string) net.IP { @@ -387,6 +400,11 @@ func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) { debugPrint("Listening and serving HTTPS on %s\n", addr) defer func() { debugPrintError(err) }() + err = engine.parseTrustedProxies() + if err != nil { + return err + } + err = http.ListenAndServeTLS(addr, certFile, keyFile, engine) return } @@ -398,6 +416,11 @@ func (engine *Engine) RunUnix(file string) (err error) { debugPrint("Listening and serving HTTP on unix:/%s", file) defer func() { debugPrintError(err) }() + err = engine.parseTrustedProxies() + if err != nil { + return err + } + listener, err := net.Listen("unix", file) if err != nil { return @@ -416,6 +439,11 @@ func (engine *Engine) RunFd(fd int) (err error) { debugPrint("Listening and serving HTTP on fd@%d", fd) defer func() { debugPrintError(err) }() + err = engine.parseTrustedProxies() + if err != nil { + return err + } + f := os.NewFile(uintptr(fd), fmt.Sprintf("fd@%d", fd)) listener, err := net.FileListener(f) if err != nil { @@ -431,6 +459,12 @@ func (engine *Engine) RunFd(fd int) (err error) { func (engine *Engine) RunListener(listener net.Listener) (err error) { debugPrint("Listening and serving HTTP on listener what's bind with address@%s", listener.Addr()) defer func() { debugPrintError(err) }() + + err = engine.parseTrustedProxies() + if err != nil { + return err + } + err = http.Serve(listener, engine) return } diff --git a/gin_integration_test.go b/gin_integration_test.go index fd972657..2eb2d52b 100644 --- a/gin_integration_test.go +++ b/gin_integration_test.go @@ -55,13 +55,74 @@ func TestRunEmpty(t *testing.T) { testRequest(t, "http://localhost:8080/example") } -func TestTrustedCIDRsForRun(t *testing.T) { +func TestBadTrustedCIDRsForRun(t *testing.T) { os.Setenv("PORT", "") router := New() router.TrustedProxies = []string{"hello/world"} assert.Error(t, router.Run(":8080")) } +func TestBadTrustedCIDRsForRunUnix(t *testing.T) { + router := New() + router.TrustedProxies = []string{"hello/world"} + + unixTestSocket := filepath.Join(os.TempDir(), "unix_unit_test") + + defer os.Remove(unixTestSocket) + + go func() { + router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") }) + assert.Error(t, router.RunUnix(unixTestSocket)) + }() + // have to wait for the goroutine to start and run the server + // otherwise the main thread will complete + time.Sleep(5 * time.Millisecond) +} + +func TestBadTrustedCIDRsForRunFd(t *testing.T) { + router := New() + router.TrustedProxies = []string{"hello/world"} + + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + assert.NoError(t, err) + listener, err := net.ListenTCP("tcp", addr) + assert.NoError(t, err) + socketFile, err := listener.File() + assert.NoError(t, err) + + go func() { + router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") }) + assert.Error(t, router.RunFd(int(socketFile.Fd()))) + }() + // have to wait for the goroutine to start and run the server + // otherwise the main thread will complete + time.Sleep(5 * time.Millisecond) +} + +func TestBadTrustedCIDRsForRunListener(t *testing.T) { + router := New() + router.TrustedProxies = []string{"hello/world"} + + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + assert.NoError(t, err) + listener, err := net.ListenTCP("tcp", addr) + assert.NoError(t, err) + go func() { + router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") }) + assert.Error(t, router.RunListener(listener)) + }() + // have to wait for the goroutine to start and run the server + // otherwise the main thread will complete + time.Sleep(5 * time.Millisecond) +} + +func TestBadTrustedCIDRsForRunTLS(t *testing.T) { + os.Setenv("PORT", "") + router := New() + router.TrustedProxies = []string{"hello/world"} + assert.Error(t, router.RunTLS(":8080", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem")) +} + func TestRunTLS(t *testing.T) { router := New() go func() {