refactor(form_mapping.go): mapping ptr, struct and map (#1749)

* refactor(form_mapping.go): mapping ptr, struct and map

* fix #1672 correct work with ptr - not create value if field is not set
* avoid allocations on strings.Split() - change to strings.Index()
* fix #610 tag value "-" is mean ignoring field
* struct fields mapped like json.Unmarshal
* map fields mapped like json.Unmarshal

* fix after @thinkerou review
This commit is contained in:
Dmitry Kutakov 2019-03-03 09:39:43 +03:00 committed by 田欧
parent 893c6cae07
commit 0d50ce8597
4 changed files with 341 additions and 150 deletions

View File

@ -1836,24 +1836,6 @@ $ curl "http://localhost:8080/getd?field_x=hello&field_d=world"
{"d":"world","x":{"FieldX":"hello"}} {"d":"world","x":{"FieldX":"hello"}}
``` ```
**NOTE**: NOT support the follow style struct:
```go
type StructX struct {
X struct {} `form:"name_x"` // HERE have form
}
type StructY struct {
Y StructX `form:"name_y"` // HERE have form
}
type StructZ struct {
Z *StructZ `form:"name_z"` // HERE have form
}
```
In a word, only support nested custom struct which have no `form` now.
### Try to bind body into different structs ### Try to bind body into different structs
The normal methods for binding request body consumes `c.Request.Body` and they The normal methods for binding request body consumes `c.Request.Body` and they

View File

@ -12,6 +12,7 @@ import (
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@ -57,7 +58,6 @@ type FooStructForTimeTypeFailLocation struct {
} }
type FooStructForMapType struct { type FooStructForMapType struct {
// Unknown type: not support map
MapFoo map[string]interface{} `form:"map_foo"` MapFoo map[string]interface{} `form:"map_foo"`
} }
@ -303,7 +303,7 @@ func TestBindingFormInvalidName2(t *testing.T) {
func TestBindingFormForType(t *testing.T) { func TestBindingFormForType(t *testing.T) {
testFormBindingForType(t, "POST", testFormBindingForType(t, "POST",
"/", "/", "/", "/",
"map_foo=", "bar2=1", "Map") "map_foo={\"bar\":123}", "map_foo=1", "Map")
testFormBindingForType(t, "POST", testFormBindingForType(t, "POST",
"/", "/", "/", "/",
@ -508,20 +508,30 @@ func TestBindingYAMLFail(t *testing.T) {
`foo:\nbar`, `bar: foo`) `foo:\nbar`, `bar: foo`)
} }
func createFormPostRequest() *http.Request { func createFormPostRequest(t *testing.T) *http.Request {
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo")) req, err := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo"))
assert.NoError(t, err)
req.Header.Set("Content-Type", MIMEPOSTForm) req.Header.Set("Content-Type", MIMEPOSTForm)
return req return req
} }
func createDefaultFormPostRequest() *http.Request { func createDefaultFormPostRequest(t *testing.T) *http.Request {
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar")) req, err := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar"))
assert.NoError(t, err)
req.Header.Set("Content-Type", MIMEPOSTForm) req.Header.Set("Content-Type", MIMEPOSTForm)
return req return req
} }
func createFormPostRequestFail() *http.Request { func createFormPostRequestForMap(t *testing.T) *http.Request {
req, _ := http.NewRequest("POST", "/?map_foo=getfoo", bytes.NewBufferString("map_foo=bar")) req, err := http.NewRequest("POST", "/?map_foo=getfoo", bytes.NewBufferString("map_foo={\"bar\":123}"))
assert.NoError(t, err)
req.Header.Set("Content-Type", MIMEPOSTForm)
return req
}
func createFormPostRequestForMapFail(t *testing.T) *http.Request {
req, err := http.NewRequest("POST", "/?map_foo=getfoo", bytes.NewBufferString("map_foo=hello"))
assert.NoError(t, err)
req.Header.Set("Content-Type", MIMEPOSTForm) req.Header.Set("Content-Type", MIMEPOSTForm)
return req return req
} }
@ -535,26 +545,42 @@ func createFormMultipartRequest(t *testing.T) *http.Request {
assert.NoError(t, mw.SetBoundary(boundary)) assert.NoError(t, mw.SetBoundary(boundary))
assert.NoError(t, mw.WriteField("foo", "bar")) assert.NoError(t, mw.WriteField("foo", "bar"))
assert.NoError(t, mw.WriteField("bar", "foo")) assert.NoError(t, mw.WriteField("bar", "foo"))
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body) req, err := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body)
assert.NoError(t, err)
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary)
return req return req
} }
func createFormMultipartRequestFail(t *testing.T) *http.Request { func createFormMultipartRequestForMap(t *testing.T) *http.Request {
boundary := "--testboundary" boundary := "--testboundary"
body := new(bytes.Buffer) body := new(bytes.Buffer)
mw := multipart.NewWriter(body) mw := multipart.NewWriter(body)
defer mw.Close() defer mw.Close()
assert.NoError(t, mw.SetBoundary(boundary)) assert.NoError(t, mw.SetBoundary(boundary))
assert.NoError(t, mw.WriteField("map_foo", "bar")) assert.NoError(t, mw.WriteField("map_foo", "{\"bar\":123, \"name\":\"thinkerou\", \"pai\": 3.14}"))
req, _ := http.NewRequest("POST", "/?map_foo=getfoo", body) req, err := http.NewRequest("POST", "/?map_foo=getfoo", body)
assert.NoError(t, err)
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary)
return req
}
func createFormMultipartRequestForMapFail(t *testing.T) *http.Request {
boundary := "--testboundary"
body := new(bytes.Buffer)
mw := multipart.NewWriter(body)
defer mw.Close()
assert.NoError(t, mw.SetBoundary(boundary))
assert.NoError(t, mw.WriteField("map_foo", "3.14"))
req, err := http.NewRequest("POST", "/?map_foo=getfoo", body)
assert.NoError(t, err)
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary) req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary)
return req return req
} }
func TestBindingFormPost(t *testing.T) { func TestBindingFormPost(t *testing.T) {
req := createFormPostRequest() req := createFormPostRequest(t)
var obj FooBarStruct var obj FooBarStruct
assert.NoError(t, FormPost.Bind(req, &obj)) assert.NoError(t, FormPost.Bind(req, &obj))
@ -564,7 +590,7 @@ func TestBindingFormPost(t *testing.T) {
} }
func TestBindingDefaultValueFormPost(t *testing.T) { func TestBindingDefaultValueFormPost(t *testing.T) {
req := createDefaultFormPostRequest() req := createDefaultFormPostRequest(t)
var obj FooDefaultBarStruct var obj FooDefaultBarStruct
assert.NoError(t, FormPost.Bind(req, &obj)) assert.NoError(t, FormPost.Bind(req, &obj))
@ -572,8 +598,16 @@ func TestBindingDefaultValueFormPost(t *testing.T) {
assert.Equal(t, "hello", obj.Bar) assert.Equal(t, "hello", obj.Bar)
} }
func TestBindingFormPostFail(t *testing.T) { func TestBindingFormPostForMap(t *testing.T) {
req := createFormPostRequestFail() req := createFormPostRequestForMap(t)
var obj FooStructForMapType
err := FormPost.Bind(req, &obj)
assert.NoError(t, err)
assert.Equal(t, float64(123), obj.MapFoo["bar"].(float64))
}
func TestBindingFormPostForMapFail(t *testing.T) {
req := createFormPostRequestForMapFail(t)
var obj FooStructForMapType var obj FooStructForMapType
err := FormPost.Bind(req, &obj) err := FormPost.Bind(req, &obj)
assert.Error(t, err) assert.Error(t, err)
@ -589,8 +623,18 @@ func TestBindingFormMultipart(t *testing.T) {
assert.Equal(t, "foo", obj.Bar) assert.Equal(t, "foo", obj.Bar)
} }
func TestBindingFormMultipartFail(t *testing.T) { func TestBindingFormMultipartForMap(t *testing.T) {
req := createFormMultipartRequestFail(t) req := createFormMultipartRequestForMap(t)
var obj FooStructForMapType
err := FormMultipart.Bind(req, &obj)
assert.NoError(t, err)
assert.Equal(t, float64(123), obj.MapFoo["bar"].(float64))
assert.Equal(t, "thinkerou", obj.MapFoo["name"].(string))
assert.Equal(t, float64(3.14), obj.MapFoo["pai"].(float64))
}
func TestBindingFormMultipartForMapFail(t *testing.T) {
req := createFormMultipartRequestForMapFail(t)
var obj FooStructForMapType var obj FooStructForMapType
err := FormMultipart.Bind(req, &obj) err := FormMultipart.Bind(req, &obj)
assert.Error(t, err) assert.Error(t, err)
@ -773,6 +817,17 @@ func TestFormBindingFail(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
} }
func TestFormBindingMultipartFail(t *testing.T) {
obj := FooBarStruct{}
req, err := http.NewRequest("POST", "/", strings.NewReader("foo=bar"))
assert.NoError(t, err)
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+";boundary=testboundary")
_, err = req.MultipartReader()
assert.NoError(t, err)
err = Form.Bind(req, &obj)
assert.Error(t, err)
}
func TestFormPostBindingFail(t *testing.T) { func TestFormPostBindingFail(t *testing.T) {
b := FormPost b := FormPost
assert.Equal(t, "form-urlencoded", b.Name()) assert.Equal(t, "form-urlencoded", b.Name())
@ -1109,7 +1164,8 @@ func testFormBindingForType(t *testing.T, method, path, badPath, body, badBody s
case "Map": case "Map":
obj := FooStructForMapType{} obj := FooStructForMapType{}
err := b.Bind(req, &obj) err := b.Bind(req, &obj)
assert.Error(t, err) assert.NoError(t, err)
assert.Equal(t, float64(123), obj.MapFoo["bar"].(float64))
case "SliceMap": case "SliceMap":
obj := FooStructForSliceMapType{} obj := FooStructForSliceMapType{}
err := b.Bind(req, &obj) err := b.Bind(req, &obj)
@ -1317,3 +1373,43 @@ func TestCanSet(t *testing.T) {
var c CanSetStruct var c CanSetStruct
assert.Nil(t, mapForm(&c, nil)) assert.Nil(t, mapForm(&c, nil))
} }
func formPostRequest(path, body string) *http.Request {
req := requestWithBody("POST", path, body)
req.Header.Add("Content-Type", MIMEPOSTForm)
return req
}
func TestBindingSliceDefault(t *testing.T) {
var s struct {
Friends []string `form:"friends,default=mike"`
}
req := formPostRequest("", "")
err := Form.Bind(req, &s)
assert.NoError(t, err)
assert.Len(t, s.Friends, 1)
assert.Equal(t, "mike", s.Friends[0])
}
func TestBindingStructField(t *testing.T) {
var s struct {
Opts struct {
Port int
} `form:"opts"`
}
req := formPostRequest("", `opts={"Port": 8000}`)
err := Form.Bind(req, &s)
assert.NoError(t, err)
assert.Equal(t, 8000, s.Opts.Port)
}
func TestBindingUnknownTypeChan(t *testing.T) {
var s struct {
Stop chan bool `form:"stop"`
}
req := formPostRequest("", "stop=true")
err := Form.Bind(req, &s)
assert.Error(t, err)
assert.Equal(t, errUnknownType, err)
}

View File

@ -5,6 +5,7 @@
package binding package binding
import ( import (
"encoding/json"
"errors" "errors"
"reflect" "reflect"
"strconv" "strconv"
@ -12,6 +13,8 @@ import (
"time" "time"
) )
var errUnknownType = errors.New("Unknown type")
func mapUri(ptr interface{}, m map[string][]string) error { func mapUri(ptr interface{}, m map[string][]string) error {
return mapFormByTag(ptr, m, "uri") return mapFormByTag(ptr, m, "uri")
} }
@ -20,124 +23,153 @@ func mapForm(ptr interface{}, form map[string][]string) error {
return mapFormByTag(ptr, form, "form") return mapFormByTag(ptr, form, "form")
} }
var emptyField = reflect.StructField{}
func mapFormByTag(ptr interface{}, form map[string][]string, tag string) error { func mapFormByTag(ptr interface{}, form map[string][]string, tag string) error {
typ := reflect.TypeOf(ptr).Elem() _, err := mapping(reflect.ValueOf(ptr), emptyField, form, tag)
val := reflect.ValueOf(ptr).Elem()
for i := 0; i < typ.NumField(); i++ {
typeField := typ.Field(i)
structField := val.Field(i)
if !structField.CanSet() {
continue
}
structFieldKind := structField.Kind()
inputFieldName := typeField.Tag.Get(tag)
inputFieldNameList := strings.Split(inputFieldName, ",")
inputFieldName = inputFieldNameList[0]
var defaultValue string
if len(inputFieldNameList) > 1 {
defaultList := strings.SplitN(inputFieldNameList[1], "=", 2)
if defaultList[0] == "default" {
defaultValue = defaultList[1]
}
}
if inputFieldName == "-" {
continue
}
if inputFieldName == "" {
inputFieldName = typeField.Name
// if "form" tag is nil, we inspect if the field is a struct or struct pointer.
// this would not make sense for JSON parsing but it does for a form
// since data is flatten
if structFieldKind == reflect.Ptr {
if !structField.Elem().IsValid() {
structField.Set(reflect.New(structField.Type().Elem()))
}
structField = structField.Elem()
structFieldKind = structField.Kind()
}
if structFieldKind == reflect.Struct {
err := mapFormByTag(structField.Addr().Interface(), form, tag)
if err != nil {
return err return err
}
continue
}
}
inputValue, exists := form[inputFieldName]
if !exists {
if defaultValue == "" {
continue
}
inputValue = make([]string, 1)
inputValue[0] = defaultValue
}
numElems := len(inputValue)
if structFieldKind == reflect.Slice && numElems > 0 {
sliceOf := structField.Type().Elem().Kind()
slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
for i := 0; i < numElems; i++ {
if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil {
return err
}
}
val.Field(i).Set(slice)
continue
}
if _, isTime := structField.Interface().(time.Time); isTime {
if err := setTimeField(inputValue[0], typeField, structField); err != nil {
return err
}
continue
}
if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
return err
}
}
return nil
} }
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { func mapping(value reflect.Value, field reflect.StructField, form map[string][]string, tag string) (bool, error) {
switch valueKind { var vKind = value.Kind()
case reflect.Int:
return setIntField(val, 0, structField) if vKind == reflect.Ptr {
case reflect.Int8: var isNew bool
return setIntField(val, 8, structField) vPtr := value
case reflect.Int16: if value.IsNil() {
return setIntField(val, 16, structField) isNew = true
case reflect.Int32: vPtr = reflect.New(value.Type().Elem())
return setIntField(val, 32, structField)
case reflect.Int64:
return setIntField(val, 64, structField)
case reflect.Uint:
return setUintField(val, 0, structField)
case reflect.Uint8:
return setUintField(val, 8, structField)
case reflect.Uint16:
return setUintField(val, 16, structField)
case reflect.Uint32:
return setUintField(val, 32, structField)
case reflect.Uint64:
return setUintField(val, 64, structField)
case reflect.Bool:
return setBoolField(val, structField)
case reflect.Float32:
return setFloatField(val, 32, structField)
case reflect.Float64:
return setFloatField(val, 64, structField)
case reflect.String:
structField.SetString(val)
case reflect.Ptr:
if !structField.Elem().IsValid() {
structField.Set(reflect.New(structField.Type().Elem()))
} }
structFieldElem := structField.Elem() isSetted, err := mapping(vPtr.Elem(), field, form, tag)
return setWithProperType(structFieldElem.Kind(), val, structFieldElem) if err != nil {
return false, err
}
if isNew && isSetted {
value.Set(vPtr)
}
return isSetted, nil
}
ok, err := tryToSetValue(value, field, form, tag)
if err != nil {
return false, err
}
if ok {
return true, nil
}
if vKind == reflect.Struct {
tValue := value.Type()
var isSetted bool
for i := 0; i < value.NumField(); i++ {
if !value.Field(i).CanSet() {
continue
}
ok, err := mapping(value.Field(i), tValue.Field(i), form, tag)
if err != nil {
return false, err
}
isSetted = isSetted || ok
}
return isSetted, nil
}
return false, nil
}
func tryToSetValue(value reflect.Value, field reflect.StructField, form map[string][]string, tag string) (bool, error) {
var tagValue, defaultValue string
var isDefaultExists bool
tagValue = field.Tag.Get(tag)
tagValue, opts := head(tagValue, ",")
if tagValue == "-" { // just ignoring this field
return false, nil
}
if tagValue == "" { // default value is FieldName
tagValue = field.Name
}
if tagValue == "" { // when field is "emptyField" variable
return false, nil
}
var opt string
for len(opts) > 0 {
opt, opts = head(opts, ",")
k, v := head(opt, "=")
switch k {
case "default":
isDefaultExists = true
defaultValue = v
}
}
vs, ok := form[tagValue]
if !ok && !isDefaultExists {
return false, nil
}
switch value.Kind() {
case reflect.Slice:
if !ok {
vs = []string{defaultValue}
}
return true, setSlice(vs, value, field)
default: default:
return errors.New("Unknown type") var val string
if !ok {
val = defaultValue
}
if len(vs) > 0 {
val = vs[0]
}
return true, setWithProperType(val, value, field)
}
}
func setWithProperType(val string, value reflect.Value, field reflect.StructField) error {
switch value.Kind() {
case reflect.Int:
return setIntField(val, 0, value)
case reflect.Int8:
return setIntField(val, 8, value)
case reflect.Int16:
return setIntField(val, 16, value)
case reflect.Int32:
return setIntField(val, 32, value)
case reflect.Int64:
return setIntField(val, 64, value)
case reflect.Uint:
return setUintField(val, 0, value)
case reflect.Uint8:
return setUintField(val, 8, value)
case reflect.Uint16:
return setUintField(val, 16, value)
case reflect.Uint32:
return setUintField(val, 32, value)
case reflect.Uint64:
return setUintField(val, 64, value)
case reflect.Bool:
return setBoolField(val, value)
case reflect.Float32:
return setFloatField(val, 32, value)
case reflect.Float64:
return setFloatField(val, 64, value)
case reflect.String:
value.SetString(val)
case reflect.Struct:
switch value.Interface().(type) {
case time.Time:
return setTimeField(val, field, value)
}
return json.Unmarshal([]byte(val), value.Addr().Interface())
case reflect.Map:
return json.Unmarshal([]byte(val), value.Addr().Interface())
default:
return errUnknownType
} }
return nil return nil
} }
@ -218,3 +250,23 @@ func setTimeField(val string, structField reflect.StructField, value reflect.Val
value.Set(reflect.ValueOf(t)) value.Set(reflect.ValueOf(t))
return nil return nil
} }
func setSlice(vals []string, value reflect.Value, field reflect.StructField) error {
slice := reflect.MakeSlice(value.Type(), len(vals), len(vals))
for i, s := range vals {
err := setWithProperType(s, slice.Index(i), field)
if err != nil {
return err
}
}
value.Set(slice)
return nil
}
func head(str, sep string) (head string, tail string) {
idx := strings.Index(str, sep)
if idx < 0 {
return str, ""
}
return str[:idx], str[idx+len(sep):]
}

View File

@ -0,0 +1,61 @@
// Copyright 2019 Gin Core Team. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var form = map[string][]string{
"name": {"mike"},
"friends": {"anna", "nicole"},
"id_number": {"12345678"},
"id_date": {"2018-01-20"},
}
type structFull struct {
Name string `form:"name"`
Age int `form:"age,default=25"`
Friends []string `form:"friends"`
ID *struct {
Number string `form:"id_number"`
DateOfIssue time.Time `form:"id_date" time_format:"2006-01-02" time_utc:"true"`
}
Nationality *string `form:"nationality"`
}
func BenchmarkMapFormFull(b *testing.B) {
var s structFull
for i := 0; i < b.N; i++ {
mapForm(&s, form)
}
b.StopTimer()
t := b
assert.Equal(t, "mike", s.Name)
assert.Equal(t, 25, s.Age)
assert.Equal(t, []string{"anna", "nicole"}, s.Friends)
assert.Equal(t, "12345678", s.ID.Number)
assert.Equal(t, time.Date(2018, 1, 20, 0, 0, 0, 0, time.UTC), s.ID.DateOfIssue)
assert.Nil(t, s.Nationality)
}
type structName struct {
Name string `form:"name"`
}
func BenchmarkMapFormName(b *testing.B) {
var s structName
for i := 0; i < b.N; i++ {
mapForm(&s, form)
}
b.StopTimer()
t := b
assert.Equal(t, "mike", s.Name)
}