Support context for MarshalJSON and UnmarshalJSON

This commit is contained in:
Masaaki Goshima 2021-06-12 17:06:26 +09:00
parent a2ba5e8bcc
commit cd7fb7392f
17 changed files with 355 additions and 51 deletions

View File

@ -1,6 +1,7 @@
package json package json
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"reflect" "reflect"
@ -39,7 +40,7 @@ func unmarshal(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
} }
ctx := decoder.TakeRuntimeContext() ctx := decoder.TakeRuntimeContext()
ctx.Buf = src ctx.Buf = src
ctx.Option.Flag = 0 ctx.Option.Flags = 0
for _, optFunc := range optFuncs { for _, optFunc := range optFuncs {
optFunc(ctx.Option) optFunc(ctx.Option)
} }
@ -52,6 +53,36 @@ func unmarshal(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
return validateEndBuf(src, cursor) return validateEndBuf(src, cursor)
} }
func unmarshalContext(ctx context.Context, data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
src := make([]byte, len(data)+1) // append nul byte to the end
copy(src, data)
header := (*emptyInterface)(unsafe.Pointer(&v))
if err := validateType(header.typ, uintptr(header.ptr)); err != nil {
return err
}
dec, err := decoder.CompileToGetDecoder(header.typ)
if err != nil {
return err
}
rctx := decoder.TakeRuntimeContext()
rctx.Buf = src
rctx.Option.Flags = 0
rctx.Option.Flags |= decoder.ContextOption
rctx.Option.Context = ctx
for _, optFunc := range optFuncs {
optFunc(rctx.Option)
}
cursor, err := dec.Decode(rctx, 0, 0, header.ptr)
if err != nil {
decoder.ReleaseRuntimeContext(rctx)
return err
}
decoder.ReleaseRuntimeContext(rctx)
return validateEndBuf(src, cursor)
}
func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error { func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
src := make([]byte, len(data)+1) // append nul byte to the end src := make([]byte, len(data)+1) // append nul byte to the end
copy(src, data) copy(src, data)
@ -68,7 +99,7 @@ func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc)
ctx := decoder.TakeRuntimeContext() ctx := decoder.TakeRuntimeContext()
ctx.Buf = src ctx.Buf = src
ctx.Option.Flag = 0 ctx.Option.Flags = 0
for _, optFunc := range optFuncs { for _, optFunc := range optFuncs {
optFunc(ctx.Option) optFunc(ctx.Option)
} }
@ -137,6 +168,14 @@ func (d *Decoder) Decode(v interface{}) error {
return d.DecodeWithOption(v) return d.DecodeWithOption(v)
} }
// DecodeContext reads the next JSON-encoded value from its
// input and stores it in the value pointed to by v with context.Context.
func (d *Decoder) DecodeContext(ctx context.Context, v interface{}) error {
d.s.Option.Flags |= decoder.ContextOption
d.s.Option.Context = ctx
return d.DecodeWithOption(v)
}
func (d *Decoder) DecodeWithOption(v interface{}, optFuncs ...DecodeOptionFunc) error { func (d *Decoder) DecodeWithOption(v interface{}, optFuncs ...DecodeOptionFunc) error {
header := (*emptyInterface)(unsafe.Pointer(&v)) header := (*emptyInterface)(unsafe.Pointer(&v))
typ := header.typ typ := header.typ

View File

@ -2,6 +2,7 @@ package json_test
import ( import (
"bytes" "bytes"
"context"
"encoding" "encoding"
stdjson "encoding/json" stdjson "encoding/json"
"errors" "errors"
@ -3620,3 +3621,48 @@ func TestDecodeEscapedCharField(t *testing.T) {
} }
}) })
} }
type unmarshalContextKey struct{}
type unmarshalContextStructType struct {
v int
}
func (t *unmarshalContextStructType) UnmarshalJSON(ctx context.Context, b []byte) error {
v := ctx.Value(unmarshalContextKey{})
s, ok := v.(string)
if !ok {
return fmt.Errorf("failed to propagate parent context.Context")
}
if s != "hello" {
return fmt.Errorf("failed to propagate parent context.Context")
}
t.v = 100
return nil
}
func TestDecodeContextOption(t *testing.T) {
src := []byte("10")
buf := bytes.NewBuffer(src)
t.Run("UnmarshalContext", func(t *testing.T) {
ctx := context.WithValue(context.Background(), unmarshalContextKey{}, "hello")
var v unmarshalContextStructType
if err := json.UnmarshalContext(ctx, src, &v); err != nil {
t.Fatal(err)
}
if v.v != 100 {
t.Fatal("failed to decode with context")
}
})
t.Run("DecodeContext", func(t *testing.T) {
ctx := context.WithValue(context.Background(), unmarshalContextKey{}, "hello")
var v unmarshalContextStructType
if err := json.NewDecoder(buf).DecodeContext(ctx, &v); err != nil {
t.Fatal(err)
}
if v.v != 100 {
t.Fatal("failed to decode with context")
}
})
}

