Refactors binding module

This commit is contained in:
Manu Mtz-Almeida 2015-03-31 17:51:10 +02:00
parent 18880f9215
commit d4413b6e91
10 changed files with 367 additions and 299 deletions

View File

@ -4,282 +4,43 @@
package binding package binding
import ( import "net/http"
"encoding/json"
"encoding/xml" const (
"errors" MIMEJSON = "application/json"
"log" MIMEHTML = "text/html"
"net/http" MIMEXML = "application/xml"
"reflect" MIMEXML2 = "text/xml"
"strconv" MIMEPlain = "text/plain"
"strings" MIMEPOSTForm = "application/x-www-form-urlencoded"
MIMEMultipartPOSTForm = "multipart/form-data"
) )
type ( type Binding interface {
Binding interface { Name() string
Bind(*http.Request, interface{}) error Bind(*http.Request, interface{}) error
} }
// JSON binding
jsonBinding struct{}
// XML binding
xmlBinding struct{}
// form binding
formBinding struct{}
// multipart form binding
multipartFormBinding struct{}
)
const MAX_MEMORY = 1 * 1024 * 1024
var ( var (
JSON = jsonBinding{} JSON = jsonBinding{}
XML = xmlBinding{} XML = xmlBinding{}
Form = formBinding{} // todo GETForm = getFormBinding{}
MultipartForm = multipartFormBinding{} POSTForm = postFormBinding{}
) )
func (_ jsonBinding) Bind(req *http.Request, obj interface{}) error { func Default(method, contentType string) Binding {
decoder := json.NewDecoder(req.Body) if method == "GET" {
if err := decoder.Decode(obj); err == nil { return GETForm
return Validate(obj)
} else { } else {
return err switch contentType {
} case MIMEPOSTForm:
} return POSTForm
case MIMEJSON:
func (_ xmlBinding) Bind(req *http.Request, obj interface{}) error { return JSON
decoder := xml.NewDecoder(req.Body) case MIMEXML, MIMEXML2:
if err := decoder.Decode(obj); err == nil { return XML
return Validate(obj)
} else {
return err
}
}
func (_ formBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return err
}
if err := mapForm(obj, req.Form); err != nil {
return err
}
return Validate(obj)
}
func (_ multipartFormBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseMultipartForm(MAX_MEMORY); err != nil {
return err
}
if err := mapForm(obj, req.Form); err != nil {
return err
}
return Validate(obj)
}
func mapForm(ptr interface{}, form map[string][]string) error {
typ := reflect.TypeOf(ptr).Elem()
formStruct := reflect.ValueOf(ptr).Elem()
for i := 0; i < typ.NumField(); i++ {
typeField := typ.Field(i)
if inputFieldName := typeField.Tag.Get("form"); inputFieldName != "" {
structField := formStruct.Field(i)
if !structField.CanSet() {
continue
}
inputValue, exists := form[inputFieldName]
if !exists {
continue
}
numElems := len(inputValue)
if structField.Kind() == reflect.Slice && numElems > 0 {
sliceOf := structField.Type().Elem().Kind()
slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
for i := 0; i < numElems; i++ {
if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil {
return err
}
}
formStruct.Field(i).Set(slice)
} else {
if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
return err
}
}
}
}
return nil
}
func setIntField(val string, bitSize int, structField reflect.Value) error {
if val == "" {
val = "0"
}
intVal, err := strconv.ParseInt(val, 10, bitSize)
if err == nil {
structField.SetInt(intVal)
}
return err
}
func setUintField(val string, bitSize int, structField reflect.Value) error {
if val == "" {
val = "0"
}
uintVal, err := strconv.ParseUint(val, 10, bitSize)
if err == nil {
structField.SetUint(uintVal)
}
return err
}
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
switch valueKind {
case reflect.Int:
return setIntField(val, 0, structField)
case reflect.Int8:
return setIntField(val, 8, structField)
case reflect.Int16:
return setIntField(val, 16, structField)
case reflect.Int32:
return setIntField(val, 32, structField)
case reflect.Int64:
return setIntField(val, 64, structField)
case reflect.Uint:
return setUintField(val, 0, structField)
case reflect.Uint8:
return setUintField(val, 8, structField)
case reflect.Uint16:
return setUintField(val, 16, structField)
case reflect.Uint32:
return setUintField(val, 32, structField)
case reflect.Uint64:
return setUintField(val, 64, structField)
case reflect.Bool:
if val == "" {
val = "false"
}
boolVal, err := strconv.ParseBool(val)
if err != nil {
return err
} else {
structField.SetBool(boolVal)
}
case reflect.Float32:
if val == "" {
val = "0.0"
}
floatVal, err := strconv.ParseFloat(val, 32)
if err != nil {
return err
} else {
structField.SetFloat(floatVal)
}
case reflect.Float64:
if val == "" {
val = "0.0"
}
floatVal, err := strconv.ParseFloat(val, 64)
if err != nil {
return err
} else {
structField.SetFloat(floatVal)
}
case reflect.String:
structField.SetString(val)
}
return nil
}
// Don't pass in pointers to bind to. Can lead to bugs. See:
// https://github.com/codegangsta/martini-contrib/issues/40
// https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659
func ensureNotPointer(obj interface{}) {
if reflect.TypeOf(obj).Kind() == reflect.Ptr {
log.Panic("Pointers are not accepted as binding models")
}
}
func Validate(obj interface{}, parents ...string) error {
typ := reflect.TypeOf(obj)
val := reflect.ValueOf(obj)
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
val = val.Elem()
}
switch typ.Kind() {
case reflect.Struct:
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Allow ignored and unexported fields in the struct
if len(field.PkgPath) > 0 || field.Tag.Get("form") == "-" {
continue
}
fieldValue := val.Field(i).Interface()
zero := reflect.Zero(field.Type).Interface()
if strings.Index(field.Tag.Get("binding"), "required") > -1 {
fieldType := field.Type.Kind()
if fieldType == reflect.Struct {
if reflect.DeepEqual(zero, fieldValue) {
return errors.New("Required " + field.Name)
}
err := Validate(fieldValue, field.Name)
if err != nil {
return err
}
} else if reflect.DeepEqual(zero, fieldValue) {
if len(parents) > 0 {
return errors.New("Required " + field.Name + " on " + parents[0])
} else {
return errors.New("Required " + field.Name)
}
} else if fieldType == reflect.Slice && field.Type.Elem().Kind() == reflect.Struct {
err := Validate(fieldValue)
if err != nil {
return err
}
}
} else {
fieldType := field.Type.Kind()
if fieldType == reflect.Struct {
if reflect.DeepEqual(zero, fieldValue) {
continue
}
err := Validate(fieldValue, field.Name)
if err != nil {
return err
}
} else if fieldType == reflect.Slice && field.Type.Elem().Kind() == reflect.Struct {
err := Validate(fieldValue, field.Name)
if err != nil {
return err
}
}
}
}
case reflect.Slice:
for i := 0; i < val.Len(); i++ {
fieldValue := val.Index(i).Interface()
err := Validate(fieldValue)
if err != nil {
return err
}
}
default: default:
return nil return GETForm
}
} }
return nil
} }

