// Copyright 2014 Manu Martinez-Almeida. All rights reserved. // Use of this source code is governed by a MIT style // license that can be found in the LICENSE file. package binding import ( "errors" "fmt" "reflect" "strconv" "strings" "time" "github.com/gin-gonic/gin/internal/json" ) var errUnknownType = errors.New("Unknown type") func mapUri(ptr interface{}, m map[string][]string) error { return mapFormByTag(ptr, m, "uri") } func mapForm(ptr interface{}, form map[string][]string) error { return mapFormByTag(ptr, form, "form") } var emptyField = reflect.StructField{} func mapFormByTag(ptr interface{}, form map[string][]string, tag string) error { _, err := mapping(reflect.ValueOf(ptr), emptyField, form, tag) return err } func mapping(value reflect.Value, field reflect.StructField, form map[string][]string, tag string) (bool, error) { var vKind = value.Kind() if vKind == reflect.Ptr { var isNew bool vPtr := value if value.IsNil() { isNew = true vPtr = reflect.New(value.Type().Elem()) } isSetted, err := mapping(vPtr.Elem(), field, form, tag) if err != nil { return false, err } if isNew && isSetted { value.Set(vPtr) } return isSetted, nil } ok, err := tryToSetValue(value, field, form, tag) if err != nil { return false, err } if ok { return true, nil } if vKind == reflect.Struct { tValue := value.Type() var isSetted bool for i := 0; i < value.NumField(); i++ { if !value.Field(i).CanSet() { continue } ok, err := mapping(value.Field(i), tValue.Field(i), form, tag) if err != nil { return false, err } isSetted = isSetted || ok } return isSetted, nil } return false, nil } func tryToSetValue(value reflect.Value, field reflect.StructField, form map[string][]string, tag string) (bool, error) { var tagValue, defaultValue string var isDefaultExists bool tagValue = field.Tag.Get(tag) tagValue, opts := head(tagValue, ",") if tagValue == "-" { // just ignoring this field return false, nil } if tagValue == "" { // default value is FieldName tagValue = field.Name } if tagValue == "" { // when field is "emptyField" variable return false, nil } var opt string for len(opts) > 0 { opt, opts = head(opts, ",") k, v := head(opt, "=") switch k { case "default": isDefaultExists = true defaultValue = v } } vs, ok := form[tagValue] if !ok && !isDefaultExists { return false, nil } switch value.Kind() { case reflect.Slice: if !ok { vs = []string{defaultValue} } return true, setSlice(vs, value, field) case reflect.Array: if !ok { vs = []string{defaultValue} } if len(vs) != value.Len() { return false, fmt.Errorf("%q is not valid value for %s", vs, value.Type().String()) } return true, setArray(vs, value, field) default: var val string if !ok { val = defaultValue } if len(vs) > 0 { val = vs[0] } return true, setWithProperType(val, value, field) } } func setWithProperType(val string, value reflect.Value, field reflect.StructField) error { switch value.Kind() { case reflect.Int: return setIntField(val, 0, value) case reflect.Int8: return setIntField(val, 8, value) case reflect.Int16: return setIntField(val, 16, value) case reflect.Int32: return setIntField(val, 32, value) case reflect.Int64: switch value.Interface().(type) { case time.Duration: return setTimeDuration(val, value, field) } return setIntField(val, 64, value) case reflect.Uint: return setUintField(val, 0, value) case reflect.Uint8: return setUintField(val, 8, value) case reflect.Uint16: return setUintField(val, 16, value) case reflect.Uint32: return setUintField(val, 32, value) case reflect.Uint64: return setUintField(val, 64, value) case reflect.Bool: return setBoolField(val, value) case reflect.Float32: return setFloatField(val, 32, value) case reflect.Float64: return setFloatField(val, 64, value) case reflect.String: value.SetString(val) case reflect.Struct: switch value.Interface().(type) { case time.Time: return setTimeField(val, field, value) } return json.Unmarshal([]byte(val), value.Addr().Interface()) case reflect.Map: return json.Unmarshal([]byte(val), value.Addr().Interface()) default: return errUnknownType } return nil } func setIntField(val string, bitSize int, field reflect.Value) error { if val == "" { val = "0" } intVal, err := strconv.ParseInt(val, 10, bitSize) if err == nil { field.SetInt(intVal) } return err } func setUintField(val string, bitSize int, field reflect.Value) error { if val == "" { val = "0" } uintVal, err := strconv.ParseUint(val, 10, bitSize) if err == nil { field.SetUint(uintVal) } return err } func setBoolField(val string, field reflect.Value) error { if val == "" { val = "false" } boolVal, err := strconv.ParseBool(val) if err == nil { field.SetBool(boolVal) } return err } func setFloatField(val string, bitSize int, field reflect.Value) error { if val == "" { val = "0.0" } floatVal, err := strconv.ParseFloat(val, bitSize) if err == nil { field.SetFloat(floatVal) } return err } func setTimeField(val string, structField reflect.StructField, value reflect.Value) error { timeFormat := structField.Tag.Get("time_format") if timeFormat == "" { timeFormat = time.RFC3339 } if val == "" { value.Set(reflect.ValueOf(time.Time{})) return nil } l := time.Local if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC { l = time.UTC } if locTag := structField.Tag.Get("time_location"); locTag != "" { loc, err := time.LoadLocation(locTag) if err != nil { return err } l = loc } t, err := time.ParseInLocation(timeFormat, val, l) if err != nil { return err } value.Set(reflect.ValueOf(t)) return nil } func setArray(vals []string, value reflect.Value, field reflect.StructField) error { for i, s := range vals { err := setWithProperType(s, value.Index(i), field) if err != nil { return err } } return nil } func setSlice(vals []string, value reflect.Value, field reflect.StructField) error { slice := reflect.MakeSlice(value.Type(), len(vals), len(vals)) err := setArray(vals, slice, field) if err != nil { return err } value.Set(slice) return nil } func setTimeDuration(val string, value reflect.Value, field reflect.StructField) error { d, err := time.ParseDuration(val) if err != nil { return err } value.Set(reflect.ValueOf(d)) return nil } func head(str, sep string) (head string, tail string) { idx := strings.Index(str, sep) if idx < 0 { return str, "" } return str[:idx], str[idx+len(sep):] }