diff --git a/context.go b/context.go index d2b98b65..0beeef91 100644 --- a/context.go +++ b/context.go @@ -8,6 +8,7 @@ import ( "errors" "io" "math" + "mime/multipart" "net" "net/http" "net/url" @@ -30,7 +31,10 @@ const ( MIMEMultipartPOSTForm = binding.MIMEMultipartPOSTForm ) -const abortIndex int8 = math.MaxInt8 / 2 +const ( + defaultMemory = 32 << 20 // 32 MB + abortIndex int8 = math.MaxInt8 / 2 +) // Context is the most important part of gin. It allows us to pass variables between middleware, // manage the flow, validate the JSON of a request and render a JSON response for example. @@ -291,7 +295,7 @@ func (c *Context) PostFormArray(key string) []string { func (c *Context) GetPostFormArray(key string) ([]string, bool) { req := c.Request req.ParseForm() - req.ParseMultipartForm(32 << 20) // 32 MB + req.ParseMultipartForm(defaultMemory) if values := req.PostForm[key]; len(values) > 0 { return values, true } @@ -303,6 +307,18 @@ func (c *Context) GetPostFormArray(key string) ([]string, bool) { return []string{}, false } +// FormFile returns the first file for the provided form key. +func (c *Context) FormFile(name string) (*multipart.FileHeader, error) { + _, fh, err := c.Request.FormFile(name) + return fh, err +} + +// MultipartForm is the parsed multipart form, including file uploads. +func (c *Context) MultipartForm() (*multipart.Form, error) { + err := c.Request.ParseMultipartForm(defaultMemory) + return c.Request.MultipartForm, err +} + // Bind checks the Content-Type to select a binding engine automatically, // Depending the "Content-Type" header different bindings are used: // "application/json" --> JSON binding diff --git a/context_test.go b/context_test.go index 81e18841..e488f423 100644 --- a/context_test.go +++ b/context_test.go @@ -53,6 +53,37 @@ func must(err error) { } } +func TestContextFormFile(t *testing.T) { + buf := new(bytes.Buffer) + mw := multipart.NewWriter(buf) + w, err := mw.CreateFormFile("file", "test") + if assert.NoError(t, err) { + w.Write([]byte("test")) + } + mw.Close() + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", buf) + c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + f, err := c.FormFile("file") + if assert.NoError(t, err) { + assert.Equal(t, "test", f.Filename) + } +} + +func TestContextMultipartForm(t *testing.T) { + buf := new(bytes.Buffer) + mw := multipart.NewWriter(buf) + mw.WriteField("foo", "bar") + mw.Close() + c, _ := CreateTestContext(httptest.NewRecorder()) + c.Request, _ = http.NewRequest("POST", "/", buf) + c.Request.Header.Set("Content-Type", mw.FormDataContentType()) + f, err := c.MultipartForm() + if assert.NoError(t, err) { + assert.NotNil(t, f) + } +} + func TestContextReset(t *testing.T) { router := New() c := router.allocateContext()