143
binding/form_mapping.go Normal file
View File

@ -0,0 +1,143 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"errors"
"fmt"
"log"
"reflect"
"strconv"
)
func mapForm(ptr interface{}, form map[string][]string) error {
typ := reflect.TypeOf(ptr).Elem()
val := reflect.ValueOf(ptr).Elem()
for i := 0; i < typ.NumField(); i++ {
typeField := typ.Field(i)
structField := val.Field(i)
if !structField.CanSet() {
continue
}
inputFieldName := typeField.Tag.Get("form")
if inputFieldName == "" {
inputFieldName = typeField.Name
}
inputValue, exists := form[inputFieldName]
fmt.Println("Field: "+inputFieldName+" Value: ", inputValue)
if !exists {
continue
}
numElems := len(inputValue)
if structField.Kind() == reflect.Slice && numElems > 0 {
sliceOf := structField.Type().Elem().Kind()
slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
for i := 0; i < numElems; i++ {
if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil {
return err
}
}
val.Field(i).Set(slice)
} else {
if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
return err
}
}
}
return nil
}
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
switch valueKind {
case reflect.Int:
return setIntField(val, 0, structField)
case reflect.Int8:
return setIntField(val, 8, structField)
case reflect.Int16:
return setIntField(val, 16, structField)
case reflect.Int32:
return setIntField(val, 32, structField)
case reflect.Int64:
return setIntField(val, 64, structField)
case reflect.Uint:
return setUintField(val, 0, structField)
case reflect.Uint8:
return setUintField(val, 8, structField)
case reflect.Uint16:
return setUintField(val, 16, structField)
case reflect.Uint32:
return setUintField(val, 32, structField)
case reflect.Uint64:
return setUintField(val, 64, structField)
case reflect.Bool:
return setBoolField(val, structField)
case reflect.Float32:
return setFloatField(val, 32, structField)
case reflect.Float64:
return setFloatField(val, 64, structField)
case reflect.String:
structField.SetString(val)
default:
return errors.New("Unknown type")
}
return nil
}
func setIntField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0"
}
intVal, err := strconv.ParseInt(val, 10, bitSize)
if err == nil {
field.SetInt(intVal)
}
return err
}
func setUintField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0"
}
uintVal, err := strconv.ParseUint(val, 10, bitSize)
if err == nil {
field.SetUint(uintVal)
}
return err
}
func setBoolField(val string, field reflect.Value) error {
if val == "" {
val = "false"
}
boolVal, err := strconv.ParseBool(val)
if err == nil {
field.SetBool(boolVal)
}
return nil
}
func setFloatField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0.0"
}
floatVal, err := strconv.ParseFloat(val, bitSize)
if err == nil {
field.SetFloat(floatVal)
}
return err
}
// Don't pass in pointers to bind to. Can lead to bugs. See:
// https://github.com/codegangsta/martini-contrib/issues/40
// https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659
func ensureNotPointer(obj interface{}) {
if reflect.TypeOf(obj).Kind() == reflect.Ptr {
log.Panic("Pointers are not accepted as binding models")
}
}

