diff --git a/decode.go b/decode.go index e2cfcd4..d99749d 100644 --- a/decode.go +++ b/decode.go @@ -1,6 +1,7 @@ package json import ( + "context" "fmt" "io" "reflect" @@ -39,7 +40,7 @@ func unmarshal(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error { } ctx := decoder.TakeRuntimeContext() ctx.Buf = src - ctx.Option.Flag = 0 + ctx.Option.Flags = 0 for _, optFunc := range optFuncs { optFunc(ctx.Option) } @@ -52,6 +53,36 @@ func unmarshal(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error { return validateEndBuf(src, cursor) } +func unmarshalContext(ctx context.Context, data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error { + src := make([]byte, len(data)+1) // append nul byte to the end + copy(src, data) + + header := (*emptyInterface)(unsafe.Pointer(&v)) + + if err := validateType(header.typ, uintptr(header.ptr)); err != nil { + return err + } + dec, err := decoder.CompileToGetDecoder(header.typ) + if err != nil { + return err + } + rctx := decoder.TakeRuntimeContext() + rctx.Buf = src + rctx.Option.Flags = 0 + rctx.Option.Flags |= decoder.ContextOption + rctx.Option.Context = ctx + for _, optFunc := range optFuncs { + optFunc(rctx.Option) + } + cursor, err := dec.Decode(rctx, 0, 0, header.ptr) + if err != nil { + decoder.ReleaseRuntimeContext(rctx) + return err + } + decoder.ReleaseRuntimeContext(rctx) + return validateEndBuf(src, cursor) +} + func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error { src := make([]byte, len(data)+1) // append nul byte to the end copy(src, data) @@ -68,7 +99,7 @@ func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) ctx := decoder.TakeRuntimeContext() ctx.Buf = src - ctx.Option.Flag = 0 + ctx.Option.Flags = 0 for _, optFunc := range optFuncs { optFunc(ctx.Option) } @@ -137,6 +168,14 @@ func (d *Decoder) Decode(v interface{}) error { return d.DecodeWithOption(v) } +// DecodeContext reads the next JSON-encoded value from its +// input and stores it in the value pointed to by v with context.Context. +func (d *Decoder) DecodeContext(ctx context.Context, v interface{}) error { + d.s.Option.Flags |= decoder.ContextOption + d.s.Option.Context = ctx + return d.DecodeWithOption(v) +} + func (d *Decoder) DecodeWithOption(v interface{}, optFuncs ...DecodeOptionFunc) error { header := (*emptyInterface)(unsafe.Pointer(&v)) typ := header.typ diff --git a/decode_test.go b/decode_test.go index f3c2622..f578de2 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2,6 +2,7 @@ package json_test import ( "bytes" + "context" "encoding" stdjson "encoding/json" "errors" @@ -3620,3 +3621,48 @@ func TestDecodeEscapedCharField(t *testing.T) { } }) } + +type unmarshalContextKey struct{} + +type unmarshalContextStructType struct { + v int +} + +func (t *unmarshalContextStructType) UnmarshalJSON(ctx context.Context, b []byte) error { + v := ctx.Value(unmarshalContextKey{}) + s, ok := v.(string) + if !ok { + return fmt.Errorf("failed to propagate parent context.Context") + } + if s != "hello" { + return fmt.Errorf("failed to propagate parent context.Context") + } + t.v = 100 + return nil +} + +func TestDecodeContextOption(t *testing.T) { + src := []byte("10") + buf := bytes.NewBuffer(src) + + t.Run("UnmarshalContext", func(t *testing.T) { + ctx := context.WithValue(context.Background(), unmarshalContextKey{}, "hello") + var v unmarshalContextStructType + if err := json.UnmarshalContext(ctx, src, &v); err != nil { + t.Fatal(err) + } + if v.v != 100 { + t.Fatal("failed to decode with context") + } + }) + t.Run("DecodeContext", func(t *testing.T) { + ctx := context.WithValue(context.Background(), unmarshalContextKey{}, "hello") + var v unmarshalContextStructType + if err := json.NewDecoder(buf).DecodeContext(ctx, &v); err != nil { + t.Fatal(err) + } + if v.v != 100 { + t.Fatal("failed to decode with context") + } + }) +} diff --git a/encode.go b/encode.go index ea2bc7c..7f198bd 100644 --- a/encode.go +++ b/encode.go @@ -1,6 +1,7 @@ package json import ( + "context" "io" "unsafe" @@ -35,6 +36,7 @@ func (e *Encoder) Encode(v interface{}) error { // EncodeWithOption call Encode with EncodeOption. func (e *Encoder) EncodeWithOption(v interface{}, optFuncs ...EncodeOptionFunc) error { ctx := encoder.TakeRuntimeContext() + ctx.Option.Flag = 0 err := e.encodeWithOption(ctx, v, optFuncs...) @@ -42,8 +44,20 @@ func (e *Encoder) EncodeWithOption(v interface{}, optFuncs ...EncodeOptionFunc) return err } +// EncodeContext call Encode with context.Context and EncodeOption. +func (e *Encoder) EncodeContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) error { + rctx := encoder.TakeRuntimeContext() + rctx.Option.Flag = 0 + rctx.Option.Flag |= encoder.ContextOption + rctx.Option.Context = ctx + + err := e.encodeWithOption(rctx, v, optFuncs...) + + encoder.ReleaseRuntimeContext(rctx) + return err +} + func (e *Encoder) encodeWithOption(ctx *encoder.RuntimeContext, v interface{}, optFuncs ...EncodeOptionFunc) error { - ctx.Option.Flag = 0 if e.enabledHTMLEscape { ctx.Option.Flag |= encoder.HTMLEscapeOption } @@ -94,6 +108,33 @@ func (e *Encoder) SetIndent(prefix, indent string) { e.enabledIndent = true } +func marshalContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) { + rctx := encoder.TakeRuntimeContext() + rctx.Option.Flag = 0 + rctx.Option.Flag = encoder.HTMLEscapeOption | encoder.ContextOption + rctx.Option.Context = ctx + for _, optFunc := range optFuncs { + optFunc(rctx.Option) + } + + buf, err := encode(rctx, v) + if err != nil { + encoder.ReleaseRuntimeContext(rctx) + return nil, err + } + + // this line exists to escape call of `runtime.makeslicecopy` . + // if use `make([]byte, len(buf)-1)` and `copy(copied, buf)`, + // dst buffer size and src buffer size are differrent. + // in this case, compiler uses `runtime.makeslicecopy`, but it is slow. + buf = buf[:len(buf)-1] + copied := make([]byte, len(buf)) + copy(copied, buf) + + encoder.ReleaseRuntimeContext(rctx) + return copied, nil +} + func marshal(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) { ctx := encoder.TakeRuntimeContext() diff --git a/encode_test.go b/encode_test.go index 715bf4f..41a14cc 100644 --- a/encode_test.go +++ b/encode_test.go @@ -2,6 +2,7 @@ package json_test import ( "bytes" + "context" "encoding" stdjson "encoding/json" "errors" @@ -1918,3 +1919,42 @@ func TestEncodeMapKeyTypeInterface(t *testing.T) { t.Fatal("expected error") } } + +type marshalContextKey struct{} + +type marshalContextStructType struct{} + +func (t *marshalContextStructType) MarshalJSON(ctx context.Context) ([]byte, error) { + v := ctx.Value(marshalContextKey{}) + s, ok := v.(string) + if !ok { + return nil, fmt.Errorf("failed to propagate parent context.Context") + } + if s != "hello" { + return nil, fmt.Errorf("failed to propagate parent context.Context") + } + return []byte(`"success"`), nil +} + +func TestEncodeContextOption(t *testing.T) { + t.Run("MarshalContext", func(t *testing.T) { + ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello") + b, err := json.MarshalContext(ctx, &marshalContextStructType{}) + if err != nil { + t.Fatal(err) + } + if string(b) != `"success"` { + t.Fatal("failed to encode with MarshalerContext") + } + }) + t.Run("EncodeContext", func(t *testing.T) { + ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello") + buf := bytes.NewBuffer([]byte{}) + if err := json.NewEncoder(buf).EncodeContext(ctx, &marshalContextStructType{}); err != nil { + t.Fatal(err) + } + if buf.String() != "\"success\"\n" { + t.Fatal("failed to encode with EncodeContext") + } + }) +} diff --git a/internal/decoder/compile.go b/internal/decoder/compile.go index b57af80..bd56687 100644 --- a/internal/decoder/compile.go +++ b/internal/decoder/compile.go @@ -60,7 +60,7 @@ func compileToGetDecoderSlowPath(typeptr uintptr, typ *runtime.Type) (Decoder, e func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) { switch { - case runtime.PtrTo(typ).Implements(unmarshalJSONType): + case implementsUnmarshalJSONType(runtime.PtrTo(typ)): return newUnmarshalJSONDecoder(runtime.PtrTo(typ), "", ""), nil case runtime.PtrTo(typ).Implements(unmarshalTextType): return newUnmarshalTextDecoder(runtime.PtrTo(typ), "", ""), nil @@ -70,7 +70,7 @@ func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (De func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) { switch { - case runtime.PtrTo(typ).Implements(unmarshalJSONType): + case implementsUnmarshalJSONType(runtime.PtrTo(typ)): return newUnmarshalJSONDecoder(runtime.PtrTo(typ), structName, fieldName), nil case runtime.PtrTo(typ).Implements(unmarshalTextType): return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil @@ -133,7 +133,7 @@ func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecode func isStringTagSupportedType(typ *runtime.Type) bool { switch { - case runtime.PtrTo(typ).Implements(unmarshalJSONType): + case implementsUnmarshalJSONType(runtime.PtrTo(typ)): return false case runtime.PtrTo(typ).Implements(unmarshalTextType): return false @@ -494,3 +494,7 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo structDec.tryOptimize() return structDec, nil } + +func implementsUnmarshalJSONType(typ *runtime.Type) bool { + return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType) +} diff --git a/internal/decoder/interface.go b/internal/decoder/interface.go index 8c82758..ea1b4aa 100644 --- a/internal/decoder/interface.go +++ b/internal/decoder/interface.go @@ -117,6 +117,21 @@ func decodeStreamUnmarshaler(s *Stream, depth int64, unmarshaler json.Unmarshale return nil } +func decodeStreamUnmarshalerContext(s *Stream, depth int64, unmarshaler unmarshalerContext) error { + start := s.cursor + if err := s.skipValue(depth); err != nil { + return err + } + src := s.buf[start:s.cursor] + dst := make([]byte, len(src)) + copy(dst, src) + + if err := unmarshaler.UnmarshalJSON(s.Option.Context, dst); err != nil { + return err + } + return nil +} + func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler json.Unmarshaler) (int64, error) { cursor = skipWhiteSpace(buf, cursor) start := cursor @@ -134,6 +149,23 @@ func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler json.Unmarsh return end, nil } +func decodeUnmarshalerContext(ctx *RuntimeContext, buf []byte, cursor, depth int64, unmarshaler unmarshalerContext) (int64, error) { + cursor = skipWhiteSpace(buf, cursor) + start := cursor + end, err := skipValue(buf, cursor, depth) + if err != nil { + return 0, err + } + src := buf[start:end] + dst := make([]byte, len(src)) + copy(dst, src) + + if err := unmarshaler.UnmarshalJSON(ctx.Option.Context, dst); err != nil { + return 0, err + } + return end, nil +} + func decodeStreamTextUnmarshaler(s *Stream, depth int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error { start := s.cursor if err := s.skipValue(depth); err != nil { @@ -260,6 +292,9 @@ func (d *interfaceDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer })) rv := reflect.ValueOf(runtimeInterfaceValue) if rv.NumMethod() > 0 && rv.CanInterface() { + if u, ok := rv.Interface().(unmarshalerContext); ok { + return decodeStreamUnmarshalerContext(s, depth, u) + } if u, ok := rv.Interface().(json.Unmarshaler); ok { return decodeStreamUnmarshaler(s, depth, u) } @@ -317,6 +352,9 @@ func (d *interfaceDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p un })) rv := reflect.ValueOf(runtimeInterfaceValue) if rv.NumMethod() > 0 && rv.CanInterface() { + if u, ok := rv.Interface().(unmarshalerContext); ok { + return decodeUnmarshalerContext(ctx, buf, cursor, depth, u) + } if u, ok := rv.Interface().(json.Unmarshaler); ok { return decodeUnmarshaler(buf, cursor, depth, u) } diff --git a/internal/decoder/option.go b/internal/decoder/option.go index 1603512..e41f876 100644 --- a/internal/decoder/option.go +++ b/internal/decoder/option.go @@ -1,11 +1,15 @@ package decoder -type OptionFlag int +import "context" + +type OptionFlags uint8 const ( - FirstWinOption OptionFlag = 1 << iota + FirstWinOption OptionFlags = 1 << iota + ContextOption ) type Option struct { - Flag OptionFlag + Flags OptionFlags + Context context.Context } diff --git a/internal/decoder/struct.go b/internal/decoder/struct.go index b99ab0a..d467b0d 100644 --- a/internal/decoder/struct.go +++ b/internal/decoder/struct.go @@ -665,7 +665,7 @@ func (d *structDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) e seenFields map[int]struct{} seenFieldNum int ) - firstWin := (s.Option.Flag & FirstWinOption) != 0 + firstWin := (s.Option.Flags & FirstWinOption) != 0 if firstWin { seenFields = make(map[int]struct{}, d.fieldUniqueNameNum) } @@ -752,7 +752,7 @@ func (d *structDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf seenFields map[int]struct{} seenFieldNum int ) - firstWin := (ctx.Option.Flag & FirstWinOption) != 0 + firstWin := (ctx.Option.Flags & FirstWinOption) != 0 if firstWin { seenFields = make(map[int]struct{}, d.fieldUniqueNameNum) } diff --git a/internal/decoder/type.go b/internal/decoder/type.go index 419ad4a..70e9907 100644 --- a/internal/decoder/type.go +++ b/internal/decoder/type.go @@ -1,6 +1,7 @@ package decoder import ( + "context" "encoding" "encoding/json" "reflect" @@ -17,7 +18,12 @@ const ( maxDecodeNestingDepth = 10000 ) +type unmarshalerContext interface { + UnmarshalJSON(context.Context, []byte) error +} + var ( - unmarshalJSONType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() - unmarshalTextType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + unmarshalJSONType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + unmarshalJSONContextType = reflect.TypeOf((*unmarshalerContext)(nil)).Elem() + unmarshalTextType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() ) diff --git a/internal/decoder/unmarshal_json.go b/internal/decoder/unmarshal_json.go index a5e4fa6..d90f39c 100644 --- a/internal/decoder/unmarshal_json.go +++ b/internal/decoder/unmarshal_json.go @@ -46,9 +46,16 @@ func (d *unmarshalJSONDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Poi typ: d.typ, ptr: p, })) - if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil { - d.annotateError(s.cursor, err) - return err + if (s.Option.Flags & ContextOption) != 0 { + if err := v.(unmarshalerContext).UnmarshalJSON(s.Option.Context, dst); err != nil { + d.annotateError(s.cursor, err) + return err + } + } else { + if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil { + d.annotateError(s.cursor, err) + return err + } } return nil } @@ -69,9 +76,16 @@ func (d *unmarshalJSONDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, typ: d.typ, ptr: p, })) - if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil { - d.annotateError(cursor, err) - return 0, err + if (ctx.Option.Flags & ContextOption) != 0 { + if err := v.(unmarshalerContext).UnmarshalJSON(ctx.Option.Context, dst); err != nil { + d.annotateError(cursor, err) + return 0, err + } + } else { + if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil { + d.annotateError(cursor, err) + return 0, err + } } return end, nil } diff --git a/internal/encoder/compiler.go b/internal/encoder/compiler.go index 56a9604..9654361 100644 --- a/internal/encoder/compiler.go +++ b/internal/encoder/compiler.go @@ -1,6 +1,7 @@ package encoder import ( + "context" "encoding" "encoding/json" "fmt" @@ -13,13 +14,18 @@ import ( "github.com/goccy/go-json/internal/runtime" ) +type marshalerContext interface { + MarshalJSON(context.Context) ([]byte, error) +} + var ( - marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() - marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() - jsonNumberType = reflect.TypeOf(json.Number("")) - cachedOpcodeSets []*OpcodeSet - cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet - typeAddr *runtime.TypeAddr + marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() + marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem() + marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + jsonNumberType = reflect.TypeOf(json.Number("")) + cachedOpcodeSets []*OpcodeSet + cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet + typeAddr *runtime.TypeAddr ) func init() { @@ -110,7 +116,7 @@ func compileHead(ctx *compileContext) (*Opcode, error) { elem := typ.Elem() if elem.Kind() == reflect.Uint8 { p := runtime.PtrTo(elem) - if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) { + if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { if isPtr { return compileBytesPtr(ctx) } @@ -340,14 +346,14 @@ func optimizeStructEnd(c *Opcode) { } func implementsMarshalJSON(typ *runtime.Type) bool { - if !typ.Implements(marshalJSONType) { + if !implementsMarshalJSONType(typ) { return false } if typ.Kind() != reflect.Ptr { return true } // type kind is reflect.Ptr - if !typ.Elem().Implements(marshalJSONType) { + if !implementsMarshalJSONType(typ.Elem()) { return true } // needs to dereference @@ -384,7 +390,7 @@ func compile(ctx *compileContext, isPtr bool) (*Opcode, error) { elem := typ.Elem() if elem.Kind() == reflect.Uint8 { p := runtime.PtrTo(elem) - if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) { + if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { return compileBytes(ctx) } } @@ -527,9 +533,12 @@ func compilePtr(ctx *compileContext) (*Opcode, error) { func compileMarshalJSON(ctx *compileContext) (*Opcode, error) { code := newOpCode(ctx, OpMarshalJSON) typ := ctx.typ - if !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType) { + if isPtrMarshalJSONType(typ) { code.Flags |= AddrForMarshalerFlags } + if typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType) { + code.Flags |= MarshalerContextFlags + } if isNilableType(typ) { code.Flags |= IsNilableTypeFlags } else { @@ -920,7 +929,7 @@ func compileSlice(ctx *compileContext) (*Opcode, error) { func compileListElem(ctx *compileContext) (*Opcode, error) { typ := ctx.typ switch { - case !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType): + case isPtrMarshalJSONType(typ): return compileMarshalJSON(ctx) case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType): return compileMarshalText(ctx) @@ -1534,8 +1543,12 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { return ret, nil } +func implementsMarshalJSONType(typ *runtime.Type) bool { + return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType) +} + func isPtrMarshalJSONType(typ *runtime.Type) bool { - return !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType) + return !implementsMarshalJSONType(typ) && implementsMarshalJSONType(runtime.PtrTo(typ)) } func isPtrMarshalTextType(typ *runtime.Type) bool { diff --git a/internal/encoder/context.go b/internal/encoder/context.go index ffc1d8b..61b8908 100644 --- a/internal/encoder/context.go +++ b/internal/encoder/context.go @@ -1,6 +1,7 @@ package encoder import ( + "context" "sync" "unsafe" @@ -104,6 +105,7 @@ var ( ) type RuntimeContext struct { + Context context.Context Buf []byte MarshalBuf []byte Ptrs []uintptr diff --git a/internal/encoder/encoder.go b/internal/encoder/encoder.go index 5368ce3..35b8159 100644 --- a/internal/encoder/encoder.go +++ b/internal/encoder/encoder.go @@ -365,13 +365,27 @@ func AppendMarshalJSON(ctx *RuntimeContext, code *Opcode, b []byte, v interface{ } } v = rv.Interface() - marshaler, ok := v.(json.Marshaler) - if !ok { - return AppendNull(ctx, b), nil - } - bb, err := marshaler.MarshalJSON() - if err != nil { - return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err} + var bb []byte + if (code.Flags & MarshalerContextFlags) != 0 { + marshaler, ok := v.(marshalerContext) + if !ok { + return AppendNull(ctx, b), nil + } + b, err := marshaler.MarshalJSON(ctx.Option.Context) + if err != nil { + return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err} + } + bb = b + } else { + marshaler, ok := v.(json.Marshaler) + if !ok { + return AppendNull(ctx, b), nil + } + b, err := marshaler.MarshalJSON() + if err != nil { + return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err} + } + bb = b } marshalBuf := ctx.MarshalBuf[:0] marshalBuf = append(append(marshalBuf, bb...), nul) @@ -395,13 +409,27 @@ func AppendMarshalJSONIndent(ctx *RuntimeContext, code *Opcode, b []byte, v inte } } v = rv.Interface() - marshaler, ok := v.(json.Marshaler) - if !ok { - return AppendNull(ctx, b), nil - } - bb, err := marshaler.MarshalJSON() - if err != nil { - return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err} + var bb []byte + if (code.Flags & MarshalerContextFlags) != 0 { + marshaler, ok := v.(marshalerContext) + if !ok { + return AppendNull(ctx, b), nil + } + b, err := marshaler.MarshalJSON(ctx.Option.Context) + if err != nil { + return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err} + } + bb = b + } else { + marshaler, ok := v.(json.Marshaler) + if !ok { + return AppendNull(ctx, b), nil + } + b, err := marshaler.MarshalJSON() + if err != nil { + return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err} + } + bb = b } marshalBuf := ctx.MarshalBuf[:0] marshalBuf = append(append(marshalBuf, bb...), nul) diff --git a/internal/encoder/opcode.go b/internal/encoder/opcode.go index 15a695d..3330207 100644 --- a/internal/encoder/opcode.go +++ b/internal/encoder/opcode.go @@ -10,7 +10,7 @@ import ( const uintptrSize = 4 << (^uintptr(0) >> 63) -type OpFlags uint8 +type OpFlags uint16 const ( AnonymousHeadFlags OpFlags = 1 << 0 @@ -21,6 +21,7 @@ const ( AddrForMarshalerFlags OpFlags = 1 << 5 IsNextOpPtrTypeFlags OpFlags = 1 << 6 IsNilableTypeFlags OpFlags = 1 << 7 + MarshalerContextFlags OpFlags = 1 << 8 ) type Opcode struct { @@ -32,9 +33,8 @@ type Opcode struct { Key string // struct field key Offset uint32 // offset size from struct header PtrNum uint8 // pointer number: e.g. double pointer is 2. - Flags OpFlags NumBitSize uint8 - _ [1]uint8 // 1 + Flags OpFlags Type *runtime.Type // go type PrevField *Opcode // prev struct field diff --git a/internal/encoder/option.go b/internal/encoder/option.go index 76a43ae..f5f1f04 100644 --- a/internal/encoder/option.go +++ b/internal/encoder/option.go @@ -1,5 +1,7 @@ package encoder +import "context" + type OptionFlag uint8 const ( @@ -8,11 +10,13 @@ const ( UnorderedMapOption DebugOption ColorizeOption + ContextOption ) type Option struct { Flag OptionFlag ColorScheme *ColorScheme + Context context.Context } type EncodeFormat struct { diff --git a/json.go b/json.go index 601e164..5c9448d 100644 --- a/json.go +++ b/json.go @@ -2,6 +2,7 @@ package json import ( "bytes" + "context" "encoding/json" "github.com/goccy/go-json/internal/encoder" @@ -13,6 +14,12 @@ type Marshaler interface { MarshalJSON() ([]byte, error) } +// MarshalerContext is the interface implemented by types that +// can marshal themselves into valid JSON with context.Context. +type MarshalerContext interface { + MarshalJSON(context.Context) ([]byte, error) +} + // Unmarshaler is the interface implemented by types // that can unmarshal a JSON description of themselves. // The input can be assumed to be a valid encoding of @@ -25,6 +32,12 @@ type Unmarshaler interface { UnmarshalJSON([]byte) error } +// UnmarshalerContext is the interface implemented by types +// that can unmarshal with context.Context a JSON description of themselves. +type UnmarshalerContext interface { + UnmarshalJSON(context.Context, []byte) error +} + // Marshal returns the JSON encoding of v. // // Marshal traverses the value v recursively. @@ -158,11 +171,16 @@ func Marshal(v interface{}) ([]byte, error) { return MarshalWithOption(v) } -// MarshalNoEscape +// MarshalNoEscape returns the JSON encoding of v and doesn't escape v. func MarshalNoEscape(v interface{}) ([]byte, error) { return marshalNoEscape(v) } +// MarshalContext returns the JSON encoding of v with context.Context and EncodeOption. +func MarshalContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) { + return marshalContext(ctx, v, optFuncs...) +} + // MarshalWithOption returns the JSON encoding of v with EncodeOption. func MarshalWithOption(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) { return marshal(v, optFuncs...) @@ -258,6 +276,13 @@ func Unmarshal(data []byte, v interface{}) error { return unmarshal(data, v) } +// UnmarshalContext parses the JSON-encoded data and stores the result +// in the value pointed to by v. If you implement the UnmarshalerContext interface, +// call it with ctx as an argument. +func UnmarshalContext(ctx context.Context, data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error { + return unmarshalContext(ctx, data, v) +} + func UnmarshalWithOption(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error { return unmarshal(data, v, optFuncs...) } diff --git a/option.go b/option.go index c8ec1b3..ad65091 100644 --- a/option.go +++ b/option.go @@ -41,6 +41,6 @@ type DecodeOptionFunc func(*DecodeOption) // This behavior has a performance advantage as it allows the subsequent strings to be skipped if all fields have been evaluated. func DecodeFieldPriorityFirstWin() DecodeOptionFunc { return func(opt *DecodeOption) { - opt.Flag |= decoder.FirstWinOption + opt.Flags |= decoder.FirstWinOption } }