feat(context): add ContextWithFallback feature flag (#3166) (#3172)

Enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value()
This commit is contained in:
wei 2022-06-06 18:43:53 +08:00 committed by GitHub
parent 92ba8e17aa
commit f197a8bae0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 110 additions and 16 deletions

View File

@ -1158,7 +1158,7 @@ func (c *Context) SetAccepted(formats ...string) {
// Deadline returns that there is no deadline (ok==false) when c.Request has no Context. // Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
func (c *Context) Deadline() (deadline time.Time, ok bool) { func (c *Context) Deadline() (deadline time.Time, ok bool) {
if c.Request == nil || c.Request.Context() == nil { if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return return
} }
return c.Request.Context().Deadline() return c.Request.Context().Deadline()
@ -1166,7 +1166,7 @@ func (c *Context) Deadline() (deadline time.Time, ok bool) {
// Done returns nil (chan which will wait forever) when c.Request has no Context. // Done returns nil (chan which will wait forever) when c.Request has no Context.
func (c *Context) Done() <-chan struct{} { func (c *Context) Done() <-chan struct{} {
if c.Request == nil || c.Request.Context() == nil { if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return nil return nil
} }
return c.Request.Context().Done() return c.Request.Context().Done()
@ -1174,7 +1174,7 @@ func (c *Context) Done() <-chan struct{} {
// Err returns nil when c.Request has no Context. // Err returns nil when c.Request has no Context.
func (c *Context) Err() error { func (c *Context) Err() error {
if c.Request == nil || c.Request.Context() == nil { if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return nil return nil
} }
return c.Request.Context().Err() return c.Request.Context().Err()
@ -1195,7 +1195,7 @@ func (c *Context) Value(key any) any {
return val return val
} }
} }
if c.Request == nil || c.Request.Context() == nil { if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
return nil return nil
} }
return c.Request.Context().Value(key) return c.Request.Context().Value(key)

View File

@ -2097,12 +2097,18 @@ func TestRemoteIPFail(t *testing.T) {
} }
func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) { func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
c := &Context{} c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
deadline, ok := c.Deadline() deadline, ok := c.Deadline()
assert.Zero(t, deadline) assert.Zero(t, deadline)
assert.False(t, ok) assert.False(t, ok)
c2 := &Context{} c2, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil) c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
d := time.Now().Add(time.Second) d := time.Now().Add(time.Second)
ctx, cancel := context.WithDeadline(context.Background(), d) ctx, cancel := context.WithDeadline(context.Background(), d)
@ -2114,10 +2120,16 @@ func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
} }
func TestContextWithFallbackDoneFromRequestContext(t *testing.T) { func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
c := &Context{} c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
assert.Nil(t, c.Done()) assert.Nil(t, c.Done())
c2 := &Context{} c2, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil) c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
c2.Request = c2.Request.WithContext(ctx) c2.Request = c2.Request.WithContext(ctx)
@ -2126,10 +2138,16 @@ func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
} }
func TestContextWithFallbackErrFromRequestContext(t *testing.T) { func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
c := &Context{} c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
assert.Nil(t, c.Err()) assert.Nil(t, c.Err())
c2 := &Context{} c2, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c2.engine.ContextWithFallback = true
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil) c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
c2.Request = c2.Request.WithContext(ctx) c2.Request = c2.Request.WithContext(ctx)
@ -2138,9 +2156,9 @@ func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
assert.EqualError(t, c2.Err(), context.Canceled.Error()) assert.EqualError(t, c2.Err(), context.Canceled.Error())
} }
type contextKey string
func TestContextWithFallbackValueFromRequestContext(t *testing.T) { func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
type contextKey string
tests := []struct { tests := []struct {
name string name string
getContextAndKey func() (*Context, any) getContextAndKey func() (*Context, any)
@ -2150,7 +2168,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
name: "c with struct context key", name: "c with struct context key",
getContextAndKey: func() (*Context, any) { getContextAndKey: func() (*Context, any) {
var key struct{} var key struct{}
c := &Context{} c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil) c.Request, _ = http.NewRequest("POST", "/", nil)
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value")) c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value"))
return c, key return c, key
@ -2160,7 +2180,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
{ {
name: "c with string context key", name: "c with string context key",
getContextAndKey: func() (*Context, any) { getContextAndKey: func() (*Context, any) {
c := &Context{} c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil) c.Request, _ = http.NewRequest("POST", "/", nil)
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value")) c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value"))
return c, contextKey("key") return c, contextKey("key")
@ -2170,7 +2192,10 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
{ {
name: "c with nil http.Request", name: "c with nil http.Request",
getContextAndKey: func() (*Context, any) { getContextAndKey: func() (*Context, any) {
c := &Context{} c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request = nil
return c, "key" return c, "key"
}, },
value: nil, value: nil,
@ -2178,7 +2203,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
{ {
name: "c with nil http.Request.Context()", name: "c with nil http.Request.Context()",
getContextAndKey: func() (*Context, any) { getContextAndKey: func() (*Context, any) {
c := &Context{} c, _ := CreateTestContext(httptest.NewRecorder())
// enable ContextWithFallback feature flag
c.engine.ContextWithFallback = true
c.Request, _ = http.NewRequest("POST", "/", nil) c.Request, _ = http.NewRequest("POST", "/", nil)
return c, "key" return c, "key"
}, },
@ -2193,6 +2220,70 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
} }
} }
func TestContextCopyShouldNotCancel(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
ensureRequestIsOver := make(chan struct{})
wg := &sync.WaitGroup{}
r := New()
r.GET("/", func(ginctx *Context) {
wg.Add(1)
ginctx = ginctx.Copy()
// start async goroutine for calling srv
go func() {
defer wg.Done()
<-ensureRequestIsOver // ensure request is done
req, err := http.NewRequestWithContext(ginctx, http.MethodGet, srv.URL, nil)
must(err)
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Error(fmt.Errorf("request error: %w", err))
return
}
if res.StatusCode != http.StatusOK {
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
}
}()
})
l, err := net.Listen("tcp", ":0")
must(err)
go func() {
s := &http.Server{
Handler: r,
}
must(s.Serve(l))
}()
addr := strings.Split(l.Addr().String(), ":")
res, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/", addr[len(addr)-1]))
if err != nil {
t.Error(fmt.Errorf("request error: %w", err))
return
}
close(ensureRequestIsOver)
if res.StatusCode != http.StatusOK {
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
return
}
wg.Wait()
}
func TestContextAddParam(t *testing.T) { func TestContextAddParam(t *testing.T) {
c := &Context{} c := &Context{}
id := "id" id := "id"

3
gin.go
View File

@ -147,6 +147,9 @@ type Engine struct {
// UseH2C enable h2c support. // UseH2C enable h2c support.
UseH2C bool UseH2C bool
// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil.
ContextWithFallback bool
delims render.Delims delims render.Delims
secureJSONPrefix string secureJSONPrefix string
HTMLRender render.HTMLRender HTMLRender render.HTMLRender