23
binding/get_form.go Normal file
View File

@ -0,0 +1,23 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import "net/http"
type getFormBinding struct{}
func (_ getFormBinding) Name() string {
return "get_form"
}
func (_ getFormBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return err
}
if err := mapForm(obj, req.Form); err != nil {
return err
}
return Validate(obj)
}

26
binding/json.go Normal file
View File

@ -0,0 +1,26 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"encoding/json"
"net/http"
)
type jsonBinding struct{}
func (_ jsonBinding) Name() string {
return "json"
}
func (_ jsonBinding) Bind(req *http.Request, obj interface{}) error {
decoder := json.NewDecoder(req.Body)
if err := decoder.Decode(obj); err == nil {
return Validate(obj)
} else {
return err
}
}

23
binding/post_form.go Normal file
View File

@ -0,0 +1,23 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import "net/http"
type postFormBinding struct{}
func (_ postFormBinding) Name() string {
return "post_form"
}
func (_ postFormBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return err
}
if err := mapForm(obj, req.PostForm); err != nil {
return err
}
return Validate(obj)
}

79
binding/validate.go Normal file
View File

@ -0,0 +1,79 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"errors"
"reflect"
"strings"
)
func Validate(obj interface{}) error {
return validate(obj, "{{ROOT}}")
}
func validate(obj interface{}, parent string) error {
typ, val := inspectObject(obj)
switch typ.Kind() {
case reflect.Struct:
return validateStruct(typ, val, parent)
case reflect.Slice:
return validateSlice(typ, val, parent)
default:
return errors.New("The object is not a slice or struct.")
}
}
func inspectObject(obj interface{}) (typ reflect.Type, val reflect.Value) {
typ = reflect.TypeOf(obj)
val = reflect.ValueOf(obj)
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
val = val.Elem()
}
return
}
func validateSlice(typ reflect.Type, val reflect.Value, parent string) error {
if typ.Elem().Kind() == reflect.Struct {
for i := 0; i < val.Len(); i++ {
itemValue := val.Index(i).Interface()
if err := validate(itemValue, parent); err != nil {
return err
}
}
}
return nil
}
func validateStruct(typ reflect.Type, val reflect.Value, parent string) error {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Allow ignored and unexported fields in the struct
// TODO should include || field.Tag.Get("form") == "-"
if len(field.PkgPath) > 0 {
continue
}
fieldValue := val.Field(i).Interface()
requiredField := strings.Index(field.Tag.Get("binding"), "required") > -1
if requiredField {
zero := reflect.Zero(field.Type).Interface()
if reflect.DeepEqual(zero, fieldValue) {
return errors.New("Required " + field.Name + " in " + parent)
}
}
fieldType := field.Type.Kind()
if fieldType == reflect.Struct || fieldType == reflect.Slice {
if err := validate(fieldValue, field.Name); err != nil {
return err
}
}
}
return nil
}

25
binding/xml.go Normal file
View File

@ -0,0 +1,25 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package binding
import (
"encoding/xml"
"net/http"
)
type xmlBinding struct{}
func (_ xmlBinding) Name() string {
return "xml"
}
func (_ xmlBinding) Bind(req *http.Request, obj interface{}) error {
decoder := xml.NewDecoder(req.Body)
if err := decoder.Decode(obj); err == nil {
return Validate(obj)
} else {
return err
}
}

View File

