diff --git a/context.go b/context.go index 39f09135..a092565d 100644 --- a/context.go +++ b/context.go @@ -7,7 +7,6 @@ package gin import ( "errors" "log" - "net" "math" "net/http" "strings" @@ -135,109 +134,44 @@ func (c *Context) Set(key string, item interface{}) { } // Get returns the value for the given key or an error if the key does not exist. -func (c *Context) Get(key string) (interface{}, error) { +func (c *Context) Get(key string) (value interface{}, ok bool) { if c.Keys != nil { - value, ok := c.Keys[key] - if ok { - return value, nil - } + value, ok = c.Keys[key] } - return nil, errors.New("Key %s does not exist") + return } // MustGet returns the value for the given key or panics if the value doesn't exist. func (c *Context) MustGet(key string) interface{} { - value, err := c.Get(key) - if err != nil { - log.Panic(err.Error()) + if value, exists := c.Get(key); exists { + return value + } else { + log.Panicf("Key %s does not exist", key) } - return value -} - -func ipInMasks(ip net.IP, masks []interface{}) bool { - for _, proxy := range masks { - var mask *net.IPNet - var err error - - switch t := proxy.(type) { - case string: - if _, mask, err = net.ParseCIDR(t); err != nil { - log.Panic(err) - } - case net.IP: - mask = &net.IPNet{IP: t, Mask: net.CIDRMask(len(t)*8, len(t)*8)} - case net.IPNet: - mask = &t - } - - if mask.Contains(ip) { - return true - } - } - - return false -} - -// the ForwardedFor middleware unwraps the X-Forwarded-For headers, be careful to only use this -// middleware if you've got servers in front of this server. The list with (known) proxies and -// local ips are being filtered out of the forwarded for list, giving the last not local ip being -// the real client ip. -func ForwardedFor(proxies ...interface{}) HandlerFunc { - if len(proxies) == 0 { - // default to local ips - var reservedLocalIps = []string{"10.0.0.0/8", "127.0.0.1/32", "172.16.0.0/12", "192.168.0.0/16"} - - proxies = make([]interface{}, len(reservedLocalIps)) - - for i, v := range reservedLocalIps { - proxies[i] = v - } - } - - return func(c *Context) { - // the X-Forwarded-For header contains an array with left most the client ip, then - // comma separated, all proxies the request passed. The last proxy appears - // as the remote address of the request. Returning the client - // ip to comply with default RemoteAddr response. - - // check if remoteaddr is local ip or in list of defined proxies - remoteIp := net.ParseIP(strings.Split(c.Request.RemoteAddr, ":")[0]) - - if !ipInMasks(remoteIp, proxies) { - return - } - - if forwardedFor := c.Request.Header.Get("X-Forwarded-For"); forwardedFor != "" { - parts := strings.Split(forwardedFor, ",") - - for i := len(parts) - 1; i >= 0; i-- { - part := parts[i] - - ip := net.ParseIP(strings.TrimSpace(part)) - - if ipInMasks(ip, proxies) { - continue - } - - // returning remote addr conform the original remote addr format - c.Request.RemoteAddr = ip.String() + ":0" - - // remove forwarded for address - c.Request.Header.Set("X-Forwarded-For", "") - return - } - } - } -} - -func (c *Context) ClientIP() string { - return c.Request.RemoteAddr + return nil } /************************************/ /********* PARSING REQUEST **********/ /************************************/ +func (c *Context) ClientIP() string { + clientIP := c.Request.Header.Get("X-Real-IP") + if len(clientIP) > 0 { + return clientIP + } + clientIP = c.Request.Header.Get("X-Forwarded-For") + clientIP = strings.Split(clientIP, ",")[0] + if len(clientIP) > 0 { + return clientIP + } + return c.Request.RemoteAddr +} + +func (c *Context) ContentType() string { + return filterFlags(c.Request.Header.Get("Content-Type")) +} + // This function checks the Content-Type to select a binding engine automatically, // Depending the "Content-Type" header different bindings are used: // "application/json" --> JSON binding diff --git a/deprecated.go b/deprecated.go index 2f53c08d..a1a10244 100644 --- a/deprecated.go +++ b/deprecated.go @@ -5,7 +5,10 @@ package gin import ( + "log" + "net" "net/http" + "strings" "github.com/gin-gonic/gin/binding" ) @@ -46,3 +49,79 @@ func (engine *Engine) LoadHTMLTemplates(pattern string) { func (engine *Engine) NotFound404(handlers ...HandlerFunc) { engine.NoRoute(handlers...) } + +// the ForwardedFor middleware unwraps the X-Forwarded-For headers, be careful to only use this +// middleware if you've got servers in front of this server. The list with (known) proxies and +// local ips are being filtered out of the forwarded for list, giving the last not local ip being +// the real client ip. +func ForwardedFor(proxies ...interface{}) HandlerFunc { + if len(proxies) == 0 { + // default to local ips + var reservedLocalIps = []string{"10.0.0.0/8", "127.0.0.1/32", "172.16.0.0/12", "192.168.0.0/16"} + + proxies = make([]interface{}, len(reservedLocalIps)) + + for i, v := range reservedLocalIps { + proxies[i] = v + } + } + + return func(c *Context) { + // the X-Forwarded-For header contains an array with left most the client ip, then + // comma separated, all proxies the request passed. The last proxy appears + // as the remote address of the request. Returning the client + // ip to comply with default RemoteAddr response. + + // check if remoteaddr is local ip or in list of defined proxies + remoteIp := net.ParseIP(strings.Split(c.Request.RemoteAddr, ":")[0]) + + if !ipInMasks(remoteIp, proxies) { + return + } + + if forwardedFor := c.Request.Header.Get("X-Forwarded-For"); forwardedFor != "" { + parts := strings.Split(forwardedFor, ",") + + for i := len(parts) - 1; i >= 0; i-- { + part := parts[i] + + ip := net.ParseIP(strings.TrimSpace(part)) + + if ipInMasks(ip, proxies) { + continue + } + + // returning remote addr conform the original remote addr format + c.Request.RemoteAddr = ip.String() + ":0" + + // remove forwarded for address + c.Request.Header.Set("X-Forwarded-For", "") + return + } + } + } +} + +func ipInMasks(ip net.IP, masks []interface{}) bool { + for _, proxy := range masks { + var mask *net.IPNet + var err error + + switch t := proxy.(type) { + case string: + if _, mask, err = net.ParseCIDR(t); err != nil { + log.Panic(err) + } + case net.IP: + mask = &net.IPNet{IP: t, Mask: net.CIDRMask(len(t)*8, len(t)*8)} + case net.IPNet: + mask = &t + } + + if mask.Contains(ip) { + return true + } + } + + return false +}