View File

@ -1,6 +1,7 @@
package json package json
import ( import (
"context"
"io" "io"
"unsafe" "unsafe"
@ -35,6 +36,7 @@ func (e *Encoder) Encode(v interface{}) error {
// EncodeWithOption call Encode with EncodeOption. // EncodeWithOption call Encode with EncodeOption.
func (e *Encoder) EncodeWithOption(v interface{}, optFuncs ...EncodeOptionFunc) error { func (e *Encoder) EncodeWithOption(v interface{}, optFuncs ...EncodeOptionFunc) error {
ctx := encoder.TakeRuntimeContext() ctx := encoder.TakeRuntimeContext()
ctx.Option.Flag = 0
err := e.encodeWithOption(ctx, v, optFuncs...) err := e.encodeWithOption(ctx, v, optFuncs...)
@ -42,8 +44,20 @@ func (e *Encoder) EncodeWithOption(v interface{}, optFuncs ...EncodeOptionFunc)
return err return err
} }
// EncodeContext call Encode with context.Context and EncodeOption.
func (e *Encoder) EncodeContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) error {
rctx := encoder.TakeRuntimeContext()
rctx.Option.Flag = 0
rctx.Option.Flag |= encoder.ContextOption
rctx.Option.Context = ctx
err := e.encodeWithOption(rctx, v, optFuncs...)
encoder.ReleaseRuntimeContext(rctx)
return err
}
func (e *Encoder) encodeWithOption(ctx *encoder.RuntimeContext, v interface{}, optFuncs ...EncodeOptionFunc) error { func (e *Encoder) encodeWithOption(ctx *encoder.RuntimeContext, v interface{}, optFuncs ...EncodeOptionFunc) error {
ctx.Option.Flag = 0
if e.enabledHTMLEscape { if e.enabledHTMLEscape {
ctx.Option.Flag |= encoder.HTMLEscapeOption ctx.Option.Flag |= encoder.HTMLEscapeOption
} }
@ -94,6 +108,33 @@ func (e *Encoder) SetIndent(prefix, indent string) {
e.enabledIndent = true e.enabledIndent = true
} }
func marshalContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
rctx := encoder.TakeRuntimeContext()
rctx.Option.Flag = 0
rctx.Option.Flag = encoder.HTMLEscapeOption | encoder.ContextOption
rctx.Option.Context = ctx
for _, optFunc := range optFuncs {
optFunc(rctx.Option)
}
buf, err := encode(rctx, v)
if err != nil {
encoder.ReleaseRuntimeContext(rctx)
return nil, err
}
// this line exists to escape call of `runtime.makeslicecopy` .
// if use `make([]byte, len(buf)-1)` and `copy(copied, buf)`,
// dst buffer size and src buffer size are differrent.
// in this case, compiler uses `runtime.makeslicecopy`, but it is slow.
buf = buf[:len(buf)-1]
copied := make([]byte, len(buf))
copy(copied, buf)
encoder.ReleaseRuntimeContext(rctx)
return copied, nil
}
func marshal(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) { func marshal(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
ctx := encoder.TakeRuntimeContext() ctx := encoder.TakeRuntimeContext()

View File

@ -2,6 +2,7 @@ package json_test
import ( import (
"bytes" "bytes"
"context"
"encoding" "encoding"
stdjson "encoding/json" stdjson "encoding/json"
"errors" "errors"
@ -1918,3 +1919,42 @@ func TestEncodeMapKeyTypeInterface(t *testing.T) {
t.Fatal("expected error") t.Fatal("expected error")
} }
} }
type marshalContextKey struct{}
type marshalContextStructType struct{}
func (t *marshalContextStructType) MarshalJSON(ctx context.Context) ([]byte, error) {
v := ctx.Value(marshalContextKey{})
s, ok := v.(string)
if !ok {
return nil, fmt.Errorf("failed to propagate parent context.Context")
}
if s != "hello" {
return nil, fmt.Errorf("failed to propagate parent context.Context")
}
return []byte(`"success"`), nil
}
func TestEncodeContextOption(t *testing.T) {
t.Run("MarshalContext", func(t *testing.T) {
ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello")
b, err := json.MarshalContext(ctx, &marshalContextStructType{})
if err != nil {
t.Fatal(err)
}
if string(b) != `"success"` {
t.Fatal("failed to encode with MarshalerContext")
}
})
t.Run("EncodeContext", func(t *testing.T) {
ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello")
buf := bytes.NewBuffer([]byte{})
if err := json.NewEncoder(buf).EncodeContext(ctx, &marshalContextStructType{}); err != nil {
t.Fatal(err)
}
if buf.String() != "\"success\"\n" {
t.Fatal("failed to encode with EncodeContext")
}
})
}

View File

@ -60,7 +60,7 @@ func compileToGetDecoderSlowPath(typeptr uintptr, typ *runtime.Type) (Decoder, e
func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) { func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
switch { switch {
case runtime.PtrTo(typ).Implements(unmarshalJSONType): case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
return newUnmarshalJSONDecoder(runtime.PtrTo(typ), "", ""), nil return newUnmarshalJSONDecoder(runtime.PtrTo(typ), "", ""), nil
case runtime.PtrTo(typ).Implements(unmarshalTextType): case runtime.PtrTo(typ).Implements(unmarshalTextType):
return newUnmarshalTextDecoder(runtime.PtrTo(typ), "", ""), nil return newUnmarshalTextDecoder(runtime.PtrTo(typ), "", ""), nil
@ -70,7 +70,7 @@ func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (De
func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) { func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
switch { switch {
case runtime.PtrTo(typ).Implements(unmarshalJSONType): case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
return newUnmarshalJSONDecoder(runtime.PtrTo(typ), structName, fieldName), nil return newUnmarshalJSONDecoder(runtime.PtrTo(typ), structName, fieldName), nil
case runtime.PtrTo(typ).Implements(unmarshalTextType): case runtime.PtrTo(typ).Implements(unmarshalTextType):
return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil
@ -133,7 +133,7 @@ func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecode
func isStringTagSupportedType(typ *runtime.Type) bool { func isStringTagSupportedType(typ *runtime.Type) bool {
switch { switch {
case runtime.PtrTo(typ).Implements(unmarshalJSONType): case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
return false return false
case runtime.PtrTo(typ).Implements(unmarshalTextType): case runtime.PtrTo(typ).Implements(unmarshalTextType):
return false return false
@ -494,3 +494,7 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
structDec.tryOptimize() structDec.tryOptimize()
return structDec, nil return structDec, nil
} }
func implementsUnmarshalJSONType(typ *runtime.Type) bool {
return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType)
}

View File

@ -117,6 +117,21 @@ func decodeStreamUnmarshaler(s *Stream, depth int64, unmarshaler json.Unmarshale
return nil return nil
} }
func decodeStreamUnmarshalerContext(s *Stream, depth int64, unmarshaler unmarshalerContext) error {
start := s.cursor
if err := s.skipValue(depth); err != nil {
return err
}
src := s.buf[start:s.cursor]
dst := make([]byte, len(src))
copy(dst, src)
if err := unmarshaler.UnmarshalJSON(s.Option.Context, dst); err != nil {
return err
}
return nil
}
func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler json.Unmarshaler) (int64, error) { func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler json.Unmarshaler) (int64, error) {
cursor = skipWhiteSpace(buf, cursor) cursor = skipWhiteSpace(buf, cursor)
start := cursor start := cursor
@ -134,6 +149,23 @@ func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler json.Unmarsh
return end, nil return end, nil
} }
func decodeUnmarshalerContext(ctx *RuntimeContext, buf []byte, cursor, depth int64, unmarshaler unmarshalerContext) (int64, error) {
cursor = skipWhiteSpace(buf, cursor)
start := cursor
end, err := skipValue(buf, cursor, depth)
if err != nil {
return 0, err
}
src := buf[start:end]
dst := make([]byte, len(src))
copy(dst, src)
if err := unmarshaler.UnmarshalJSON(ctx.Option.Context, dst); err != nil {
return 0, err
}
return end, nil
}
func decodeStreamTextUnmarshaler(s *Stream, depth int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error { func decodeStreamTextUnmarshaler(s *Stream, depth int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error {
start := s.cursor start := s.cursor
if err := s.skipValue(depth); err != nil { if err := s.skipValue(depth); err != nil {
@ -260,6 +292,9 @@ func (d *interfaceDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer
})) }))
rv := reflect.ValueOf(runtimeInterfaceValue) rv := reflect.ValueOf(runtimeInterfaceValue)
if rv.NumMethod() > 0 && rv.CanInterface() { if rv.NumMethod() > 0 && rv.CanInterface() {
if u, ok := rv.Interface().(unmarshalerContext); ok {
return decodeStreamUnmarshalerContext(s, depth, u)
}
if u, ok := rv.Interface().(json.Unmarshaler); ok { if u, ok := rv.Interface().(json.Unmarshaler); ok {
return decodeStreamUnmarshaler(s, depth, u) return decodeStreamUnmarshaler(s, depth, u)
} }
@ -317,6 +352,9 @@ func (d *interfaceDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p un
})) }))
rv := reflect.ValueOf(runtimeInterfaceValue) rv := reflect.ValueOf(runtimeInterfaceValue)
if rv.NumMethod() > 0 && rv.CanInterface() { if rv.NumMethod() > 0 && rv.CanInterface() {
if u, ok := rv.Interface().(unmarshalerContext); ok {
return decodeUnmarshalerContext(ctx, buf, cursor, depth, u)
}
if u, ok := rv.Interface().(json.Unmarshaler); ok { if u, ok := rv.Interface().(json.Unmarshaler); ok {
return decodeUnmarshaler(buf, cursor, depth, u) return decodeUnmarshaler(buf, cursor, depth, u)
} }

View File

@ -1,11 +1,15 @@
package decoder package decoder
type OptionFlag int import "context"
type OptionFlags uint8
const ( const (
FirstWinOption OptionFlag = 1 << iota FirstWinOption OptionFlags = 1 << iota
ContextOption
) )
type Option struct { type Option struct {
Flag OptionFlag Flags OptionFlags
Context context.Context
} }

View File

@ -665,7 +665,7 @@ func (d *structDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) e
seenFields map[int]struct{} seenFields map[int]struct{}
seenFieldNum int seenFieldNum int
) )
firstWin := (s.Option.Flag & FirstWinOption) != 0 firstWin := (s.Option.Flags & FirstWinOption) != 0
if firstWin { if firstWin {
seenFields = make(map[int]struct{}, d.fieldUniqueNameNum) seenFields = make(map[int]struct{}, d.fieldUniqueNameNum)
} }
@ -752,7 +752,7 @@ func (d *structDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf
seenFields map[int]struct{} seenFields map[int]struct{}
seenFieldNum int seenFieldNum int
) )
firstWin := (ctx.Option.Flag & FirstWinOption) != 0 firstWin := (ctx.Option.Flags & FirstWinOption) != 0
if firstWin { if firstWin {
seenFields = make(map[int]struct{}, d.fieldUniqueNameNum) seenFields = make(map[int]struct{}, d.fieldUniqueNameNum)
} }

View File

@ -1,6 +1,7 @@
package decoder package decoder
import ( import (
"context"
"encoding" "encoding"
"encoding/json" "encoding/json"
"reflect" "reflect"
@ -17,7 +18,12 @@ const (
maxDecodeNestingDepth = 10000 maxDecodeNestingDepth = 10000
) )
type unmarshalerContext interface {
UnmarshalJSON(context.Context, []byte) error
}
var ( var (
unmarshalJSONType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() unmarshalJSONType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
unmarshalTextType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() unmarshalJSONContextType = reflect.TypeOf((*unmarshalerContext)(nil)).Elem()
unmarshalTextType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
) )

View File

@ -46,9 +46,16 @@ func (d *unmarshalJSONDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Poi
typ: d.typ, typ: d.typ,
ptr: p, ptr: p,
})) }))
if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil { if (s.Option.Flags & ContextOption) != 0 {
d.annotateError(s.cursor, err) if err := v.(unmarshalerContext).UnmarshalJSON(s.Option.Context, dst); err != nil {
return err d.annotateError(s.cursor, err)
return err
}
} else {
if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil {
d.annotateError(s.cursor, err)
return err
}
} }
return nil return nil
} }
@ -69,9 +76,16 @@ func (d *unmarshalJSONDecoder) Decode(ctx *RuntimeContext, cursor, depth int64,
typ: d.typ, typ: d.typ,
ptr: p, ptr: p,
})) }))
if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil { if (ctx.Option.Flags & ContextOption) != 0 {
d.annotateError(cursor, err) if err := v.(unmarshalerContext).UnmarshalJSON(ctx.Option.Context, dst); err != nil {
return 0, err d.annotateError(cursor, err)
return 0, err
}
} else {
if err := v.(json.Unmarshaler).UnmarshalJSON(dst); err != nil {
d.annotateError(cursor, err)
return 0, err
}
} }
return end, nil return end, nil
} }

