set engine.TrustedProxies For items that don't use gin.RUN (#2692)

Co-authored-by: Bo-Yi Wu <appleboy.tw@gmail.com>
This commit is contained in:
yiranzai 2021-05-25 13:47:35 +08:00 committed by GitHub
parent 328d0b8076
commit b5ca989875
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 109 additions and 28 deletions

View File

@ -1392,14 +1392,10 @@ func TestContextAbortWithError(t *testing.T) {
assert.True(t, c.IsAborted()) assert.True(t, c.IsAborted())
} }
func resetTrustedCIDRs(c *Context) {
c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs()
}
func TestContextClientIP(t *testing.T) { func TestContextClientIP(t *testing.T) {
c, _ := CreateTestContext(httptest.NewRecorder()) c, _ := CreateTestContext(httptest.NewRecorder())
c.Request, _ = http.NewRequest("POST", "/", nil) c.Request, _ = http.NewRequest("POST", "/", nil)
resetTrustedCIDRs(c) c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs()
resetContextForClientIPTests(c) resetContextForClientIPTests(c)
// Legacy tests (validating that the defaults don't break the // Legacy tests (validating that the defaults don't break the
@ -1428,57 +1424,47 @@ func TestContextClientIP(t *testing.T) {
resetContextForClientIPTests(c) resetContextForClientIPTests(c)
// No trusted proxies // No trusted proxies
c.engine.TrustedProxies = []string{} _ = c.engine.SetTrustedProxies([]string{})
resetTrustedCIDRs(c)
c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"} c.engine.RemoteIPHeaders = []string{"X-Forwarded-For"}
assert.Equal(t, "40.40.40.40", c.ClientIP()) assert.Equal(t, "40.40.40.40", c.ClientIP())
// Last proxy is trusted, but the RemoteAddr is not // Last proxy is trusted, but the RemoteAddr is not
c.engine.TrustedProxies = []string{"30.30.30.30"} _ = c.engine.SetTrustedProxies([]string{"30.30.30.30"})
resetTrustedCIDRs(c)
assert.Equal(t, "40.40.40.40", c.ClientIP()) assert.Equal(t, "40.40.40.40", c.ClientIP())
// Only trust RemoteAddr // Only trust RemoteAddr
c.engine.TrustedProxies = []string{"40.40.40.40"} _ = c.engine.SetTrustedProxies([]string{"40.40.40.40"})
resetTrustedCIDRs(c)
assert.Equal(t, "20.20.20.20", c.ClientIP()) assert.Equal(t, "20.20.20.20", c.ClientIP())
// All steps are trusted // All steps are trusted
c.engine.TrustedProxies = []string{"40.40.40.40", "30.30.30.30", "20.20.20.20"} _ = c.engine.SetTrustedProxies([]string{"40.40.40.40", "30.30.30.30", "20.20.20.20"})
resetTrustedCIDRs(c)
assert.Equal(t, "20.20.20.20", c.ClientIP()) assert.Equal(t, "20.20.20.20", c.ClientIP())
// Use CIDR // Use CIDR
c.engine.TrustedProxies = []string{"40.40.25.25/16", "30.30.30.30"} _ = c.engine.SetTrustedProxies([]string{"40.40.25.25/16", "30.30.30.30"})
resetTrustedCIDRs(c)
assert.Equal(t, "20.20.20.20", c.ClientIP()) assert.Equal(t, "20.20.20.20", c.ClientIP())
// Use hostname that resolves to all the proxies // Use hostname that resolves to all the proxies
c.engine.TrustedProxies = []string{"foo"} _ = c.engine.SetTrustedProxies([]string{"foo"})
resetTrustedCIDRs(c)
assert.Equal(t, "40.40.40.40", c.ClientIP()) assert.Equal(t, "40.40.40.40", c.ClientIP())
// Use hostname that returns an error // Use hostname that returns an error
c.engine.TrustedProxies = []string{"bar"} _ = c.engine.SetTrustedProxies([]string{"bar"})
resetTrustedCIDRs(c)
assert.Equal(t, "40.40.40.40", c.ClientIP()) assert.Equal(t, "40.40.40.40", c.ClientIP())
// X-Forwarded-For has a non-IP element // X-Forwarded-For has a non-IP element
c.engine.TrustedProxies = []string{"40.40.40.40"} _ = c.engine.SetTrustedProxies([]string{"40.40.40.40"})
resetTrustedCIDRs(c)
c.Request.Header.Set("X-Forwarded-For", " blah ") c.Request.Header.Set("X-Forwarded-For", " blah ")
assert.Equal(t, "40.40.40.40", c.ClientIP()) assert.Equal(t, "40.40.40.40", c.ClientIP())
// Result from LookupHost has non-IP element. This should never // Result from LookupHost has non-IP element. This should never
// happen, but we should test it to make sure we handle it // happen, but we should test it to make sure we handle it
// gracefully. // gracefully.
c.engine.TrustedProxies = []string{"baz"} _ = c.engine.SetTrustedProxies([]string{"baz"})
resetTrustedCIDRs(c)
c.Request.Header.Set("X-Forwarded-For", " 30.30.30.30 ") c.Request.Header.Set("X-Forwarded-For", " 30.30.30.30 ")
assert.Equal(t, "40.40.40.40", c.ClientIP()) assert.Equal(t, "40.40.40.40", c.ClientIP())
c.engine.TrustedProxies = []string{"40.40.40.40"} _ = c.engine.SetTrustedProxies([]string{"40.40.40.40"})
resetTrustedCIDRs(c)
c.Request.Header.Del("X-Forwarded-For") c.Request.Header.Del("X-Forwarded-For")
c.engine.RemoteIPHeaders = []string{"X-Forwarded-For", "X-Real-IP"} c.engine.RemoteIPHeaders = []string{"X-Forwarded-For", "X-Real-IP"}
assert.Equal(t, "10.10.10.10", c.ClientIP()) assert.Equal(t, "10.10.10.10", c.ClientIP())

38
gin.go
View File

@ -326,11 +326,11 @@ func iterate(path, method string, routes RoutesInfo, root *node) RoutesInfo {
func (engine *Engine) Run(addr ...string) (err error) { func (engine *Engine) Run(addr ...string) (err error) {
defer func() { debugPrintError(err) }() defer func() { debugPrintError(err) }()
trustedCIDRs, err := engine.prepareTrustedCIDRs() err = engine.parseTrustedProxies()
if err != nil { if err != nil {
return err return err
} }
engine.trustedCIDRs = trustedCIDRs
address := resolveAddress(addr) address := resolveAddress(addr)
debugPrint("Listening and serving HTTP on %s\n", address) debugPrint("Listening and serving HTTP on %s\n", address)
err = http.ListenAndServe(address, engine) err = http.ListenAndServe(address, engine)
@ -366,6 +366,19 @@ func (engine *Engine) prepareTrustedCIDRs() ([]*net.IPNet, error) {
return cidr, nil return cidr, nil
} }
// SetTrustedProxies set Engine.TrustedProxies
func (engine *Engine) SetTrustedProxies(trustedProxies []string) error {
engine.TrustedProxies = trustedProxies
return engine.parseTrustedProxies()
}
// parseTrustedProxies parse Engine.TrustedProxies to Engine.trustedCIDRs
func (engine *Engine) parseTrustedProxies() error {
trustedCIDRs, err := engine.prepareTrustedCIDRs()
engine.trustedCIDRs = trustedCIDRs
return err
}
// parseIP parse a string representation of an IP and returns a net.IP with the // parseIP parse a string representation of an IP and returns a net.IP with the
// minimum byte representation or nil if input is invalid. // minimum byte representation or nil if input is invalid.
func parseIP(ip string) net.IP { func parseIP(ip string) net.IP {
@ -387,6 +400,11 @@ func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) {
debugPrint("Listening and serving HTTPS on %s\n", addr) debugPrint("Listening and serving HTTPS on %s\n", addr)
defer func() { debugPrintError(err) }() defer func() { debugPrintError(err) }()
err = engine.parseTrustedProxies()
if err != nil {
return err
}
err = http.ListenAndServeTLS(addr, certFile, keyFile, engine) err = http.ListenAndServeTLS(addr, certFile, keyFile, engine)
return return
} }
@ -398,6 +416,11 @@ func (engine *Engine) RunUnix(file string) (err error) {
debugPrint("Listening and serving HTTP on unix:/%s", file) debugPrint("Listening and serving HTTP on unix:/%s", file)
defer func() { debugPrintError(err) }() defer func() { debugPrintError(err) }()
err = engine.parseTrustedProxies()
if err != nil {
return err
}
listener, err := net.Listen("unix", file) listener, err := net.Listen("unix", file)
if err != nil { if err != nil {
return return
@ -416,6 +439,11 @@ func (engine *Engine) RunFd(fd int) (err error) {
debugPrint("Listening and serving HTTP on fd@%d", fd) debugPrint("Listening and serving HTTP on fd@%d", fd)
defer func() { debugPrintError(err) }() defer func() { debugPrintError(err) }()
err = engine.parseTrustedProxies()
if err != nil {
return err
}
f := os.NewFile(uintptr(fd), fmt.Sprintf("fd@%d", fd)) f := os.NewFile(uintptr(fd), fmt.Sprintf("fd@%d", fd))
listener, err := net.FileListener(f) listener, err := net.FileListener(f)
if err != nil { if err != nil {
@ -431,6 +459,12 @@ func (engine *Engine) RunFd(fd int) (err error) {
func (engine *Engine) RunListener(listener net.Listener) (err error) { func (engine *Engine) RunListener(listener net.Listener) (err error) {
debugPrint("Listening and serving HTTP on listener what's bind with address@%s", listener.Addr()) debugPrint("Listening and serving HTTP on listener what's bind with address@%s", listener.Addr())
defer func() { debugPrintError(err) }() defer func() { debugPrintError(err) }()
err = engine.parseTrustedProxies()
if err != nil {
return err
}
err = http.Serve(listener, engine) err = http.Serve(listener, engine)
return return
} }

View File

@ -55,13 +55,74 @@ func TestRunEmpty(t *testing.T) {
testRequest(t, "http://localhost:8080/example") testRequest(t, "http://localhost:8080/example")
} }
func TestTrustedCIDRsForRun(t *testing.T) { func TestBadTrustedCIDRsForRun(t *testing.T) {
os.Setenv("PORT", "") os.Setenv("PORT", "")
router := New() router := New()
router.TrustedProxies = []string{"hello/world"} router.TrustedProxies = []string{"hello/world"}
assert.Error(t, router.Run(":8080")) assert.Error(t, router.Run(":8080"))
} }
func TestBadTrustedCIDRsForRunUnix(t *testing.T) {
router := New()
router.TrustedProxies = []string{"hello/world"}
unixTestSocket := filepath.Join(os.TempDir(), "unix_unit_test")
defer os.Remove(unixTestSocket)
go func() {
router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") })
assert.Error(t, router.RunUnix(unixTestSocket))
}()
// have to wait for the goroutine to start and run the server
// otherwise the main thread will complete
time.Sleep(5 * time.Millisecond)
}
func TestBadTrustedCIDRsForRunFd(t *testing.T) {
router := New()
router.TrustedProxies = []string{"hello/world"}
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
assert.NoError(t, err)
listener, err := net.ListenTCP("tcp", addr)
assert.NoError(t, err)
socketFile, err := listener.File()
assert.NoError(t, err)
go func() {
router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") })
assert.Error(t, router.RunFd(int(socketFile.Fd())))
}()
// have to wait for the goroutine to start and run the server
// otherwise the main thread will complete
time.Sleep(5 * time.Millisecond)
}
func TestBadTrustedCIDRsForRunListener(t *testing.T) {
router := New()
router.TrustedProxies = []string{"hello/world"}
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
assert.NoError(t, err)
listener, err := net.ListenTCP("tcp", addr)
assert.NoError(t, err)
go func() {
router.GET("/example", func(c *Context) { c.String(http.StatusOK, "it worked") })
assert.Error(t, router.RunListener(listener))
}()
// have to wait for the goroutine to start and run the server
// otherwise the main thread will complete
time.Sleep(5 * time.Millisecond)
}
func TestBadTrustedCIDRsForRunTLS(t *testing.T) {
os.Setenv("PORT", "")
router := New()
router.TrustedProxies = []string{"hello/world"}
assert.Error(t, router.RunTLS(":8080", "./testdata/certificate/cert.pem", "./testdata/certificate/key.pem"))
}
func TestRunTLS(t *testing.T) { func TestRunTLS(t *testing.T) {
router := New() router := New()
go func() { go func() {