@ -179,21 +179,7 @@ func (c *Context) ContentType() string {
// else --> returns an error // else --> returns an error
// if Parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input. It decodes the json payload into the struct specified as a pointer.Like ParseBody() but this method also writes a 400 error if the json is not valid. // if Parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input. It decodes the json payload into the struct specified as a pointer.Like ParseBody() but this method also writes a 400 error if the json is not valid.
func (c *Context) Bind(obj interface{}) bool { func (c *Context) Bind(obj interface{}) bool {
var b binding.Binding b := binding.Default(c.Request.Method, c.ContentType())
ctype := filterFlags(c.Request.Header.Get("Content-Type"))
switch {
case c.Request.Method == "GET" || ctype == MIMEPOSTForm:
b = binding.Form
case ctype == MIMEMultipartPOSTForm:
b = binding.MultipartForm
case ctype == MIMEJSON:
b = binding.JSON
case ctype == MIMEXML || ctype == MIMEXML2:
b = binding.XML
default:
c.Fail(400, errors.New("unknown content-type: "+ctype))
return false
}
return c.BindWith(obj, b) return c.BindWith(obj, b)
} }
@ -283,18 +269,18 @@ type Negotiate struct {
func (c *Context) Negotiate(code int, config Negotiate) { func (c *Context) Negotiate(code int, config Negotiate) {
switch c.NegotiateFormat(config.Offered...) { switch c.NegotiateFormat(config.Offered...) {
case MIMEJSON: case binding.MIMEJSON:
data := chooseData(config.JSONData, config.Data) data := chooseData(config.JSONData, config.Data)
c.JSON(code, data) c.JSON(code, data)
case MIMEHTML: case binding.MIMEHTML:
data := chooseData(config.HTMLData, config.Data)
if len(config.HTMLPath) == 0 { if len(config.HTMLPath) == 0 {
log.Panic("negotiate config is wrong. html path is needed") log.Panic("negotiate config is wrong. html path is needed")
} }
data := chooseData(config.HTMLData, config.Data)
c.HTML(code, config.HTMLPath, data) c.HTML(code, config.HTMLPath, data)
case MIMEXML: case binding.MIMEXML:
data := chooseData(config.XMLData, config.Data) data := chooseData(config.XMLData, config.Data)
c.XML(code, data) c.XML(code, data)

View File

@ -13,6 +13,16 @@ import (
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
) )
const (
MIMEJSON = binding.MIMEJSON
MIMEHTML = binding.MIMEHTML
MIMEXML = binding.MIMEXML
MIMEXML2 = binding.MIMEXML2
MIMEPlain = binding.MIMEPlain
MIMEPOSTForm = binding.MIMEPOSTForm
MIMEMultipartPOSTForm = binding.MIMEMultipartPOSTForm
)
// DEPRECATED, use Bind() instead. // DEPRECATED, use Bind() instead.
// Like ParseBody() but this method also writes a 400 error if the json is not valid. // Like ParseBody() but this method also writes a 400 error if the json is not valid.
func (c *Context) EnsureBody(item interface{}) bool { func (c *Context) EnsureBody(item interface{}) bool {

14
gin.go
View File

@ -9,19 +9,11 @@ import (
"net/http" "net/http"
"sync" "sync"
"github.com/gin-gonic/gin/binding"
"github.com/gin-gonic/gin/render" "github.com/gin-gonic/gin/render"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
) )
const (
MIMEJSON = "application/json"
MIMEHTML = "text/html"
MIMEXML = "application/xml"
MIMEXML2 = "text/xml"
MIMEPlain = "text/plain"
MIMEPOSTForm = "application/x-www-form-urlencoded"
MIMEMultipartPOSTForm = "multipart/form-data"
)
type ( type (
HandlerFunc func(*Context) HandlerFunc func(*Context)
@ -147,7 +139,7 @@ func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) {
c.Next() c.Next()
if !c.Writer.Written() { if !c.Writer.Written() {
if c.Writer.Status() == 404 { if c.Writer.Status() == 404 {
c.Data(-1, MIMEPlain, engine.Default404Body) c.Data(-1, binding.MIMEPlain, engine.Default404Body)
} else { } else {
c.Writer.WriteHeaderNow() c.Writer.WriteHeaderNow()
} }
@ -162,7 +154,7 @@ func (engine *Engine) handle405(w http.ResponseWriter, req *http.Request) {
c.Next() c.Next()
if !c.Writer.Written() { if !c.Writer.Written() {
if c.Writer.Status() == 405 { if c.Writer.Status() == 405 {
c.Data(-1, MIMEPlain, engine.Default405Body) c.Data(-1, binding.MIMEPlain, engine.Default405Body)
} else { } else {
c.Writer.WriteHeaderNow() c.Writer.WriteHeaderNow()
} }