From 70325deb98d3f65551e83147ff11d8c32989a887 Mon Sep 17 00:00:00 2001 From: Manu Mtz-Almeida Date: Thu, 4 Jun 2015 13:15:22 +0200 Subject: [PATCH] c.ClientIP() performance improvement benchmark old ns/op new ns/op delta BenchmarkManyHandlers 4956 4463 -9.95% benchmark old allocs new allocs delta BenchmarkManyHandlers 16 13 -18.75% benchmark old bytes new bytes delta BenchmarkManyHandlers 256 216 -15.62% --- context.go | 20 +++++++++++++++----- context_test.go | 4 ++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/context.go b/context.go index 73f45581..d180dd87 100644 --- a/context.go +++ b/context.go @@ -251,12 +251,15 @@ func (c *Context) BindWith(obj interface{}, b binding.Binding) error { // Best effort algoritm to return the real client IP, it parses // X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy. func (c *Context) ClientIP() string { - clientIP := strings.TrimSpace(c.Request.Header.Get("X-Real-IP")) + clientIP := strings.TrimSpace(c.requestHeader("X-Real-Ip")) if len(clientIP) > 0 { return clientIP } - clientIP = c.Request.Header.Get("X-Forwarded-For") - clientIP = strings.TrimSpace(strings.Split(clientIP, ",")[0]) + clientIP = c.requestHeader("X-Forwarded-For") + if index := strings.IndexByte(clientIP, ','); index >= 0 { + clientIP = clientIP[0:index] + } + clientIP = strings.TrimSpace(clientIP) if len(clientIP) > 0 { return clientIP } @@ -264,7 +267,14 @@ func (c *Context) ClientIP() string { } func (c *Context) ContentType() string { - return filterFlags(c.Request.Header.Get("Content-Type")) + return filterFlags(c.requestHeader("Content-Type")) +} + +func (c *Context) requestHeader(key string) string { + if values, _ := c.Request.Header[key]; len(values) > 0 { + return values[0] + } + return "" } /************************************/ @@ -414,7 +424,7 @@ func (c *Context) NegotiateFormat(offered ...string) string { panic("you must provide at least one offer") } if c.Accepted == nil { - c.Accepted = parseAccept(c.Request.Header.Get("Accept")) + c.Accepted = parseAccept(c.requestHeader("Accept")) } if len(c.Accepted) == 0 { return offered[0] diff --git a/context_test.go b/context_test.go index cc1c4527..f6083070 100644 --- a/context_test.go +++ b/context_test.go @@ -454,7 +454,7 @@ func TestContextClientIP(t *testing.T) { c.Request, _ = http.NewRequest("POST", "/", nil) c.Request.Header.Set("X-Real-IP", " 10.10.10.10 ") - c.Request.Header.Set("X-Forwarded-For", " 20.20.20.20 , 30.30.30.30") + c.Request.Header.Set("X-Forwarded-For", " 20.20.20.20, 30.30.30.30") c.Request.RemoteAddr = " 40.40.40.40 " assert.Equal(t, c.ClientIP(), "10.10.10.10") @@ -462,7 +462,7 @@ func TestContextClientIP(t *testing.T) { c.Request.Header.Del("X-Real-IP") assert.Equal(t, c.ClientIP(), "20.20.20.20") - c.Request.Header.Set("X-Forwarded-For", "30.30.30.30") + c.Request.Header.Set("X-Forwarded-For", "30.30.30.30 ") assert.Equal(t, c.ClientIP(), "30.30.30.30") c.Request.Header.Del("X-Forwarded-For")