mirror of https://github.com/goccy/go-json.git
Support context for MarshalJSON and UnmarshalJSON
This commit is contained in:
parent
a2ba5e8bcc
commit
cd7fb7392f
43
decode.go
43
decode.go
|
@ -1,6 +1,7 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
@ -39,7 +40,7 @@ func unmarshal(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
|
|||
}
|
||||
ctx := decoder.TakeRuntimeContext()
|
||||
ctx.Buf = src
|
||||
ctx.Option.Flag = 0
|
||||
ctx.Option.Flags = 0
|
||||
for _, optFunc := range optFuncs {
|
||||
optFunc(ctx.Option)
|
||||
}
|
||||
|
@ -52,6 +53,36 @@ func unmarshal(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
|
|||
return validateEndBuf(src, cursor)
|
||||
}
|
||||
|
||||
func unmarshalContext(ctx context.Context, data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
|
||||
src := make([]byte, len(data)+1) // append nul byte to the end
|
||||
copy(src, data)
|
||||
|
||||
header := (*emptyInterface)(unsafe.Pointer(&v))
|
||||
|
||||
if err := validateType(header.typ, uintptr(header.ptr)); err != nil {
|
||||
return err
|
||||
}
|
||||
dec, err := decoder.CompileToGetDecoder(header.typ)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rctx := decoder.TakeRuntimeContext()
|
||||
rctx.Buf = src
|
||||
rctx.Option.Flags = 0
|
||||
rctx.Option.Flags |= decoder.ContextOption
|
||||
rctx.Option.Context = ctx
|
||||
for _, optFunc := range optFuncs {
|
||||
optFunc(rctx.Option)
|
||||
}
|
||||
cursor, err := dec.Decode(rctx, 0, 0, header.ptr)
|
||||
if err != nil {
|
||||
decoder.ReleaseRuntimeContext(rctx)
|
||||
return err
|
||||
}
|
||||
decoder.ReleaseRuntimeContext(rctx)
|
||||
return validateEndBuf(src, cursor)
|
||||
}
|
||||
|
||||
func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
|
||||
src := make([]byte, len(data)+1) // append nul byte to the end
|
||||
copy(src, data)
|
||||
|
@ -68,7 +99,7 @@ func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc)
|
|||
|
||||
ctx := decoder.TakeRuntimeContext()
|
||||
ctx.Buf = src
|
||||
ctx.Option.Flag = 0
|
||||
ctx.Option.Flags = 0
|
||||
for _, optFunc := range optFuncs {
|
||||
optFunc(ctx.Option)
|
||||
}
|
||||
|
@ -137,6 +168,14 @@ func (d *Decoder) Decode(v interface{}) error {
|
|||
return d.DecodeWithOption(v)
|
||||
}
|
||||
|
||||
// DecodeContext reads the next JSON-encoded value from its
|
||||
// input and stores it in the value pointed to by v with context.Context.
|
||||
func (d *Decoder) DecodeContext(ctx context.Context, v interface{}) error {
|
||||
d.s.Option.Flags |= decoder.ContextOption
|
||||
d.s.Option.Context = ctx
|
||||
return d.DecodeWithOption(v)
|
||||
}
|
||||
|
||||
func (d *Decoder) DecodeWithOption(v interface{}, optFuncs ...DecodeOptionFunc) error {
|
||||
header := (*emptyInterface)(unsafe.Pointer(&v))
|
||||
typ := header.typ
|
||||
|
|
|
@ -2,6 +2,7 @@ package json_test
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding"
|
||||
stdjson "encoding/json"
|
||||
"errors"
|
||||
|
@ -3620,3 +3621,48 @@ func TestDecodeEscapedCharField(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
type unmarshalContextKey struct{}
|
||||
|
||||
type unmarshalContextStructType struct {
|
||||
v int
|
||||
}
|
||||
|
||||
func (t *unmarshalContextStructType) UnmarshalJSON(ctx context.Context, b []byte) error {
|
||||
v := ctx.Value(unmarshalContextKey{})
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to propagate parent context.Context")
|
||||
}
|
||||
if s != "hello" {
|
||||
return fmt.Errorf("failed to propagate parent context.Context")
|
||||
}
|
||||
t.v = 100
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDecodeContextOption(t *testing.T) {
|
||||
src := []byte("10")
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
t.Run("UnmarshalContext", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), unmarshalContextKey{}, "hello")
|
||||
var v unmarshalContextStructType
|
||||
if err := json.UnmarshalContext(ctx, src, &v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if v.v != 100 {
|
||||
t.Fatal("failed to decode with context")
|
||||
}
|
||||
})
|
||||
t.Run("DecodeContext", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), unmarshalContextKey{}, "hello")
|
||||
var v unmarshalContextStructType
|
||||
if err := json.NewDecoder(buf).DecodeContext(ctx, &v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if v.v != 100 {
|
||||
t.Fatal("failed to decode with context")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
43
encode.go
43
encode.go
|
@ -1,6 +1,7 @@
|
|||
package json
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"unsafe"
|
||||
|
||||
|
@ -35,6 +36,7 @@ func (e *Encoder) Encode(v interface{}) error {
|
|||
// EncodeWithOption call Encode with EncodeOption.
|
||||
func (e *Encoder) EncodeWithOption(v interface{}, optFuncs ...EncodeOptionFunc) error {
|
||||
ctx := encoder.TakeRuntimeContext()
|
||||
ctx.Option.Flag = 0
|
||||
|
||||
err := e.encodeWithOption(ctx, v, optFuncs...)
|
||||
|
||||
|
@ -42,8 +44,20 @@ func (e *Encoder) EncodeWithOption(v interface{}, optFuncs ...EncodeOptionFunc)
|
|||
return err
|
||||
}
|
||||
|
||||
// EncodeContext call Encode with context.Context and EncodeOption.
|
||||
func (e *Encoder) EncodeContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) error {
|
||||
rctx := encoder.TakeRuntimeContext()
|
||||
rctx.Option.Flag = 0
|
||||
rctx.Option.Flag |= encoder.ContextOption
|
||||
rctx.Option.Context = ctx
|
||||
|
||||
err := e.encodeWithOption(rctx, v, optFuncs...)
|
||||
|
||||
encoder.ReleaseRuntimeContext(rctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *Encoder) encodeWithOption(ctx *encoder.RuntimeContext, v interface{}, optFuncs ...EncodeOptionFunc) error {
|
||||
ctx.Option.Flag = 0
|
||||
if e.enabledHTMLEscape {
|
||||
ctx.Option.Flag |= encoder.HTMLEscapeOption
|
||||
}
|
||||
|
@ -94,6 +108,33 @@ func (e *Encoder) SetIndent(prefix, indent string) {
|
|||
e.enabledIndent = true
|
||||
}
|
||||
|
||||
func marshalContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
|
||||
rctx := encoder.TakeRuntimeContext()
|
||||
rctx.Option.Flag = 0
|
||||
rctx.Option.Flag = encoder.HTMLEscapeOption | encoder.ContextOption
|
||||
rctx.Option.Context = ctx
|
||||
for _, optFunc := range optFuncs {
|
||||
optFunc(rctx.Option)
|
||||
}
|
||||
|
||||
buf, err := encode(rctx, v)
|
||||
if err != nil {
|
||||
encoder.ReleaseRuntimeContext(rctx)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// this line exists to escape call of `runtime.makeslicecopy` .
|
||||
// if use `make([]byte, len(buf)-1)` and `copy(copied, buf)`,
|
||||
// dst buffer size and src buffer size are differrent.
|
||||
// in this case, compiler uses `runtime.makeslicecopy`, but it is slow.
|
||||
buf = buf[:len(buf)-1]
|
||||
copied := make([]byte, len(buf))
|
||||
copy(copied, buf)
|
||||
|
||||
encoder.ReleaseRuntimeContext(rctx)
|
||||
return copied, nil
|
||||
}
|
||||
|
||||
func marshal(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
|
||||
ctx := encoder.TakeRuntimeContext()
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package json_test
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding"
|
||||
stdjson "encoding/json"
|
||||
"errors"
|
||||
|
@ -1918,3 +1919,42 @@ func TestEncodeMapKeyTypeInterface(t *testing.T) {
|
|||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
type marshalContextKey struct{}
|
||||
|
||||
type marshalContextStructType struct{}
|
||||
|
||||
func (t *marshalContextStructType) MarshalJSON(ctx context.Context) ([]byte, error) {
|
||||
v := ctx.Value(marshalContextKey{})
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to propagate parent context.Context")
|
||||
}
|
||||
if s != "hello" {
|
||||
return nil, fmt.Errorf("failed to propagate parent context.Context")
|
||||
}
|
||||
return []byte(`"success"`), nil
|
||||
}
|
||||
|
||||
func TestEncodeContextOption(t *testing.T) {
|
||||
t.Run("MarshalContext", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello")
|
||||
b, err := json.MarshalContext(ctx, &marshalContextStructType{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(b) != `"success"` {
|
||||
t.Fatal("failed to encode with MarshalerContext")
|
||||
}
|
||||
})
|
||||
t.Run("EncodeContext", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello")
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
if err := json.NewEncoder(buf).EncodeContext(ctx, &marshalContextStructType{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if buf.String() != "\"success\"\n" {
|
||||
t.Fatal("failed to encode with EncodeContext")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ func compileToGetDecoderSlowPath(typeptr uintptr, typ *runtime.Type) (Decoder, e
|
|||
|
||||
func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
|
||||
switch {
|
||||
case runtime.PtrTo(typ).Implements(unmarshalJSONType):
|
||||
case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
|
||||
return newUnmarshalJSONDecoder(runtime.PtrTo(typ), "", ""), nil
|
||||
case runtime.PtrTo(typ).Implements(unmarshalTextType):
|
||||
return newUnmarshalTextDecoder(runtime.PtrTo(typ), "", ""), nil
|
||||
|
@ -70,7 +70,7 @@ func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (De
|
|||
|
||||
func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
|
||||
switch {
|
||||
case runtime.PtrTo(typ).Implements(unmarshalJSONType):
|
||||
case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
|
||||
return newUnmarshalJSONDecoder(runtime.PtrTo(typ), structName, fieldName), nil
|
||||
case runtime.PtrTo(typ).Implements(unmarshalTextType):
|
||||
return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil
|
||||
|
@ -133,7 +133,7 @@ func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecode
|
|||
|
||||
func isStringTagSupportedType(typ *runtime.Type) bool {
|
||||
switch {
|
||||
case runtime.PtrTo(typ).Implements(unmarshalJSONType):
|
||||
case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
|
||||
return false
|
||||
case runtime.PtrTo(typ).Implements(unmarshalTextType):
|
||||
return false
|
||||
|
@ -494,3 +494,7 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
|
|||
structDec.tryOptimize()
|
||||
return structDec, nil
|
||||
}
|
||||
|
||||
func implementsUnmarshalJSONType(typ *runtime.Type) bool {
|
||||
return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType)
|
||||
}
|
||||
|
|
|
@ -117,6 +117,21 @@ func decodeStreamUnmarshaler(s *Stream, depth int64, unmarshaler json.Unmarshale
|
|||
return nil
|
||||
}
|
||||
|
||||
func decodeStreamUnmarshalerContext(s *Stream, depth int64, unmarshaler unmarshalerContext) error {
|
||||
start := s.cursor
|
||||
if err := s.skipValue(depth); err != nil {
|
||||
return err
|
||||
}
|
||||
src := s.buf[start:s.cursor]
|
||||
dst := make([]byte, len(src))
|
||||
copy(dst, src)
|
||||
|
||||
if err := unmarshaler.UnmarshalJSON(s.Option.Context, dst); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler json.Unmarshaler) (int64, error) {
|
||||
cursor = skipWhiteSpace(buf, cursor)
|
||||
start := cursor
|
||||
|
@ -134,6 +149,23 @@ func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler json.Unmarsh
|
|||
return end, nil
|
||||
}
|
||||
|
||||
func decodeUnmarshalerContext(ctx *RuntimeContext, buf []byte, cursor, depth int64, unmarshaler unmarshalerContext) (int64, error) {
|
||||
cursor = skipWhiteSpace(buf, cursor)
|
||||
start := cursor
|
||||
end, err := skipValue(buf, cursor, depth)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
src := buf[start:end]
|
||||
dst := make([]byte, len(src))
|
||||
copy(dst, src)
|
||||
|
||||
if err := unmarshaler.UnmarshalJSON(ctx.Option.Context, dst); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return end, nil
|
||||
}
|
||||
|
||||
func decodeStreamTextUnmarshaler(s *Stream, depth int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error {
|
||||
start := s.cursor
|
||||
if err := s.skipValue(depth); err != nil {
|
||||
|
@ -260,6 +292,9 @@ func (d *interfaceDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer
|
|||
}))
|
||||
rv := reflect.ValueOf(runtimeInterfaceValue)
|
||||
if rv.NumMethod() > 0 && rv.CanInterface() {
|
||||
if u, ok := rv.Interface().(unmarshalerContext); ok {
|
||||
return decodeStreamUnmarshalerContext(s, depth, u)
|
||||
}
|
||||
if u, ok := rv.Interface().(json.Unmarshaler); ok {
|
||||
return decodeStreamUnmarshaler(s, depth, u)
|
||||
}
|
||||
|
@ -317,6 +352,9 @@ func (d *interfaceDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p un
|
|||
}))
|
||||
rv := reflect.ValueOf(runtimeInterfaceValue)
|
||||
if rv.NumMethod() > 0 && rv.CanInterface() {
|
||||
if u, ok := rv.Interface().(unmarshalerContext); ok {
|
||||
return decodeUnmarshalerContext(ctx, buf, cursor, depth, u)
|
||||
}
|
||||
if u, ok := rv.Interface().(json.Unmarshaler); ok {
|
||||
return decodeUnmarshaler(buf, cursor, depth, u)
|
||||
}
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
package decoder
|
||||
|
||||
type OptionFlag int
|
||||
import "context"
|
||||
|
||||
type OptionFlags uint8
|
||||
|
||||
const (
|
||||
FirstWinOption OptionFlag = 1 << iota
|
||||
FirstWinOption OptionFlags = 1 << iota
|
||||
ContextOption
|
||||
)
|
||||
|
||||
type Option struct {
|
||||
Flag OptionFlag
|
||||
Flags OptionFlags
|
||||
Context context.Context
|
||||
}
|
||||
|
|
|
@ -665,7 +665,7 @@ func (d *structDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) e
|
|||
seenFields map[int]struct{}
|
||||
seenFieldNum int
|
||||
)
|
||||
firstWin := (s.Option.Flag & FirstWinOption) != 0
|
||||
firstWin := (s.Option.Flags & FirstWinOption) != 0
|
||||
if firstWin {
|
||||
seenFields = make(map[int]struct{}, d.fieldUniqueNameNum)
|
||||
}
|
||||
|
@ -752,7 +752,7 @@ func (d *structDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf
|
|||
seenFields map[int]struct{}
|
||||
seenFieldNum int
|
||||
)
|
||||
firstWin := (ctx.Option.Flag & FirstWinOption) != 0
|
||||
firstWin := (ctx.Option.Flags & FirstWinOption) != 0
|
||||
if firstWin {
|
||||
seenFields = make(map[int]struct{}, d.fieldUniqueNameNum)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package decoder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
|
@ -17,7 +18,12 @@ const (
|
|||
maxDecodeNestingDepth = 10000
|
||||
)
|
||||
|
||||
type unmarshalerContext interface {
|
||||
UnmarshalJSON(context.Context, []byte) error
|
||||
}
|
||||
|
||||
var (
|
||||
unmarshalJSONType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
|
||||
unmarshalTextType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
|
||||
unmarshalJSONType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
|
||||
unmarshalJSONContextType = reflect.TypeOf((*unmarshalerContext)(nil)).Elem()
|
||||
unmarshalTextType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
|
||||
)
|
||||
|
|
|
@ -46,9 +46,16 @@ func (d *unmarshalJSONDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Poi
|
|||
typ: d.typ,
|
||||
ptr: p,
|
||||
}))
|
||||
if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil {
|
||||
d.annotateError(s.cursor, err)
|
||||
return err
|
||||
if (s.Option.Flags & ContextOption) != 0 {
|
||||
if err := v.(unmarshalerContext).UnmarshalJSON(s.Option.Context, dst); err != nil {
|
||||
d.annotateError(s.cursor, err)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil {
|
||||
d.annotateError(s.cursor, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -69,9 +76,16 @@ func (d *unmarshalJSONDecoder) Decode(ctx *RuntimeContext, cursor, depth int64,
|
|||
typ: d.typ,
|
||||
ptr: p,
|
||||
}))
|
||||
if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil {
|
||||
d.annotateError(cursor, err)
|
||||
return 0, err
|
||||
if (ctx.Option.Flags & ContextOption) != 0 {
|
||||
if err := v.(unmarshalerContext).UnmarshalJSON(ctx.Option.Context, dst); err != nil {
|
||||
d.annotateError(cursor, err)
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil {
|
||||
d.annotateError(cursor, err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return end, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package encoder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
@ -13,13 +14,18 @@ import (
|
|||
"github.com/goccy/go-json/internal/runtime"
|
||||
)
|
||||
|
||||
type marshalerContext interface {
|
||||
MarshalJSON(context.Context) ([]byte, error)
|
||||
}
|
||||
|
||||
var (
|
||||
marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
|
||||
marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
|
||||
jsonNumberType = reflect.TypeOf(json.Number(""))
|
||||
cachedOpcodeSets []*OpcodeSet
|
||||
cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet
|
||||
typeAddr *runtime.TypeAddr
|
||||
marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
|
||||
marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem()
|
||||
marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
|
||||
jsonNumberType = reflect.TypeOf(json.Number(""))
|
||||
cachedOpcodeSets []*OpcodeSet
|
||||
cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet
|
||||
typeAddr *runtime.TypeAddr
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -110,7 +116,7 @@ func compileHead(ctx *compileContext) (*Opcode, error) {
|
|||
elem := typ.Elem()
|
||||
if elem.Kind() == reflect.Uint8 {
|
||||
p := runtime.PtrTo(elem)
|
||||
if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) {
|
||||
if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
|
||||
if isPtr {
|
||||
return compileBytesPtr(ctx)
|
||||
}
|
||||
|
@ -340,14 +346,14 @@ func optimizeStructEnd(c *Opcode) {
|
|||
}
|
||||
|
||||
func implementsMarshalJSON(typ *runtime.Type) bool {
|
||||
if !typ.Implements(marshalJSONType) {
|
||||
if !implementsMarshalJSONType(typ) {
|
||||
return false
|
||||
}
|
||||
if typ.Kind() != reflect.Ptr {
|
||||
return true
|
||||
}
|
||||
// type kind is reflect.Ptr
|
||||
if !typ.Elem().Implements(marshalJSONType) {
|
||||
if !implementsMarshalJSONType(typ.Elem()) {
|
||||
return true
|
||||
}
|
||||
// needs to dereference
|
||||
|
@ -384,7 +390,7 @@ func compile(ctx *compileContext, isPtr bool) (*Opcode, error) {
|
|||
elem := typ.Elem()
|
||||
if elem.Kind() == reflect.Uint8 {
|
||||
p := runtime.PtrTo(elem)
|
||||
if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) {
|
||||
if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
|
||||
return compileBytes(ctx)
|
||||
}
|
||||
}
|
||||
|
@ -527,9 +533,12 @@ func compilePtr(ctx *compileContext) (*Opcode, error) {
|
|||
func compileMarshalJSON(ctx *compileContext) (*Opcode, error) {
|
||||
code := newOpCode(ctx, OpMarshalJSON)
|
||||
typ := ctx.typ
|
||||
if !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType) {
|
||||
if isPtrMarshalJSONType(typ) {
|
||||
code.Flags |= AddrForMarshalerFlags
|
||||
}
|
||||
if typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType) {
|
||||
code.Flags |= MarshalerContextFlags
|
||||
}
|
||||
if isNilableType(typ) {
|
||||
code.Flags |= IsNilableTypeFlags
|
||||
} else {
|
||||
|
@ -920,7 +929,7 @@ func compileSlice(ctx *compileContext) (*Opcode, error) {
|
|||
func compileListElem(ctx *compileContext) (*Opcode, error) {
|
||||
typ := ctx.typ
|
||||
switch {
|
||||
case !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType):
|
||||
case isPtrMarshalJSONType(typ):
|
||||
return compileMarshalJSON(ctx)
|
||||
case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType):
|
||||
return compileMarshalText(ctx)
|
||||
|
@ -1534,8 +1543,12 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func implementsMarshalJSONType(typ *runtime.Type) bool {
|
||||
return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType)
|
||||
}
|
||||
|
||||
func isPtrMarshalJSONType(typ *runtime.Type) bool {
|
||||
return !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType)
|
||||
return !implementsMarshalJSONType(typ) && implementsMarshalJSONType(runtime.PtrTo(typ))
|
||||
}
|
||||
|
||||
func isPtrMarshalTextType(typ *runtime.Type) bool {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package encoder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
|
@ -104,6 +105,7 @@ var (
|
|||
)
|
||||
|
||||
type RuntimeContext struct {
|
||||
Context context.Context
|
||||
Buf []byte
|
||||
MarshalBuf []byte
|
||||
Ptrs []uintptr
|
||||
|
|
|
@ -365,13 +365,27 @@ func AppendMarshalJSON(ctx *RuntimeContext, code *Opcode, b []byte, v interface{
|
|||
}
|
||||
}
|
||||
v = rv.Interface()
|
||||
marshaler, ok := v.(json.Marshaler)
|
||||
if !ok {
|
||||
return AppendNull(ctx, b), nil
|
||||
}
|
||||
bb, err := marshaler.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
|
||||
var bb []byte
|
||||
if (code.Flags & MarshalerContextFlags) != 0 {
|
||||
marshaler, ok := v.(marshalerContext)
|
||||
if !ok {
|
||||
return AppendNull(ctx, b), nil
|
||||
}
|
||||
b, err := marshaler.MarshalJSON(ctx.Option.Context)
|
||||
if err != nil {
|
||||
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
|
||||
}
|
||||
bb = b
|
||||
} else {
|
||||
marshaler, ok := v.(json.Marshaler)
|
||||
if !ok {
|
||||
return AppendNull(ctx, b), nil
|
||||
}
|
||||
b, err := marshaler.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
|
||||
}
|
||||
bb = b
|
||||
}
|
||||
marshalBuf := ctx.MarshalBuf[:0]
|
||||
marshalBuf = append(append(marshalBuf, bb...), nul)
|
||||
|
@ -395,13 +409,27 @@ func AppendMarshalJSONIndent(ctx *RuntimeContext, code *Opcode, b []byte, v inte
|
|||
}
|
||||
}
|
||||
v = rv.Interface()
|
||||
marshaler, ok := v.(json.Marshaler)
|
||||
if !ok {
|
||||
return AppendNull(ctx, b), nil
|
||||
}
|
||||
bb, err := marshaler.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
|
||||
var bb []byte
|
||||
if (code.Flags & MarshalerContextFlags) != 0 {
|
||||
marshaler, ok := v.(marshalerContext)
|
||||
if !ok {
|
||||
return AppendNull(ctx, b), nil
|
||||
}
|
||||
b, err := marshaler.MarshalJSON(ctx.Option.Context)
|
||||
if err != nil {
|
||||
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
|
||||
}
|
||||
bb = b
|
||||
} else {
|
||||
marshaler, ok := v.(json.Marshaler)
|
||||
if !ok {
|
||||
return AppendNull(ctx, b), nil
|
||||
}
|
||||
b, err := marshaler.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
|
||||
}
|
||||
bb = b
|
||||
}
|
||||
marshalBuf := ctx.MarshalBuf[:0]
|
||||
marshalBuf = append(append(marshalBuf, bb...), nul)
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
const uintptrSize = 4 << (^uintptr(0) >> 63)
|
||||
|
||||
type OpFlags uint8
|
||||
type OpFlags uint16
|
||||
|
||||
const (
|
||||
AnonymousHeadFlags OpFlags = 1 << 0
|
||||
|
@ -21,6 +21,7 @@ const (
|
|||
AddrForMarshalerFlags OpFlags = 1 << 5
|
||||
IsNextOpPtrTypeFlags OpFlags = 1 << 6
|
||||
IsNilableTypeFlags OpFlags = 1 << 7
|
||||
MarshalerContextFlags OpFlags = 1 << 8
|
||||
)
|
||||
|
||||
type Opcode struct {
|
||||
|
@ -32,9 +33,8 @@ type Opcode struct {
|
|||
Key string // struct field key
|
||||
Offset uint32 // offset size from struct header
|
||||
PtrNum uint8 // pointer number: e.g. double pointer is 2.
|
||||
Flags OpFlags
|
||||
NumBitSize uint8
|
||||
_ [1]uint8 // 1
|
||||
Flags OpFlags
|
||||
|
||||
Type *runtime.Type // go type
|
||||
PrevField *Opcode // prev struct field
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package encoder
|
||||
|
||||
import "context"
|
||||
|
||||
type OptionFlag uint8
|
||||
|
||||
const (
|
||||
|
@ -8,11 +10,13 @@ const (
|
|||
UnorderedMapOption
|
||||
DebugOption
|
||||
ColorizeOption
|
||||
ContextOption
|
||||
)
|
||||
|
||||
type Option struct {
|
||||
Flag OptionFlag
|
||||
ColorScheme *ColorScheme
|
||||
Context context.Context
|
||||
}
|
||||
|
||||
type EncodeFormat struct {
|
||||
|
|
27
json.go
27
json.go
|
@ -2,6 +2,7 @@ package json
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/goccy/go-json/internal/encoder"
|
||||
|
@ -13,6 +14,12 @@ type Marshaler interface {
|
|||
MarshalJSON() ([]byte, error)
|
||||
}
|
||||
|
||||
// MarshalerContext is the interface implemented by types that
|
||||
// can marshal themselves into valid JSON with context.Context.
|
||||
type MarshalerContext interface {
|
||||
MarshalJSON(context.Context) ([]byte, error)
|
||||
}
|
||||
|
||||
// Unmarshaler is the interface implemented by types
|
||||
// that can unmarshal a JSON description of themselves.
|
||||
// The input can be assumed to be a valid encoding of
|
||||
|
@ -25,6 +32,12 @@ type Unmarshaler interface {
|
|||
UnmarshalJSON([]byte) error
|
||||
}
|
||||
|
||||
// UnmarshalerContext is the interface implemented by types
|
||||
// that can unmarshal with context.Context a JSON description of themselves.
|
||||
type UnmarshalerContext interface {
|
||||
UnmarshalJSON(context.Context, []byte) error
|
||||
}
|
||||
|
||||
// Marshal returns the JSON encoding of v.
|
||||
//
|
||||
// Marshal traverses the value v recursively.
|
||||
|
@ -158,11 +171,16 @@ func Marshal(v interface{}) ([]byte, error) {
|
|||
return MarshalWithOption(v)
|
||||
}
|
||||
|
||||
// MarshalNoEscape
|
||||
// MarshalNoEscape returns the JSON encoding of v and doesn't escape v.
|
||||
func MarshalNoEscape(v interface{}) ([]byte, error) {
|
||||
return marshalNoEscape(v)
|
||||
}
|
||||
|
||||
// MarshalContext returns the JSON encoding of v with context.Context and EncodeOption.
|
||||
func MarshalContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
|
||||
return marshalContext(ctx, v, optFuncs...)
|
||||
}
|
||||
|
||||
// MarshalWithOption returns the JSON encoding of v with EncodeOption.
|
||||
func MarshalWithOption(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
|
||||
return marshal(v, optFuncs...)
|
||||
|
@ -258,6 +276,13 @@ func Unmarshal(data []byte, v interface{}) error {
|
|||
return unmarshal(data, v)
|
||||
}
|
||||
|
||||
// UnmarshalContext parses the JSON-encoded data and stores the result
|
||||
// in the value pointed to by v. If you implement the UnmarshalerContext interface,
|
||||
// call it with ctx as an argument.
|
||||
func UnmarshalContext(ctx context.Context, data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
|
||||
return unmarshalContext(ctx, data, v)
|
||||
}
|
||||
|
||||
func UnmarshalWithOption(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
|
||||
return unmarshal(data, v, optFuncs...)
|
||||
}
|
||||
|
|
|
@ -41,6 +41,6 @@ type DecodeOptionFunc func(*DecodeOption)
|
|||
// This behavior has a performance advantage as it allows the subsequent strings to be skipped if all fields have been evaluated.
|
||||
func DecodeFieldPriorityFirstWin() DecodeOptionFunc {
|
||||
return func(opt *DecodeOption) {
|
||||
opt.Flag |= decoder.FirstWinOption
|
||||
opt.Flags |= decoder.FirstWinOption
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue