forked from mirror/gin
Enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value()
This commit is contained in:
parent
92ba8e17aa
commit
f197a8bae0
|
@ -1158,7 +1158,7 @@ func (c *Context) SetAccepted(formats ...string) {
|
|||
|
||||
// Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
|
||||
func (c *Context) Deadline() (deadline time.Time, ok bool) {
|
||||
if c.Request == nil || c.Request.Context() == nil {
|
||||
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
|
||||
return
|
||||
}
|
||||
return c.Request.Context().Deadline()
|
||||
|
@ -1166,7 +1166,7 @@ func (c *Context) Deadline() (deadline time.Time, ok bool) {
|
|||
|
||||
// Done returns nil (chan which will wait forever) when c.Request has no Context.
|
||||
func (c *Context) Done() <-chan struct{} {
|
||||
if c.Request == nil || c.Request.Context() == nil {
|
||||
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Request.Context().Done()
|
||||
|
@ -1174,7 +1174,7 @@ func (c *Context) Done() <-chan struct{} {
|
|||
|
||||
// Err returns nil when c.Request has no Context.
|
||||
func (c *Context) Err() error {
|
||||
if c.Request == nil || c.Request.Context() == nil {
|
||||
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Request.Context().Err()
|
||||
|
@ -1195,7 +1195,7 @@ func (c *Context) Value(key any) any {
|
|||
return val
|
||||
}
|
||||
}
|
||||
if c.Request == nil || c.Request.Context() == nil {
|
||||
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
|
||||
return nil
|
||||
}
|
||||
return c.Request.Context().Value(key)
|
||||
|
|
115
context_test.go
115
context_test.go
|
@ -2097,12 +2097,18 @@ func TestRemoteIPFail(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
|
||||
c := &Context{}
|
||||
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||
// enable ContextWithFallback feature flag
|
||||
c.engine.ContextWithFallback = true
|
||||
|
||||
deadline, ok := c.Deadline()
|
||||
assert.Zero(t, deadline)
|
||||
assert.False(t, ok)
|
||||
|
||||
c2 := &Context{}
|
||||
c2, _ := CreateTestContext(httptest.NewRecorder())
|
||||
// enable ContextWithFallback feature flag
|
||||
c2.engine.ContextWithFallback = true
|
||||
|
||||
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
|
||||
d := time.Now().Add(time.Second)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||
|
@ -2114,10 +2120,16 @@ func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
|
||||
c := &Context{}
|
||||
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||
// enable ContextWithFallback feature flag
|
||||
c.engine.ContextWithFallback = true
|
||||
|
||||
assert.Nil(t, c.Done())
|
||||
|
||||
c2 := &Context{}
|
||||
c2, _ := CreateTestContext(httptest.NewRecorder())
|
||||
// enable ContextWithFallback feature flag
|
||||
c2.engine.ContextWithFallback = true
|
||||
|
||||
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c2.Request = c2.Request.WithContext(ctx)
|
||||
|
@ -2126,10 +2138,16 @@ func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
|
||||
c := &Context{}
|
||||
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||
// enable ContextWithFallback feature flag
|
||||
c.engine.ContextWithFallback = true
|
||||
|
||||
assert.Nil(t, c.Err())
|
||||
|
||||
c2 := &Context{}
|
||||
c2, _ := CreateTestContext(httptest.NewRecorder())
|
||||
// enable ContextWithFallback feature flag
|
||||
c2.engine.ContextWithFallback = true
|
||||
|
||||
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c2.Request = c2.Request.WithContext(ctx)
|
||||
|
@ -2138,9 +2156,9 @@ func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
|
|||
assert.EqualError(t, c2.Err(), context.Canceled.Error())
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
||||
type contextKey string
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
getContextAndKey func() (*Context, any)
|
||||
|
@ -2150,7 +2168,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
|||
name: "c with struct context key",
|
||||
getContextAndKey: func() (*Context, any) {
|
||||
var key struct{}
|
||||
c := &Context{}
|
||||
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||
// enable ContextWithFallback feature flag
|
||||
c.engine.ContextWithFallback = true
|
||||
c.Request, _ = http.NewRequest("POST", "/", nil)
|
||||
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value"))
|
||||
return c, key
|
||||
|
@ -2160,7 +2180,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
|||
{
|
||||
name: "c with string context key",
|
||||
getContextAndKey: func() (*Context, any) {
|
||||
c := &Context{}
|
||||
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||
// enable ContextWithFallback feature flag
|
||||
c.engine.ContextWithFallback = true
|
||||
c.Request, _ = http.NewRequest("POST", "/", nil)
|
||||
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value"))
|
||||
return c, contextKey("key")
|
||||
|
@ -2170,7 +2192,10 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
|||
{
|
||||
name: "c with nil http.Request",
|
||||
getContextAndKey: func() (*Context, any) {
|
||||
c := &Context{}
|
||||
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||
// enable ContextWithFallback feature flag
|
||||
c.engine.ContextWithFallback = true
|
||||
c.Request = nil
|
||||
return c, "key"
|
||||
},
|
||||
value: nil,
|
||||
|
@ -2178,7 +2203,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
|||
{
|
||||
name: "c with nil http.Request.Context()",
|
||||
getContextAndKey: func() (*Context, any) {
|
||||
c := &Context{}
|
||||
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||
// enable ContextWithFallback feature flag
|
||||
c.engine.ContextWithFallback = true
|
||||
c.Request, _ = http.NewRequest("POST", "/", nil)
|
||||
return c, "key"
|
||||
},
|
||||
|
@ -2193,6 +2220,70 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestContextCopyShouldNotCancel(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
ensureRequestIsOver := make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
r := New()
|
||||
r.GET("/", func(ginctx *Context) {
|
||||
wg.Add(1)
|
||||
|
||||
ginctx = ginctx.Copy()
|
||||
|
||||
// start async goroutine for calling srv
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
<-ensureRequestIsOver // ensure request is done
|
||||
|
||||
req, err := http.NewRequestWithContext(ginctx, http.MethodGet, srv.URL, nil)
|
||||
must(err)
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Error(fmt.Errorf("request error: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
must(err)
|
||||
go func() {
|
||||
s := &http.Server{
|
||||
Handler: r,
|
||||
}
|
||||
|
||||
must(s.Serve(l))
|
||||
}()
|
||||
|
||||
addr := strings.Split(l.Addr().String(), ":")
|
||||
res, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/", addr[len(addr)-1]))
|
||||
if err != nil {
|
||||
t.Error(fmt.Errorf("request error: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
close(ensureRequestIsOver)
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestContextAddParam(t *testing.T) {
|
||||
c := &Context{}
|
||||
id := "id"
|
||||
|
|
3
gin.go
3
gin.go
|
@ -147,6 +147,9 @@ type Engine struct {
|
|||
// UseH2C enable h2c support.
|
||||
UseH2C bool
|
||||
|
||||
// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil.
|
||||
ContextWithFallback bool
|
||||
|
||||
delims render.Delims
|
||||
secureJSONPrefix string
|
||||
HTMLRender render.HTMLRender
|
||||
|
|
Loading…
Reference in New Issue