diff --git a/binding/form.go b/binding/form.go index f1f89195..0b28aa8a 100644 --- a/binding/form.go +++ b/binding/form.go @@ -4,7 +4,11 @@ package binding -import "net/http" +import ( + "mime/multipart" + "net/http" + "reflect" +) const defaultMemory = 32 * 1024 * 1024 @@ -53,13 +57,33 @@ func (formMultipartBinding) Bind(req *http.Request, obj interface{}) error { if err := req.ParseMultipartForm(defaultMemory); err != nil { return err } - if err := mapForm(obj, req.MultipartForm.Value); err != nil { - return err - } - - if err := mapFiles(obj, req); err != nil { + if err := mappingByPtr(obj, (*multipartRequest)(req), "form"); err != nil { return err } return validate(obj) } + +type multipartRequest http.Request + +var _ setter = (*multipartRequest)(nil) + +var ( + multipartFileHeaderStructType = reflect.TypeOf(multipart.FileHeader{}) +) + +// TrySet tries to set a value by the multipart request with the binding a form file +func (r *multipartRequest) TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (isSetted bool, err error) { + if value.Type() == multipartFileHeaderStructType { + _, file, err := (*http.Request)(r).FormFile(key) + if err != nil { + return false, err + } + if file != nil { + value.Set(reflect.ValueOf(*file)) + return true, nil + } + } + + return setByForm(value, field, r.MultipartForm.Value, key, opt) +} diff --git a/binding/form_mapping.go b/binding/form_mapping.go index fc33b1df..aaacf6c5 100644 --- a/binding/form_mapping.go +++ b/binding/form_mapping.go @@ -7,7 +7,6 @@ package binding import ( "errors" "fmt" - "net/http" "reflect" "strconv" "strings" @@ -16,34 +15,6 @@ import ( "github.com/gin-gonic/gin/internal/json" ) -func mapFiles(ptr interface{}, req *http.Request) error { - typ := reflect.TypeOf(ptr).Elem() - val := reflect.ValueOf(ptr).Elem() - for i := 0; i < typ.NumField(); i++ { - typeField := typ.Field(i) - structField := val.Field(i) - - t := fmt.Sprintf("%s", typeField.Type) - if string(t) != "*multipart.FileHeader" { - continue - } - - inputFieldName := typeField.Tag.Get("form") - if inputFieldName == "" { - inputFieldName = typeField.Name - } - - _, fileHeader, err := req.FormFile(inputFieldName) - if err != nil { - return err - } - - structField.Set(reflect.ValueOf(fileHeader)) - - } - return nil -} - var errUnknownType = errors.New("Unknown type") func mapUri(ptr interface{}, m map[string][]string) error { @@ -57,11 +28,29 @@ func mapForm(ptr interface{}, form map[string][]string) error { var emptyField = reflect.StructField{} func mapFormByTag(ptr interface{}, form map[string][]string, tag string) error { - _, err := mapping(reflect.ValueOf(ptr), emptyField, form, tag) + return mappingByPtr(ptr, formSource(form), tag) +} + +// setter tries to set value on a walking by fields of a struct +type setter interface { + TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (isSetted bool, err error) +} + +type formSource map[string][]string + +var _ setter = formSource(nil) + +// TrySet tries to set a value by request's form source (like map[string][]string) +func (form formSource) TrySet(value reflect.Value, field reflect.StructField, tagValue string, opt setOptions) (isSetted bool, err error) { + return setByForm(value, field, form, tagValue, opt) +} + +func mappingByPtr(ptr interface{}, setter setter, tag string) error { + _, err := mapping(reflect.ValueOf(ptr), emptyField, setter, tag) return err } -func mapping(value reflect.Value, field reflect.StructField, form map[string][]string, tag string) (bool, error) { +func mapping(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) { var vKind = value.Kind() if vKind == reflect.Ptr { @@ -71,7 +60,7 @@ func mapping(value reflect.Value, field reflect.StructField, form map[string][]s isNew = true vPtr = reflect.New(value.Type().Elem()) } - isSetted, err := mapping(vPtr.Elem(), field, form, tag) + isSetted, err := mapping(vPtr.Elem(), field, setter, tag) if err != nil { return false, err } @@ -81,7 +70,7 @@ func mapping(value reflect.Value, field reflect.StructField, form map[string][]s return isSetted, nil } - ok, err := tryToSetValue(value, field, form, tag) + ok, err := tryToSetValue(value, field, setter, tag) if err != nil { return false, err } @@ -97,7 +86,7 @@ func mapping(value reflect.Value, field reflect.StructField, form map[string][]s if !value.Field(i).CanSet() { continue } - ok, err := mapping(value.Field(i), tValue.Field(i), form, tag) + ok, err := mapping(value.Field(i), tValue.Field(i), setter, tag) if err != nil { return false, err } @@ -108,9 +97,14 @@ func mapping(value reflect.Value, field reflect.StructField, form map[string][]s return false, nil } -func tryToSetValue(value reflect.Value, field reflect.StructField, form map[string][]string, tag string) (bool, error) { - var tagValue, defaultValue string - var isDefaultExists bool +type setOptions struct { + isDefaultExists bool + defaultValue string +} + +func tryToSetValue(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) { + var tagValue string + var setOpt setOptions tagValue = field.Tag.Get(tag) tagValue, opts := head(tagValue, ",") @@ -132,25 +126,29 @@ func tryToSetValue(value reflect.Value, field reflect.StructField, form map[stri k, v := head(opt, "=") switch k { case "default": - isDefaultExists = true - defaultValue = v + setOpt.isDefaultExists = true + setOpt.defaultValue = v } } + return setter.TrySet(value, field, tagValue, setOpt) +} + +func setByForm(value reflect.Value, field reflect.StructField, form map[string][]string, tagValue string, opt setOptions) (isSetted bool, err error) { vs, ok := form[tagValue] - if !ok && !isDefaultExists { + if !ok && !opt.isDefaultExists { return false, nil } switch value.Kind() { case reflect.Slice: if !ok { - vs = []string{defaultValue} + vs = []string{opt.defaultValue} } return true, setSlice(vs, value, field) case reflect.Array: if !ok { - vs = []string{defaultValue} + vs = []string{opt.defaultValue} } if len(vs) != value.Len() { return false, fmt.Errorf("%q is not valid value for %s", vs, value.Type().String()) @@ -159,7 +157,7 @@ func tryToSetValue(value reflect.Value, field reflect.StructField, form map[stri default: var val string if !ok { - val = defaultValue + val = opt.defaultValue } if len(vs) > 0 {