From f32f8cb6e5692ce26a0ca6a42c6fbdc10a868ae7 Mon Sep 17 00:00:00 2001 From: ekeyme Date: Mon, 5 Dec 2022 17:55:43 +0800 Subject: [PATCH] fix middleware called multiple times in HandleContext redirection --- gin.go | 2 ++ gin_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/gin.go b/gin.go index f95e5dda..e96b3683 100644 --- a/gin.go +++ b/gin.go @@ -583,10 +583,12 @@ func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Disclaimer: You can loop yourself to deal with this, use wisely. func (engine *Engine) HandleContext(c *Context) { oldIndexValue := c.index + oldHandlers := c.handlers c.reset() engine.handleHTTPRequest(c) c.index = oldIndexValue + c.handlers = oldHandlers } func (engine *Engine) handleHTTPRequest(c *Context) { diff --git a/gin_test.go b/gin_test.go index 8825ac7e..22510d81 100644 --- a/gin_test.go +++ b/gin_test.go @@ -571,6 +571,46 @@ func TestEngineHandleContextManyReEntries(t *testing.T) { assert.Equal(t, int64(expectValue), middlewareCounter) } +func TestEngineHandleContextReEntriesWithDifferentMiddlewares(t *testing.T) { + var ( + expectedResp = "ok" + v1MiddlewareCounter1, v2MiddlewareCounter1, v2MiddlewareCounter2 int64 + ) + + r := New() + + r.GET("/v1", + func(c *Context) { /* V1middleware1 */ + atomic.AddInt64(&v1MiddlewareCounter1, 1) + }, + func(c *Context) { + c.Request.URL.Path = "/v2" + r.HandleContext(c) + }) + + r.GET("/v2", + func(c *Context) { /* V2middleware1 */ + atomic.AddInt64(&v2MiddlewareCounter1, 1) + }, + func(c *Context) { /* v2middleware2 */ + atomic.AddInt64(&v2MiddlewareCounter2, 1) + }, + func(c *Context) { + c.String(200, expectedResp) + }) + + assert.NotPanics(t, func() { + w := PerformRequest(r, "GET", "/v1") + assert.Equal(t, 200, w.Code) + assert.Equal(t, expectedResp, w.Body.String()) + }) + + // all the middlewares should be called independently from each router + assert.Equal(t, int64(1), v1MiddlewareCounter1) + assert.Equal(t, int64(1), v2MiddlewareCounter1) + assert.Equal(t, int64(1), v2MiddlewareCounter2) +} + func TestPrepareTrustedCIRDsWith(t *testing.T) { r := New()