From 8363db7e38c2e55258ae1cc5ed2995187f9fa210 Mon Sep 17 00:00:00 2001 From: TimAndy Date: Thu, 21 Nov 2024 14:43:06 +0800 Subject: [PATCH] feat: support custom json codec at runtime --- binding/form_mapping.go | 4 +- binding/json.go | 2 +- binding/json_test.go | 190 ++++++++++++++++++++++++++++++++++++++++ codec/api/json.go | 54 ++++++++++++ codec/json/api.go | 5 ++ codec/json/go_json.go | 43 ++++++--- codec/json/json.go | 43 ++++++--- codec/json/jsoniter.go | 46 +++++++--- codec/json/sonic.go | 46 +++++++--- errors.go | 4 +- errors_test.go | 6 +- go.mod | 2 +- render/json.go | 12 +-- render/render_test.go | 4 +- 14 files changed, 394 insertions(+), 67 deletions(-) create mode 100644 codec/api/json.go create mode 100644 codec/json/api.go diff --git a/binding/form_mapping.go b/binding/form_mapping.go index 35c22b17..8da1bb60 100644 --- a/binding/form_mapping.go +++ b/binding/form_mapping.go @@ -333,9 +333,9 @@ func setWithProperType(val string, value reflect.Value, field reflect.StructFiel case multipart.FileHeader: return nil } - return json.Unmarshal(bytesconv.StringToBytes(val), value.Addr().Interface()) + return json.Api.Unmarshal(bytesconv.StringToBytes(val), value.Addr().Interface()) case reflect.Map: - return json.Unmarshal(bytesconv.StringToBytes(val), value.Addr().Interface()) + return json.Api.Unmarshal(bytesconv.StringToBytes(val), value.Addr().Interface()) case reflect.Ptr: if !value.Elem().IsValid() { value.Set(reflect.New(value.Type().Elem())) diff --git a/binding/json.go b/binding/json.go index f85fc1a2..fea3c9fd 100644 --- a/binding/json.go +++ b/binding/json.go @@ -42,7 +42,7 @@ func (jsonBinding) BindBody(body []byte, obj any) error { } func decodeJSON(r io.Reader, obj any) error { - decoder := json.NewDecoder(r) + decoder := json.Api.NewDecoder(r) if EnableDecoderUseNumber { decoder.UseNumber() } diff --git a/binding/json_test.go b/binding/json_test.go index fbd5c527..40d125dc 100644 --- a/binding/json_test.go +++ b/binding/json_test.go @@ -5,8 +5,17 @@ package binding import ( + "io" + "net/http/httptest" "testing" + "time" + "unsafe" + "github.com/gin-gonic/gin/codec/api" + "github.com/gin-gonic/gin/codec/json" + "github.com/gin-gonic/gin/render" + jsoniter "github.com/json-iterator/go" + "github.com/modern-go/reflect2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -28,3 +37,184 @@ func TestJSONBindingBindBodyMap(t *testing.T) { assert.Equal(t, "FOO", s["foo"]) assert.Equal(t, "world", s["hello"]) } + +func TestCustomJsonCodec(t *testing.T) { + //Restore json encoding configuration after testing + oldMarshal := json.Api + defer func() { + json.Api = oldMarshal + }() + //Custom json api + json.Api = customJsonApi{} + + //test decode json + obj := customReq{} + err := jsonBinding{}.BindBody([]byte(`{"time_empty":null,"time_struct": "2001-12-05 10:01:02.345","time_nil":null,"time_pointer":"2002-12-05 10:01:02.345"}`), &obj) + require.NoError(t, err) + assert.Equal(t, zeroTime, obj.TimeEmpty) + assert.Equal(t, time.Date(2001, 12, 05, 10, 01, 02, 345000000, time.Local), obj.TimeStruct) + assert.Nil(t, obj.TimeNil) + assert.Equal(t, time.Date(2002, 12, 05, 10, 01, 02, 345000000, time.Local), *obj.TimePointer) + //test encode json + w := httptest.NewRecorder() + err2 := (render.PureJSON{Data: obj}).Render(w) + require.NoError(t, err2) + assert.Equal(t, "{\"time_empty\":null,\"time_struct\":\"2001-12-05 10:01:02.345\",\"time_nil\":null,\"time_pointer\":\"2002-12-05 10:01:02.345\"}\n", w.Body.String()) + assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type")) +} + +type customReq struct { + TimeEmpty time.Time `json:"time_empty"` + TimeStruct time.Time `json:"time_struct"` + TimeNil *time.Time `json:"time_nil"` + TimePointer *time.Time `json:"time_pointer"` +} + +var customConfig = jsoniter.Config{ + EscapeHTML: true, + SortMapKeys: true, + ValidateJsonRawMessage: true, +}.Froze() + +func init() { + customConfig.RegisterExtension(&TimeEx{}) + customConfig.RegisterExtension(&TimePointerEx{}) +} + +type customJsonApi struct { +} + +func (j customJsonApi) Marshal(v any) ([]byte, error) { + return customConfig.Marshal(v) +} + +func (j customJsonApi) Unmarshal(data []byte, v any) error { + return customConfig.Unmarshal(data, v) +} + +func (j customJsonApi) MarshalIndent(v any, prefix, indent string) ([]byte, error) { + return customConfig.MarshalIndent(v, prefix, indent) +} + +func (j customJsonApi) NewEncoder(writer io.Writer) api.JsonEncoder { + return customConfig.NewEncoder(writer) +} + +func (j customJsonApi) NewDecoder(reader io.Reader) api.JsonDecoder { + return customConfig.NewDecoder(reader) +} + +//region Time Extension + +var ( + zeroTime = time.Time{} + timeType = reflect2.TypeOfPtr((*time.Time)(nil)).Elem() + defaultTimeCodec = &timeCodec{} +) + +type TimeEx struct { + jsoniter.DummyExtension +} + +func (te *TimeEx) CreateDecoder(typ reflect2.Type) jsoniter.ValDecoder { + if typ == timeType { + return defaultTimeCodec + } + return nil +} + +func (te *TimeEx) CreateEncoder(typ reflect2.Type) jsoniter.ValEncoder { + if typ == timeType { + return defaultTimeCodec + } + return nil +} + +type timeCodec struct { +} + +func (tc timeCodec) IsEmpty(ptr unsafe.Pointer) bool { + t := *((*time.Time)(ptr)) + return t == zeroTime +} + +func (tc timeCodec) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { + t := *((*time.Time)(ptr)) + if t == zeroTime { + stream.WriteNil() + return + } + stream.WriteString(t.In(time.Local).Format("2006-01-02 15:04:05.000")) +} + +func (tc timeCodec) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { + ts := iter.ReadString() + if len(ts) == 0 { + *((*time.Time)(ptr)) = zeroTime + return + } + t, err := time.ParseInLocation("2006-01-02 15:04:05.000", ts, time.Local) + if err != nil { + panic(err) + } + *((*time.Time)(ptr)) = t +} + +//endregion + +//region *Time Extension + +var ( + timePointerType = reflect2.TypeOfPtr((**time.Time)(nil)).Elem() + defaultTimePointerCodec = &timePointerCodec{} +) + +type TimePointerEx struct { + jsoniter.DummyExtension +} + +func (tpe *TimePointerEx) CreateDecoder(typ reflect2.Type) jsoniter.ValDecoder { + if typ == timePointerType { + return defaultTimePointerCodec + } + return nil +} + +func (tpe *TimePointerEx) CreateEncoder(typ reflect2.Type) jsoniter.ValEncoder { + if typ == timePointerType { + return defaultTimePointerCodec + } + return nil +} + +type timePointerCodec struct { +} + +func (tpc timePointerCodec) IsEmpty(ptr unsafe.Pointer) bool { + t := *((**time.Time)(ptr)) + return t == nil || *t == zeroTime +} + +func (tpc timePointerCodec) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { + t := *((**time.Time)(ptr)) + if t == nil || *t == zeroTime { + stream.WriteNil() + return + } + stream.WriteString(t.In(time.Local).Format("2006-01-02 15:04:05.000")) +} + +func (tpc timePointerCodec) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { + ts := iter.ReadString() + if len(ts) == 0 { + *((**time.Time)(ptr)) = nil + return + } + t, err := time.ParseInLocation("2006-01-02 15:04:05.000", ts, time.Local) + if err != nil { + panic(err) + } + *((**time.Time)(ptr)) = &t +} + +//endregion diff --git a/codec/api/json.go b/codec/api/json.go new file mode 100644 index 00000000..fce49e78 --- /dev/null +++ b/codec/api/json.go @@ -0,0 +1,54 @@ +// Copyright 2022 Gin Core Team. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package api + +import "io" + +// JsonApi the api for json codec. +type JsonApi interface { + Marshal(v any) ([]byte, error) + Unmarshal(data []byte, v any) error + MarshalIndent(v any, prefix, indent string) ([]byte, error) + NewEncoder(writer io.Writer) JsonEncoder + NewDecoder(reader io.Reader) JsonDecoder +} + +// A JsonEncoder interface writes JSON values to an output stream. +type JsonEncoder interface { + // SetEscapeHTML specifies whether problematic HTML characters + // should be escaped inside JSON quoted strings. + // The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e + // to avoid certain safety problems that can arise when embedding JSON in HTML. + // + // In non-HTML settings where the escaping interferes with the readability + // of the output, SetEscapeHTML(false) disables this behavior. + SetEscapeHTML(on bool) + + // Encode writes the JSON encoding of v to the stream, + // followed by a newline character. + // + // See the documentation for Marshal for details about the + // conversion of Go values to JSON. + Encode(v interface{}) error +} + +// A JsonDecoder interface reads and decodes JSON values from an input stream. +type JsonDecoder interface { + // UseNumber causes the Decoder to unmarshal a number into an interface{} as a + // Number instead of as a float64. + UseNumber() + + // DisallowUnknownFields causes the Decoder to return an error when the destination + // is a struct and the input contains object keys which do not match any + // non-ignored, exported fields in the destination. + DisallowUnknownFields() + + // Decode reads the next JSON-encoded value from its + // input and stores it in the value pointed to by v. + // + // See the documentation for Unmarshal for details about + // the conversion of JSON into a Go value. + Decode(v interface{}) error +} diff --git a/codec/json/api.go b/codec/json/api.go new file mode 100644 index 00000000..3a92f69e --- /dev/null +++ b/codec/json/api.go @@ -0,0 +1,5 @@ +package json + +import "github.com/gin-gonic/gin/codec/api" + +var Api api.JsonApi diff --git a/codec/json/go_json.go b/codec/json/go_json.go index 47c35598..73c2c442 100644 --- a/codec/json/go_json.go +++ b/codec/json/go_json.go @@ -6,17 +6,36 @@ package json -import json "github.com/goccy/go-json" +import ( + "io" -var ( - // Marshal is exported by gin/json package. - Marshal = json.Marshal - // Unmarshal is exported by gin/json package. - Unmarshal = json.Unmarshal - // MarshalIndent is exported by gin/json package. - MarshalIndent = json.MarshalIndent - // NewDecoder is exported by gin/json package. - NewDecoder = json.NewDecoder - // NewEncoder is exported by gin/json package. - NewEncoder = json.NewEncoder + "github.com/gin-gonic/gin/codec/api" + "github.com/goccy/go-json" ) + +func init() { + Api = gojsonApi{} +} + +type gojsonApi struct { +} + +func (j gojsonApi) Marshal(v any) ([]byte, error) { + return json.Marshal(v) +} + +func (j gojsonApi) Unmarshal(data []byte, v any) error { + return json.Unmarshal(data, v) +} + +func (j gojsonApi) MarshalIndent(v any, prefix, indent string) ([]byte, error) { + return json.MarshalIndent(v, prefix, indent) +} + +func (j gojsonApi) NewEncoder(writer io.Writer) api.JsonEncoder { + return json.NewEncoder(writer) +} + +func (j gojsonApi) NewDecoder(reader io.Reader) api.JsonDecoder { + return json.NewDecoder(reader) +} diff --git a/codec/json/json.go b/codec/json/json.go index c7ee83eb..c4c3882f 100644 --- a/codec/json/json.go +++ b/codec/json/json.go @@ -6,17 +6,36 @@ package json -import "encoding/json" +import ( + "encoding/json" + "io" -var ( - // Marshal is exported by gin/json package. - Marshal = json.Marshal - // Unmarshal is exported by gin/json package. - Unmarshal = json.Unmarshal - // MarshalIndent is exported by gin/json package. - MarshalIndent = json.MarshalIndent - // NewDecoder is exported by gin/json package. - NewDecoder = json.NewDecoder - // NewEncoder is exported by gin/json package. - NewEncoder = json.NewEncoder + "github.com/gin-gonic/gin/codec/api" ) + +func init() { + Api = jsonApi{} +} + +type jsonApi struct { +} + +func (j jsonApi) Marshal(v any) ([]byte, error) { + return json.Marshal(v) +} + +func (j jsonApi) Unmarshal(data []byte, v any) error { + return json.Unmarshal(data, v) +} + +func (j jsonApi) MarshalIndent(v any, prefix, indent string) ([]byte, error) { + return json.MarshalIndent(v, prefix, indent) +} + +func (j jsonApi) NewEncoder(writer io.Writer) api.JsonEncoder { + return json.NewEncoder(writer) +} + +func (j jsonApi) NewDecoder(reader io.Reader) api.JsonDecoder { + return json.NewDecoder(reader) +} diff --git a/codec/json/jsoniter.go b/codec/json/jsoniter.go index 45ed16ba..f4ddb3cc 100644 --- a/codec/json/jsoniter.go +++ b/codec/json/jsoniter.go @@ -6,18 +6,38 @@ package json -import jsoniter "github.com/json-iterator/go" +import ( + "io" -var ( - json = jsoniter.ConfigCompatibleWithStandardLibrary - // Marshal is exported by gin/json package. - Marshal = json.Marshal - // Unmarshal is exported by gin/json package. - Unmarshal = json.Unmarshal - // MarshalIndent is exported by gin/json package. - MarshalIndent = json.MarshalIndent - // NewDecoder is exported by gin/json package. - NewDecoder = json.NewDecoder - // NewEncoder is exported by gin/json package. - NewEncoder = json.NewEncoder + "github.com/gin-gonic/gin/codec/api" + jsoniter "github.com/json-iterator/go" ) + +func init() { + Api = jsoniterApi{} +} + +var json = jsoniter.ConfigCompatibleWithStandardLibrary + +type jsoniterApi struct { +} + +func (j jsoniterApi) Marshal(v any) ([]byte, error) { + return json.Marshal(v) +} + +func (j jsoniterApi) Unmarshal(data []byte, v any) error { + return json.Unmarshal(data, v) +} + +func (j jsoniterApi) MarshalIndent(v any, prefix, indent string) ([]byte, error) { + return json.MarshalIndent(v, prefix, indent) +} + +func (j jsoniterApi) NewEncoder(writer io.Writer) api.JsonEncoder { + return json.NewEncoder(writer) +} + +func (j jsoniterApi) NewDecoder(reader io.Reader) api.JsonDecoder { + return json.NewDecoder(reader) +} diff --git a/codec/json/sonic.go b/codec/json/sonic.go index 529e16d0..22836fa8 100644 --- a/codec/json/sonic.go +++ b/codec/json/sonic.go @@ -6,18 +6,38 @@ package json -import "github.com/bytedance/sonic" +import ( + "io" -var ( - json = sonic.ConfigStd - // Marshal is exported by gin/json package. - Marshal = json.Marshal - // Unmarshal is exported by gin/json package. - Unmarshal = json.Unmarshal - // MarshalIndent is exported by gin/json package. - MarshalIndent = json.MarshalIndent - // NewDecoder is exported by gin/json package. - NewDecoder = json.NewDecoder - // NewEncoder is exported by gin/json package. - NewEncoder = json.NewEncoder + "github.com/bytedance/sonic" + "github.com/gin-gonic/gin/codec/api" ) + +func init() { + Api = sonicApi{} +} + +var json = sonic.ConfigStd + +type sonicApi struct { +} + +func (j sonicApi) Marshal(v any) ([]byte, error) { + return json.Marshal(v) +} + +func (j sonicApi) Unmarshal(data []byte, v any) error { + return json.Unmarshal(data, v) +} + +func (j sonicApi) MarshalIndent(v any, prefix, indent string) ([]byte, error) { + return json.MarshalIndent(v, prefix, indent) +} + +func (j sonicApi) NewEncoder(writer io.Writer) api.JsonEncoder { + return json.NewEncoder(writer) +} + +func (j sonicApi) NewDecoder(reader io.Reader) api.JsonDecoder { + return json.NewDecoder(reader) +} diff --git a/errors.go b/errors.go index a6fda985..6a839f4a 100644 --- a/errors.go +++ b/errors.go @@ -77,7 +77,7 @@ func (msg *Error) JSON() any { // MarshalJSON implements the json.Marshaller interface. func (msg *Error) MarshalJSON() ([]byte, error) { - return json.Marshal(msg.JSON()) + return json.Api.Marshal(msg.JSON()) } // Error implements the error interface. @@ -157,7 +157,7 @@ func (a errorMsgs) JSON() any { // MarshalJSON implements the json.Marshaller interface. func (a errorMsgs) MarshalJSON() ([]byte, error) { - return json.Marshal(a.JSON()) + return json.Api.Marshal(a.JSON()) } func (a errorMsgs) String() string { diff --git a/errors_test.go b/errors_test.go index 334751b8..c70cd111 100644 --- a/errors_test.go +++ b/errors_test.go @@ -33,7 +33,7 @@ func TestError(t *testing.T) { "meta": "some data", }, err.JSON()) - jsonBytes, _ := json.Marshal(err) + jsonBytes, _ := json.Api.Marshal(err) assert.Equal(t, "{\"error\":\"test error\",\"meta\":\"some data\"}", string(jsonBytes)) err.SetMeta(H{ //nolint: errcheck @@ -92,13 +92,13 @@ Error #03: third H{"error": "second", "meta": "some data"}, H{"error": "third", "status": "400"}, }, errs.JSON()) - jsonBytes, _ := json.Marshal(errs) + jsonBytes, _ := json.Api.Marshal(errs) assert.Equal(t, "[{\"error\":\"first\"},{\"error\":\"second\",\"meta\":\"some data\"},{\"error\":\"third\",\"status\":\"400\"}]", string(jsonBytes)) errs = errorMsgs{ {Err: errors.New("first"), Type: ErrorTypePrivate}, } assert.Equal(t, H{"error": "first"}, errs.JSON()) - jsonBytes, _ = json.Marshal(errs) + jsonBytes, _ = json.Api.Marshal(errs) assert.Equal(t, "{\"error\":\"first\"}", string(jsonBytes)) errs = errorMsgs{} diff --git a/go.mod b/go.mod index 035c2dea..027f0d5f 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/goccy/go-json v0.10.2 github.com/json-iterator/go v1.1.12 github.com/mattn/go-isatty v0.0.20 + github.com/modern-go/reflect2 v1.0.2 github.com/pelletier/go-toml/v2 v2.2.2 github.com/quic-go/quic-go v0.43.1 github.com/stretchr/testify v1.9.0 @@ -32,7 +33,6 @@ require ( github.com/kr/pretty v0.3.1 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect - github.com/modern-go/reflect2 v1.0.2 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect diff --git a/render/json.go b/render/json.go index f7ada44a..afdf5c40 100644 --- a/render/json.go +++ b/render/json.go @@ -65,7 +65,7 @@ func (r JSON) WriteContentType(w http.ResponseWriter) { // WriteJSON marshals the given interface object and writes it with custom ContentType. func WriteJSON(w http.ResponseWriter, obj any) error { writeContentType(w, jsonContentType) - jsonBytes, err := json.Marshal(obj) + jsonBytes, err := json.Api.Marshal(obj) if err != nil { return err } @@ -76,7 +76,7 @@ func WriteJSON(w http.ResponseWriter, obj any) error { // Render (IndentedJSON) marshals the given interface object and writes it with custom ContentType. func (r IndentedJSON) Render(w http.ResponseWriter) error { r.WriteContentType(w) - jsonBytes, err := json.MarshalIndent(r.Data, "", " ") + jsonBytes, err := json.Api.MarshalIndent(r.Data, "", " ") if err != nil { return err } @@ -92,7 +92,7 @@ func (r IndentedJSON) WriteContentType(w http.ResponseWriter) { // Render (SecureJSON) marshals the given interface object and writes it with custom ContentType. func (r SecureJSON) Render(w http.ResponseWriter) error { r.WriteContentType(w) - jsonBytes, err := json.Marshal(r.Data) + jsonBytes, err := json.Api.Marshal(r.Data) if err != nil { return err } @@ -115,7 +115,7 @@ func (r SecureJSON) WriteContentType(w http.ResponseWriter) { // Render (JsonpJSON) marshals the given interface object and writes it and its callback with custom ContentType. func (r JsonpJSON) Render(w http.ResponseWriter) (err error) { r.WriteContentType(w) - ret, err := json.Marshal(r.Data) + ret, err := json.Api.Marshal(r.Data) if err != nil { return err } @@ -153,7 +153,7 @@ func (r JsonpJSON) WriteContentType(w http.ResponseWriter) { // Render (AsciiJSON) marshals the given interface object and writes it with custom ContentType. func (r AsciiJSON) Render(w http.ResponseWriter) (err error) { r.WriteContentType(w) - ret, err := json.Marshal(r.Data) + ret, err := json.Api.Marshal(r.Data) if err != nil { return err } @@ -179,7 +179,7 @@ func (r AsciiJSON) WriteContentType(w http.ResponseWriter) { // Render (PureJSON) writes custom ContentType and encodes the given interface object. func (r PureJSON) Render(w http.ResponseWriter) error { r.WriteContentType(w) - encoder := json.NewEncoder(w) + encoder := json.Api.NewEncoder(w) encoder.SetEscapeHTML(false) return encoder.Encode(r.Data) } diff --git a/render/render_test.go b/render/render_test.go index ad633b00..79a2e33f 100644 --- a/render/render_test.go +++ b/render/render_test.go @@ -15,7 +15,7 @@ import ( "strings" "testing" - "github.com/gin-gonic/gin/internal/json" + "github.com/gin-gonic/gin/codec/json" testdata "github.com/gin-gonic/gin/testdata/protoexample" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -173,7 +173,7 @@ func TestRenderJsonpJSONError(t *testing.T) { err = jsonpJSON.Render(ew) assert.Equal(t, `write "`+`(`+`" error`, err.Error()) - data, _ := json.Marshal(jsonpJSON.Data) // error was returned while writing data + data, _ := json.Api.Marshal(jsonpJSON.Data) // error was returned while writing data ew.bufString = string(data) err = jsonpJSON.Render(ew) assert.Equal(t, `write "`+string(data)+`" error`, err.Error())