mirror of https://github.com/gin-gonic/gin.git
Enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value()
This commit is contained in:
parent
92ba8e17aa
commit
f197a8bae0
|
@ -1158,7 +1158,7 @@ func (c *Context) SetAccepted(formats ...string) {
|
||||||
|
|
||||||
// Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
|
// Deadline returns that there is no deadline (ok==false) when c.Request has no Context.
|
||||||
func (c *Context) Deadline() (deadline time.Time, ok bool) {
|
func (c *Context) Deadline() (deadline time.Time, ok bool) {
|
||||||
if c.Request == nil || c.Request.Context() == nil {
|
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return c.Request.Context().Deadline()
|
return c.Request.Context().Deadline()
|
||||||
|
@ -1166,7 +1166,7 @@ func (c *Context) Deadline() (deadline time.Time, ok bool) {
|
||||||
|
|
||||||
// Done returns nil (chan which will wait forever) when c.Request has no Context.
|
// Done returns nil (chan which will wait forever) when c.Request has no Context.
|
||||||
func (c *Context) Done() <-chan struct{} {
|
func (c *Context) Done() <-chan struct{} {
|
||||||
if c.Request == nil || c.Request.Context() == nil {
|
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return c.Request.Context().Done()
|
return c.Request.Context().Done()
|
||||||
|
@ -1174,7 +1174,7 @@ func (c *Context) Done() <-chan struct{} {
|
||||||
|
|
||||||
// Err returns nil when c.Request has no Context.
|
// Err returns nil when c.Request has no Context.
|
||||||
func (c *Context) Err() error {
|
func (c *Context) Err() error {
|
||||||
if c.Request == nil || c.Request.Context() == nil {
|
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return c.Request.Context().Err()
|
return c.Request.Context().Err()
|
||||||
|
@ -1195,7 +1195,7 @@ func (c *Context) Value(key any) any {
|
||||||
return val
|
return val
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if c.Request == nil || c.Request.Context() == nil {
|
if !c.engine.ContextWithFallback || c.Request == nil || c.Request.Context() == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return c.Request.Context().Value(key)
|
return c.Request.Context().Value(key)
|
||||||
|
|
115
context_test.go
115
context_test.go
|
@ -2097,12 +2097,18 @@ func TestRemoteIPFail(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
|
func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
|
||||||
c := &Context{}
|
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||||
|
// enable ContextWithFallback feature flag
|
||||||
|
c.engine.ContextWithFallback = true
|
||||||
|
|
||||||
deadline, ok := c.Deadline()
|
deadline, ok := c.Deadline()
|
||||||
assert.Zero(t, deadline)
|
assert.Zero(t, deadline)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
|
|
||||||
c2 := &Context{}
|
c2, _ := CreateTestContext(httptest.NewRecorder())
|
||||||
|
// enable ContextWithFallback feature flag
|
||||||
|
c2.engine.ContextWithFallback = true
|
||||||
|
|
||||||
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
|
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
|
||||||
d := time.Now().Add(time.Second)
|
d := time.Now().Add(time.Second)
|
||||||
ctx, cancel := context.WithDeadline(context.Background(), d)
|
ctx, cancel := context.WithDeadline(context.Background(), d)
|
||||||
|
@ -2114,10 +2120,16 @@ func TestContextWithFallbackDeadlineFromRequestContext(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
|
func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
|
||||||
c := &Context{}
|
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||||
|
// enable ContextWithFallback feature flag
|
||||||
|
c.engine.ContextWithFallback = true
|
||||||
|
|
||||||
assert.Nil(t, c.Done())
|
assert.Nil(t, c.Done())
|
||||||
|
|
||||||
c2 := &Context{}
|
c2, _ := CreateTestContext(httptest.NewRecorder())
|
||||||
|
// enable ContextWithFallback feature flag
|
||||||
|
c2.engine.ContextWithFallback = true
|
||||||
|
|
||||||
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
|
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
c2.Request = c2.Request.WithContext(ctx)
|
c2.Request = c2.Request.WithContext(ctx)
|
||||||
|
@ -2126,10 +2138,16 @@ func TestContextWithFallbackDoneFromRequestContext(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
|
func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
|
||||||
c := &Context{}
|
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||||
|
// enable ContextWithFallback feature flag
|
||||||
|
c.engine.ContextWithFallback = true
|
||||||
|
|
||||||
assert.Nil(t, c.Err())
|
assert.Nil(t, c.Err())
|
||||||
|
|
||||||
c2 := &Context{}
|
c2, _ := CreateTestContext(httptest.NewRecorder())
|
||||||
|
// enable ContextWithFallback feature flag
|
||||||
|
c2.engine.ContextWithFallback = true
|
||||||
|
|
||||||
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
|
c2.Request, _ = http.NewRequest(http.MethodGet, "/", nil)
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
c2.Request = c2.Request.WithContext(ctx)
|
c2.Request = c2.Request.WithContext(ctx)
|
||||||
|
@ -2138,9 +2156,9 @@ func TestContextWithFallbackErrFromRequestContext(t *testing.T) {
|
||||||
assert.EqualError(t, c2.Err(), context.Canceled.Error())
|
assert.EqualError(t, c2.Err(), context.Canceled.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
type contextKey string
|
|
||||||
|
|
||||||
func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
getContextAndKey func() (*Context, any)
|
getContextAndKey func() (*Context, any)
|
||||||
|
@ -2150,7 +2168,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
||||||
name: "c with struct context key",
|
name: "c with struct context key",
|
||||||
getContextAndKey: func() (*Context, any) {
|
getContextAndKey: func() (*Context, any) {
|
||||||
var key struct{}
|
var key struct{}
|
||||||
c := &Context{}
|
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||||
|
// enable ContextWithFallback feature flag
|
||||||
|
c.engine.ContextWithFallback = true
|
||||||
c.Request, _ = http.NewRequest("POST", "/", nil)
|
c.Request, _ = http.NewRequest("POST", "/", nil)
|
||||||
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value"))
|
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), key, "value"))
|
||||||
return c, key
|
return c, key
|
||||||
|
@ -2160,7 +2180,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "c with string context key",
|
name: "c with string context key",
|
||||||
getContextAndKey: func() (*Context, any) {
|
getContextAndKey: func() (*Context, any) {
|
||||||
c := &Context{}
|
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||||
|
// enable ContextWithFallback feature flag
|
||||||
|
c.engine.ContextWithFallback = true
|
||||||
c.Request, _ = http.NewRequest("POST", "/", nil)
|
c.Request, _ = http.NewRequest("POST", "/", nil)
|
||||||
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value"))
|
c.Request = c.Request.WithContext(context.WithValue(context.TODO(), contextKey("key"), "value"))
|
||||||
return c, contextKey("key")
|
return c, contextKey("key")
|
||||||
|
@ -2170,7 +2192,10 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "c with nil http.Request",
|
name: "c with nil http.Request",
|
||||||
getContextAndKey: func() (*Context, any) {
|
getContextAndKey: func() (*Context, any) {
|
||||||
c := &Context{}
|
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||||
|
// enable ContextWithFallback feature flag
|
||||||
|
c.engine.ContextWithFallback = true
|
||||||
|
c.Request = nil
|
||||||
return c, "key"
|
return c, "key"
|
||||||
},
|
},
|
||||||
value: nil,
|
value: nil,
|
||||||
|
@ -2178,7 +2203,9 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "c with nil http.Request.Context()",
|
name: "c with nil http.Request.Context()",
|
||||||
getContextAndKey: func() (*Context, any) {
|
getContextAndKey: func() (*Context, any) {
|
||||||
c := &Context{}
|
c, _ := CreateTestContext(httptest.NewRecorder())
|
||||||
|
// enable ContextWithFallback feature flag
|
||||||
|
c.engine.ContextWithFallback = true
|
||||||
c.Request, _ = http.NewRequest("POST", "/", nil)
|
c.Request, _ = http.NewRequest("POST", "/", nil)
|
||||||
return c, "key"
|
return c, "key"
|
||||||
},
|
},
|
||||||
|
@ -2193,6 +2220,70 @@ func TestContextWithFallbackValueFromRequestContext(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContextCopyShouldNotCancel(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
ensureRequestIsOver := make(chan struct{})
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
|
r := New()
|
||||||
|
r.GET("/", func(ginctx *Context) {
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
ginctx = ginctx.Copy()
|
||||||
|
|
||||||
|
// start async goroutine for calling srv
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
<-ensureRequestIsOver // ensure request is done
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ginctx, http.MethodGet, srv.URL, nil)
|
||||||
|
must(err)
|
||||||
|
|
||||||
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(fmt.Errorf("request error: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
|
||||||
|
l, err := net.Listen("tcp", ":0")
|
||||||
|
must(err)
|
||||||
|
go func() {
|
||||||
|
s := &http.Server{
|
||||||
|
Handler: r,
|
||||||
|
}
|
||||||
|
|
||||||
|
must(s.Serve(l))
|
||||||
|
}()
|
||||||
|
|
||||||
|
addr := strings.Split(l.Addr().String(), ":")
|
||||||
|
res, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/", addr[len(addr)-1]))
|
||||||
|
if err != nil {
|
||||||
|
t.Error(fmt.Errorf("request error: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(ensureRequestIsOver)
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
t.Error(fmt.Errorf("unexpected status code: %s", res.Status))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func TestContextAddParam(t *testing.T) {
|
func TestContextAddParam(t *testing.T) {
|
||||||
c := &Context{}
|
c := &Context{}
|
||||||
id := "id"
|
id := "id"
|
||||||
|
|
3
gin.go
3
gin.go
|
@ -147,6 +147,9 @@ type Engine struct {
|
||||||
// UseH2C enable h2c support.
|
// UseH2C enable h2c support.
|
||||||
UseH2C bool
|
UseH2C bool
|
||||||
|
|
||||||
|
// ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil.
|
||||||
|
ContextWithFallback bool
|
||||||
|
|
||||||
delims render.Delims
|
delims render.Delims
|
||||||
secureJSONPrefix string
|
secureJSONPrefix string
|
||||||
HTMLRender render.HTMLRender
|
HTMLRender render.HTMLRender
|
||||||
|
|
Loading…
Reference in New Issue