View File

@ -1,6 +1,7 @@
package encoder package encoder
import ( import (
"context"
"encoding" "encoding"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -13,13 +14,18 @@ import (
"github.com/goccy/go-json/internal/runtime" "github.com/goccy/go-json/internal/runtime"
) )
type marshalerContext interface {
MarshalJSON(context.Context) ([]byte, error)
}
var ( var (
marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem()
jsonNumberType = reflect.TypeOf(json.Number("")) marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
cachedOpcodeSets []*OpcodeSet jsonNumberType = reflect.TypeOf(json.Number(""))
cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet cachedOpcodeSets []*OpcodeSet
typeAddr *runtime.TypeAddr cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet
typeAddr *runtime.TypeAddr
) )
func init() { func init() {
@ -110,7 +116,7 @@ func compileHead(ctx *compileContext) (*Opcode, error) {
elem := typ.Elem() elem := typ.Elem()
if elem.Kind() == reflect.Uint8 { if elem.Kind() == reflect.Uint8 {
p := runtime.PtrTo(elem) p := runtime.PtrTo(elem)
if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) { if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
if isPtr { if isPtr {
return compileBytesPtr(ctx) return compileBytesPtr(ctx)
} }
@ -340,14 +346,14 @@ func optimizeStructEnd(c *Opcode) {
} }
func implementsMarshalJSON(typ *runtime.Type) bool { func implementsMarshalJSON(typ *runtime.Type) bool {
if !typ.Implements(marshalJSONType) { if !implementsMarshalJSONType(typ) {
return false return false
} }
if typ.Kind() != reflect.Ptr { if typ.Kind() != reflect.Ptr {
return true return true
} }
// type kind is reflect.Ptr // type kind is reflect.Ptr
if !typ.Elem().Implements(marshalJSONType) { if !implementsMarshalJSONType(typ.Elem()) {
return true return true
} }
// needs to dereference // needs to dereference
@ -384,7 +390,7 @@ func compile(ctx *compileContext, isPtr bool) (*Opcode, error) {
elem := typ.Elem() elem := typ.Elem()
if elem.Kind() == reflect.Uint8 { if elem.Kind() == reflect.Uint8 {
p := runtime.PtrTo(elem) p := runtime.PtrTo(elem)
if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) { if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
return compileBytes(ctx) return compileBytes(ctx)
} }
} }
@ -527,9 +533,12 @@ func compilePtr(ctx *compileContext) (*Opcode, error) {
func compileMarshalJSON(ctx *compileContext) (*Opcode, error) { func compileMarshalJSON(ctx *compileContext) (*Opcode, error) {
code := newOpCode(ctx, OpMarshalJSON) code := newOpCode(ctx, OpMarshalJSON)
typ := ctx.typ typ := ctx.typ
if !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType) { if isPtrMarshalJSONType(typ) {
code.Flags |= AddrForMarshalerFlags code.Flags |= AddrForMarshalerFlags
} }
if typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType) {
code.Flags |= MarshalerContextFlags
}
if isNilableType(typ) { if isNilableType(typ) {
code.Flags |= IsNilableTypeFlags code.Flags |= IsNilableTypeFlags
} else { } else {
@ -920,7 +929,7 @@ func compileSlice(ctx *compileContext) (*Opcode, error) {
func compileListElem(ctx *compileContext) (*Opcode, error) { func compileListElem(ctx *compileContext) (*Opcode, error) {
typ := ctx.typ typ := ctx.typ
switch { switch {
case !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType): case isPtrMarshalJSONType(typ):
return compileMarshalJSON(ctx) return compileMarshalJSON(ctx)
case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType): case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType):
return compileMarshalText(ctx) return compileMarshalText(ctx)
@ -1534,8 +1543,12 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
return ret, nil return ret, nil
} }
func implementsMarshalJSONType(typ *runtime.Type) bool {
return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType)
}
func isPtrMarshalJSONType(typ *runtime.Type) bool { func isPtrMarshalJSONType(typ *runtime.Type) bool {
return !typ.Implements(marshalJSONType) && runtime.PtrTo(typ).Implements(marshalJSONType) return !implementsMarshalJSONType(typ) && implementsMarshalJSONType(runtime.PtrTo(typ))
} }
func isPtrMarshalTextType(typ *runtime.Type) bool { func isPtrMarshalTextType(typ *runtime.Type) bool {

View File

@ -1,6 +1,7 @@
package encoder package encoder
import ( import (
"context"
"sync" "sync"
"unsafe" "unsafe"
@ -104,6 +105,7 @@ var (
) )
type RuntimeContext struct { type RuntimeContext struct {
Context context.Context
Buf []byte Buf []byte
MarshalBuf []byte MarshalBuf []byte
Ptrs []uintptr Ptrs []uintptr

View File

@ -365,13 +365,27 @@ func AppendMarshalJSON(ctx *RuntimeContext, code *Opcode, b []byte, v interface{
} }
} }
v = rv.Interface() v = rv.Interface()
marshaler, ok := v.(json.Marshaler) var bb []byte
if !ok { if (code.Flags & MarshalerContextFlags) != 0 {
return AppendNull(ctx, b), nil marshaler, ok := v.(marshalerContext)
} if !ok {
bb, err := marshaler.MarshalJSON() return AppendNull(ctx, b), nil
if err != nil { }
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err} b, err := marshaler.MarshalJSON(ctx.Option.Context)
if err != nil {
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
}
bb = b
} else {
marshaler, ok := v.(json.Marshaler)
if !ok {
return AppendNull(ctx, b), nil
}
b, err := marshaler.MarshalJSON()
if err != nil {
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
}
bb = b
} }
marshalBuf := ctx.MarshalBuf[:0] marshalBuf := ctx.MarshalBuf[:0]
marshalBuf = append(append(marshalBuf, bb...), nul) marshalBuf = append(append(marshalBuf, bb...), nul)
@ -395,13 +409,27 @@ func AppendMarshalJSONIndent(ctx *RuntimeContext, code *Opcode, b []byte, v inte
} }
} }
v = rv.Interface() v = rv.Interface()
marshaler, ok := v.(json.Marshaler) var bb []byte
if !ok { if (code.Flags & MarshalerContextFlags) != 0 {
return AppendNull(ctx, b), nil marshaler, ok := v.(marshalerContext)
} if !ok {
bb, err := marshaler.MarshalJSON() return AppendNull(ctx, b), nil
if err != nil { }
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err} b, err := marshaler.MarshalJSON(ctx.Option.Context)
if err != nil {
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
}
bb = b
} else {
marshaler, ok := v.(json.Marshaler)
if !ok {
return AppendNull(ctx, b), nil
}
b, err := marshaler.MarshalJSON()
if err != nil {
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
}
bb = b
} }
marshalBuf := ctx.MarshalBuf[:0] marshalBuf := ctx.MarshalBuf[:0]
marshalBuf = append(append(marshalBuf, bb...), nul) marshalBuf = append(append(marshalBuf, bb...), nul)

View File

@ -10,7 +10,7 @@ import (
const uintptrSize = 4 << (^uintptr(0) >> 63) const uintptrSize = 4 << (^uintptr(0) >> 63)
type OpFlags uint8 type OpFlags uint16
const ( const (
AnonymousHeadFlags OpFlags = 1 << 0 AnonymousHeadFlags OpFlags = 1 << 0
@ -21,6 +21,7 @@ const (
AddrForMarshalerFlags OpFlags = 1 << 5 AddrForMarshalerFlags OpFlags = 1 << 5
IsNextOpPtrTypeFlags OpFlags = 1 << 6 IsNextOpPtrTypeFlags OpFlags = 1 << 6
IsNilableTypeFlags OpFlags = 1 << 7 IsNilableTypeFlags OpFlags = 1 << 7
MarshalerContextFlags OpFlags = 1 << 8
) )
type Opcode struct { type Opcode struct {
@ -32,9 +33,8 @@ type Opcode struct {
Key string // struct field key Key string // struct field key
Offset uint32 // offset size from struct header Offset uint32 // offset size from struct header
PtrNum uint8 // pointer number: e.g. double pointer is 2. PtrNum uint8 // pointer number: e.g. double pointer is 2.
Flags OpFlags
NumBitSize uint8 NumBitSize uint8
_ [1]uint8 // 1 Flags OpFlags
Type *runtime.Type // go type Type *runtime.Type // go type
PrevField *Opcode // prev struct field PrevField *Opcode // prev struct field

View File

@ -1,5 +1,7 @@
package encoder package encoder
import "context"
type OptionFlag uint8 type OptionFlag uint8
const ( const (
@ -8,11 +10,13 @@ const (
UnorderedMapOption UnorderedMapOption
DebugOption DebugOption
ColorizeOption ColorizeOption
ContextOption
) )
type Option struct { type Option struct {
Flag OptionFlag Flag OptionFlag
ColorScheme *ColorScheme ColorScheme *ColorScheme
Context context.Context
} }
type EncodeFormat struct { type EncodeFormat struct {

27
json.go
View File

@ -2,6 +2,7 @@ package json
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"github.com/goccy/go-json/internal/encoder" "github.com/goccy/go-json/internal/encoder"
@ -13,6 +14,12 @@ type Marshaler interface {
MarshalJSON() ([]byte, error) MarshalJSON() ([]byte, error)
} }
// MarshalerContext is the interface implemented by types that
// can marshal themselves into valid JSON with context.Context.
type MarshalerContext interface {
MarshalJSON(context.Context) ([]byte, error)
}
// Unmarshaler is the interface implemented by types // Unmarshaler is the interface implemented by types
// that can unmarshal a JSON description of themselves. // that can unmarshal a JSON description of themselves.
// The input can be assumed to be a valid encoding of // The input can be assumed to be a valid encoding of
@ -25,6 +32,12 @@ type Unmarshaler interface {
UnmarshalJSON([]byte) error UnmarshalJSON([]byte) error
} }
// UnmarshalerContext is the interface implemented by types
// that can unmarshal with context.Context a JSON description of themselves.
type UnmarshalerContext interface {
UnmarshalJSON(context.Context, []byte) error
}
// Marshal returns the JSON encoding of v. // Marshal returns the JSON encoding of v.
// //
// Marshal traverses the value v recursively. // Marshal traverses the value v recursively.
@ -158,11 +171,16 @@ func Marshal(v interface{}) ([]byte, error) {
return MarshalWithOption(v) return MarshalWithOption(v)
} }
// MarshalNoEscape // MarshalNoEscape returns the JSON encoding of v and doesn't escape v.
func MarshalNoEscape(v interface{}) ([]byte, error) { func MarshalNoEscape(v interface{}) ([]byte, error) {
return marshalNoEscape(v) return marshalNoEscape(v)
} }
// MarshalContext returns the JSON encoding of v with context.Context and EncodeOption.
func MarshalContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
return marshalContext(ctx, v, optFuncs...)
}
// MarshalWithOption returns the JSON encoding of v with EncodeOption. // MarshalWithOption returns the JSON encoding of v with EncodeOption.
func MarshalWithOption(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) { func MarshalWithOption(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
return marshal(v, optFuncs...) return marshal(v, optFuncs...)
@ -258,6 +276,13 @@ func Unmarshal(data []byte, v interface{}) error {
return unmarshal(data, v) return unmarshal(data, v)
} }
// UnmarshalContext parses the JSON-encoded data and stores the result
// in the value pointed to by v. If you implement the UnmarshalerContext interface,
// call it with ctx as an argument.
func UnmarshalContext(ctx context.Context, data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
return unmarshalContext(ctx, data, v)
}
func UnmarshalWithOption(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error { func UnmarshalWithOption(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
return unmarshal(data, v, optFuncs...) return unmarshal(data, v, optFuncs...)
} }

View File

@ -41,6 +41,6 @@ type DecodeOptionFunc func(*DecodeOption)
// This behavior has a performance advantage as it allows the subsequent strings to be skipped if all fields have been evaluated. // This behavior has a performance advantage as it allows the subsequent strings to be skipped if all fields have been evaluated.
func DecodeFieldPriorityFirstWin() DecodeOptionFunc { func DecodeFieldPriorityFirstWin() DecodeOptionFunc {
return func(opt *DecodeOption) { return func(opt *DecodeOption) {
opt.Flag |= decoder.FirstWinOption opt.Flags |= decoder.FirstWinOption
} }
} }