diff --git a/context.go b/context.go index f9489a77..ab880b19 100644 --- a/context.go +++ b/context.go @@ -779,14 +779,27 @@ func (c *Context) ClientIP() string { } } - // It also checks if the remoteIP is a trusted proxy or not. - // In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks - // defined by Engine.SetTrustedProxies() - remoteIP := net.ParseIP(c.RemoteIP()) - if remoteIP == nil { - return "" + var ( + trusted bool + remoteIP net.IP + ) + // If gin is listening a unix socket, always trust it. + localAddr, ok := c.Request.Context().Value(http.LocalAddrContextKey).(net.Addr) + if ok && strings.HasPrefix(localAddr.Network(), "unix") { + trusted = true + } + + // Fallback + if !trusted { + // It also checks if the remoteIP is a trusted proxy or not. + // In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks + // defined by Engine.SetTrustedProxies() + remoteIP = net.ParseIP(c.RemoteIP()) + if remoteIP == nil { + return "" + } + trusted = c.engine.isTrustedProxy(remoteIP) } - trusted := c.engine.isTrustedProxy(remoteIP) if trusted && c.engine.ForwardedByClientIP && c.engine.RemoteIPHeaders != nil { for _, headerName := range c.engine.RemoteIPHeaders { diff --git a/context_test.go b/context_test.go index b3e81c14..be247fac 100644 --- a/context_test.go +++ b/context_test.go @@ -1437,6 +1437,16 @@ func TestContextClientIP(t *testing.T) { c.engine.trustedCIDRs, _ = c.engine.prepareTrustedCIDRs() resetContextForClientIPTests(c) + // unix address + addr := &net.UnixAddr{Net: "unix", Name: "@"} + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), http.LocalAddrContextKey, addr)) + c.Request.RemoteAddr = addr.String() + assert.Equal(t, "20.20.20.20", c.ClientIP()) + + // reset + c.Request = c.Request.WithContext(context.Background()) + resetContextForClientIPTests(c) + // Legacy tests (validating that the defaults don't break the // (insecure!) old behaviour) assert.Equal(t, "20.20.20.20", c.ClientIP())