From 14247a51e5808fbaacd64c52a80297fc14067a04 Mon Sep 17 00:00:00 2001 From: Javier Fabrizio Date: Mon, 10 Jul 2023 16:56:00 -0300 Subject: [PATCH] fix Request.Context() checks --- context.go | 16 +++++++++------- context_test.go | 32 +++++++++++++++++++++++++------- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/context.go b/context.go index 420ff167..9286d5d4 100644 --- a/context.go +++ b/context.go @@ -1181,15 +1181,17 @@ func (c *Context) SetAccepted(formats ...string) { /************************************/ // hasRequestContext returns whether c.Request has Context and fallback. +func (c *Context) hasFallback() bool { + return c.engine != nil && c.engine.ContextWithFallback +} + func (c *Context) hasRequestContext() bool { - hasFallback := c.engine != nil && c.engine.ContextWithFallback - hasRequestContext := c.Request != nil && c.Request.Context() != nil - return hasFallback && hasRequestContext + return c.Request != nil && c.Request.Context() != nil } // Deadline returns that there is no deadline (ok==false) when c.Request has no Context. func (c *Context) Deadline() (deadline time.Time, ok bool) { - if !c.hasRequestContext() { + if !c.hasFallback() || !c.hasRequestContext() { return } return c.Request.Context().Deadline() @@ -1197,7 +1199,7 @@ func (c *Context) Deadline() (deadline time.Time, ok bool) { // Done returns nil (chan which will wait forever) when c.Request has no Context. func (c *Context) Done() <-chan struct{} { - if !c.hasRequestContext() { + if !c.hasFallback() || !c.hasRequestContext() { return nil } return c.Request.Context().Done() @@ -1205,7 +1207,7 @@ func (c *Context) Done() <-chan struct{} { // Err returns nil when c.Request has no Context. func (c *Context) Err() error { - if !c.hasRequestContext() { + if !c.hasFallback() || !c.hasRequestContext() { return nil } return c.Request.Context().Err() @@ -1226,7 +1228,7 @@ func (c *Context) Value(key any) any { return val } } - if !c.hasRequestContext() { + if !c.hasFallback() || !c.hasRequestContext() { return nil } return c.Request.Context().Value(key) diff --git a/context_test.go b/context_test.go index 70d47583..2c90bd9e 100644 --- a/context_test.go +++ b/context_test.go @@ -2178,20 +2178,38 @@ func TestRemoteIPFail(t *testing.T) { func TestHasRequestContext(t *testing.T) { c, _ := CreateTestContext(httptest.NewRecorder()) - assert.False(t, c.hasRequestContext(), "no request, no fallback") + assert.False(t, c.hasRequestContext()) c.engine.ContextWithFallback = true - assert.False(t, c.hasRequestContext(), "no request, has fallback") + assert.False(t, c.hasRequestContext()) c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) - assert.True(t, c.hasRequestContext(), "has request, has fallback") + assert.True(t, c.hasRequestContext()) c.Request, _ = http.NewRequestWithContext(nil, "", "", nil) //nolint:staticcheck - assert.False(t, c.hasRequestContext(), "has request with nil ctx, has fallback") + assert.False(t, c.hasRequestContext()) c.engine.ContextWithFallback = false - assert.False(t, c.hasRequestContext(), "has request, no fallback") + assert.False(t, c.hasRequestContext()) c = &Context{} - assert.False(t, c.hasRequestContext(), "no request, no engine") + assert.False(t, c.hasRequestContext()) c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) - assert.False(t, c.hasRequestContext(), "has request, no engine") + assert.True(t, c.hasRequestContext()) +} + +func TestHasFallback(t *testing.T) { + c, _ := CreateTestContext(httptest.NewRecorder()) + assert.False(t, c.hasFallback()) + c.engine.ContextWithFallback = true + assert.True(t, c.hasFallback()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + assert.True(t, c.hasFallback()) + c.Request, _ = http.NewRequestWithContext(nil, "", "", nil) //nolint:staticcheck + assert.True(t, c.hasFallback()) + c.engine.ContextWithFallback = false + assert.False(t, c.hasFallback()) + + c = &Context{} + assert.False(t, c.hasFallback()) + c.Request, _ = http.NewRequest(http.MethodGet, "/", nil) + assert.False(t, c.hasFallback()) } func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {