From f38c30a0d2f56c54397543f08902b00260c70ff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Szafra=C5=84ski?= Date: Fri, 6 Sep 2019 07:56:59 +0200 Subject: [PATCH] feat(binding): add DisallowUnknownFields() in gin.Context.BindJSON() (#2028) --- binding/binding_test.go | 29 +++++++++++++++++++++++++++++ binding/json.go | 9 +++++++++ mode.go | 8 +++++++- mode_test.go | 6 ++++++ 4 files changed, 51 insertions(+), 1 deletion(-) diff --git a/binding/binding_test.go b/binding/binding_test.go index 3d08d693..caabaace 100644 --- a/binding/binding_test.go +++ b/binding/binding_test.go @@ -64,6 +64,10 @@ type FooStructUseNumber struct { Foo interface{} `json:"foo" binding:"required"` } +type FooStructDisallowUnknownFields struct { + Foo interface{} `json:"foo" binding:"required"` +} + type FooBarStructForTimeType struct { TimeFoo time.Time `form:"time_foo" time_format:"2006-01-02" time_utc:"1" time_location:"Asia/Chongqing"` TimeBar time.Time `form:"time_bar" time_format:"2006-01-02" time_utc:"1"` @@ -194,6 +198,12 @@ func TestBindingJSONUseNumber2(t *testing.T) { `{"foo": 123}`, `{"bar": "foo"}`) } +func TestBindingJSONDisallowUnknownFields(t *testing.T) { + testBodyBindingDisallowUnknownFields(t, JSON, + "/", "/", + `{"foo": "bar"}`, `{"foo": "bar", "what": "this"}`) +} + func TestBindingForm(t *testing.T) { testFormBinding(t, "POST", "/", "/", @@ -1162,6 +1172,25 @@ func testBodyBindingUseNumber2(t *testing.T, b Binding, name, path, badPath, bod assert.Error(t, err) } +func testBodyBindingDisallowUnknownFields(t *testing.T, b Binding, path, badPath, body, badBody string) { + EnableDecoderDisallowUnknownFields = true + defer func() { + EnableDecoderDisallowUnknownFields = false + }() + + obj := FooStructDisallowUnknownFields{} + req := requestWithBody("POST", path, body) + err := b.Bind(req, &obj) + assert.NoError(t, err) + assert.Equal(t, "bar", obj.Foo) + + obj = FooStructDisallowUnknownFields{} + req = requestWithBody("POST", badPath, badBody) + err = JSON.Bind(req, &obj) + assert.Error(t, err) + assert.Contains(t, err.Error(), "what") +} + func testBodyBindingFail(t *testing.T, b Binding, name, path, badPath, body, badBody string) { assert.Equal(t, name, b.Name()) diff --git a/binding/json.go b/binding/json.go index f968161b..d62e0705 100644 --- a/binding/json.go +++ b/binding/json.go @@ -18,6 +18,12 @@ import ( // interface{} as a Number instead of as a float64. var EnableDecoderUseNumber = false +// EnableDecoderDisallowUnknownFields is used to call the DisallowUnknownFields method +// on the JSON Decoder instance. DisallowUnknownFields causes the Decoder to +// return an error when the destination is a struct and the input contains object +// keys which do not match any non-ignored, exported fields in the destination. +var EnableDecoderDisallowUnknownFields = false + type jsonBinding struct{} func (jsonBinding) Name() string { @@ -40,6 +46,9 @@ func decodeJSON(r io.Reader, obj interface{}) error { if EnableDecoderUseNumber { decoder.UseNumber() } + if EnableDecoderDisallowUnknownFields { + decoder.DisallowUnknownFields() + } if err := decoder.Decode(obj); err != nil { return err } diff --git a/mode.go b/mode.go index 8aa84aa8..c3c37fdc 100644 --- a/mode.go +++ b/mode.go @@ -71,12 +71,18 @@ func DisableBindValidation() { binding.Validator = nil } -// EnableJsonDecoderUseNumber sets true for binding.EnableDecoderUseNumberto to +// EnableJsonDecoderUseNumber sets true for binding.EnableDecoderUseNumber to // call the UseNumber method on the JSON Decoder instance. func EnableJsonDecoderUseNumber() { binding.EnableDecoderUseNumber = true } +// EnableJsonDisallowUnknownFields sets true for binding.EnableDecoderDisallowUnknownFields to +// call the DisallowUnknownFields method on the JSON Decoder instance. +func EnableJsonDecoderDisallowUnknownFields() { + binding.EnableDecoderDisallowUnknownFields = true +} + // Mode returns currently gin mode. func Mode() string { return modeName diff --git a/mode_test.go b/mode_test.go index 3dba5150..0c5a3234 100644 --- a/mode_test.go +++ b/mode_test.go @@ -45,3 +45,9 @@ func TestEnableJsonDecoderUseNumber(t *testing.T) { EnableJsonDecoderUseNumber() assert.True(t, binding.EnableDecoderUseNumber) } + +func TestEnableJsonDecoderDisallowUnknownFields(t *testing.T) { + assert.False(t, binding.EnableDecoderDisallowUnknownFields) + EnableJsonDecoderDisallowUnknownFields() + assert.True(t, binding.EnableDecoderDisallowUnknownFields) +}