From c4914f0ff71996de663f0b5327aa8a0d1bf4476f Mon Sep 17 00:00:00 2001 From: Manu Mtz-Almeida Date: Wed, 20 May 2015 00:39:52 +0200 Subject: [PATCH] Adds WrapF() and WrapH() --- utils.go | 8 +++++++- utils_test.go | 25 ++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/utils.go b/utils.go index 476eb832..eaf6a5a7 100644 --- a/utils.go +++ b/utils.go @@ -13,12 +13,18 @@ import ( "strings" ) -func Wrap(f http.HandlerFunc) HandlerFunc { +func WrapF(f http.HandlerFunc) HandlerFunc { return func(c *Context) { f(c.Writer, c.Request) } } +func WrapH(h http.Handler) HandlerFunc { + return func(c *Context) { + h.ServeHTTP(c.Writer, c.Request) + } +} + type H map[string]interface{} // Allows type H to be used with xml.Marshal diff --git a/utils_test.go b/utils_test.go index 5bd4cc15..19fab5d9 100644 --- a/utils_test.go +++ b/utils_test.go @@ -16,21 +16,40 @@ func init() { SetMode(TestMode) } +type testStruct struct { + T *testing.T +} + +func (t *testStruct) ServeHTTP(w http.ResponseWriter, req *http.Request) { + assert.Equal(t.T, req.Method, "POST") + assert.Equal(t.T, req.URL.Path, "/path") + w.WriteHeader(500) + fmt.Fprint(w, "hello") +} + func TestWrap(t *testing.T) { router := New() - router.GET("/path", Wrap(func(w http.ResponseWriter, req *http.Request) { + router.POST("/path", WrapH(&testStruct{t})) + router.GET("/path2", WrapF(func(w http.ResponseWriter, req *http.Request) { assert.Equal(t, req.Method, "GET") - assert.Equal(t, req.URL.Path, "/path") + assert.Equal(t, req.URL.Path, "/path2") w.WriteHeader(400) fmt.Fprint(w, "hola!") })) - w := performRequest(router, "GET", "/path") + w := performRequest(router, "POST", "/path") + assert.Equal(t, w.Code, 500) + assert.Equal(t, w.Body.String(), "hello") + w = performRequest(router, "GET", "/path2") assert.Equal(t, w.Code, 400) assert.Equal(t, w.Body.String(), "hola!") } +func TestWrapH(t *testing.T) { + +} + func TestLastChar(t *testing.T) { assert.Equal(t, lastChar("hola"), uint8('a')) assert.Equal(t, lastChar("adios"), uint8('s'))