From bb1fc2e0fe97c63dab1527baab88d01183853b8f Mon Sep 17 00:00:00 2001 From: Bence Vidosits <38434845+bvidosits@users.noreply.github.com> Date: Mon, 29 May 2023 01:59:35 +0000 Subject: [PATCH] fix Request.Context() checks (#3512) Co-authored-by: Bence Vidosits --- context.go | 15 +++++++++++---- context_test.go | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/context.go b/context.go index cb360879..420ff167 100644 --- a/context.go +++ b/context.go @@ -1180,9 +1180,16 @@ func (c *Context) SetAccepted(formats ...string) { /***** GOLANG.ORG/X/NET/CONTEXT *****/ /************************************/ +// hasRequestContext returns whether c.Request has Context and fallback. +func (c *Context) hasRequestContext() bool { + hasFallback := c.engine != nil && c.engine.ContextWithFallback + hasRequestContext := c.Request != nil && c.Request.Context() != nil + return hasFallback && hasRequestContext +} + // Deadline returns that there is no deadline (ok==false) when c.Request has no Context. func (c *Context) Deadline() (deadline time.Time, ok bool) { - if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { + if !c.hasRequestContext() { return } return c.Request.Context().Deadline() @@ -1190,7 +1197,7 @@ func (c *Context) Deadline() (deadline time.Time, ok bool) { // Done returns nil (chan which will wait forever) when c.Request has no Context. func (c *Context) Done() <-chan struct{} { - if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { + if !c.hasRequestContext() { return nil } return c.Request.Context().Done() @@ -1198,7 +1205,7 @@ func (c *Context) Done() <-chan struct{} { // Err returns nil when c.Request has no Context. func (c *Context) Err() error { - if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { + if !c.hasRequestContext() { return nil } return c.Request.Context().Err() @@ -1219,7 +1226,7 @@ func (c *Context) Value(key any) any { return val } } - if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil { + if !c.hasRequestContext() { return nil } return c.Request.Context().Value(key) diff --git a/context_test.go b/context_test.go index 18051235..70d47583 100644 --- a/context_test.go +++ b/context_test.go @@ -2176,6 +2176,24 @@ func TestRemoteIPFail(t *testing.T) { assert.False(t, trust) } +func TestHasRequestContext(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + assert.False(t, c.hasRequestContext(), "no request, no fallback") + c.engine.ContextWithFallback = true + assert.False(t, c.hasRequestContext(), "no request, has fallback") + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + assert.True(t, c.hasRequestContext(), "has request, has fallback") + c.Request, _ = http.NewRequestWithContext(nil, "", "", nil) //nolint:staticcheck + assert.False(t, c.hasRequestContext(), "has request with nil ctx, has fallback") + c.engine.ContextWithFallback = false + assert.False(t, c.hasRequestContext(), "has request, no fallback") + + c = &Context{} + assert.False(t, c.hasRequestContext(), "no request, no engine") + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + assert.False(t, c.hasRequestContext(), "has request, no engine") +} + func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) // enable ContextWithFallback feature flag