From 89bcc3be8660e58363ab30d68658a5ac1488f100 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Mon, 3 Jan 2022 12:33:51 +0900 Subject: [PATCH] Supports dynamic filtering of struct fields --- benchmarks/encode_test.go | 78 ++++++++++++++ encode.go | 6 +- internal/cmd/generator/vm.go.tmpl | 2 +- internal/encoder/code.go | 133 +++++++++++++++++++++++- internal/encoder/compiler.go | 27 +++++ internal/encoder/compiler_norace.go | 14 ++- internal/encoder/compiler_race.go | 15 ++- internal/encoder/encode_opcode_test.go | 4 +- internal/encoder/encoder.go | 22 +++- internal/encoder/opcode.go | 2 + internal/encoder/option.go | 1 + internal/encoder/query.go | 135 +++++++++++++++++++++++++ internal/encoder/vm/vm.go | 2 +- internal/encoder/vm_color/vm.go | 2 +- internal/encoder/vm_color_indent/vm.go | 2 +- internal/encoder/vm_indent/vm.go | 2 +- json.go | 5 + query.go | 36 +++++++ query_test.go | 121 ++++++++++++++++++++++ size_test.go | 2 +- test/example/example_query_test.go | 96 ++++++++++++++++++ 21 files changed, 688 insertions(+), 19 deletions(-) create mode 100644 internal/encoder/query.go create mode 100644 query.go create mode 100644 query_test.go create mode 100644 test/example/example_query_test.go diff --git a/benchmarks/encode_test.go b/benchmarks/encode_test.go index 4da8034..1e8938c 100644 --- a/benchmarks/encode_test.go +++ b/benchmarks/encode_test.go @@ -2,6 +2,7 @@ package benchmark import ( "bytes" + "context" "encoding/json" "testing" @@ -835,3 +836,80 @@ func Benchmark_Encode_MarshalJSON_GoJson(b *testing.B) { } } } + +type queryTestX struct { + XA int + XB string + XC *queryTestY + XD bool + XE float32 +} + +type queryTestY struct { + YA int + YB string + YC bool + YD float32 +} + +func Benchmark_Encode_FilterByMap(b *testing.B) { + v := &queryTestX{ + XA: 1, + XB: "xb", + XC: &queryTestY{ + YA: 2, + YB: "yb", + YC: true, + YD: 4, + }, + XD: true, + XE: 5, + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + filteredMap := map[string]interface{}{ + "XA": v.XA, + "XB": v.XB, + "XC": map[string]interface{}{ + "YA": v.XC.YA, + "YB": v.XC.YB, + }, + } + if _, err := gojson.Marshal(filteredMap); err != nil { + b.Fatal(err) + } + } +} + +func Benchmark_Encode_FilterByFieldQuery(b *testing.B) { + query, err := gojson.BuildFieldQuery( + "XA", + "XB", + gojson.BuildSubFieldQuery("XC").Fields( + "YA", + "YB", + ), + ) + if err != nil { + b.Fatal(err) + } + v := &queryTestX{ + XA: 1, + XB: "xb", + XC: &queryTestY{ + YA: 2, + YB: "yb", + YC: true, + YD: 4, + }, + XD: true, + XE: 5, + } + ctx := gojson.SetFieldQueryToContext(context.Background(), query) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if _, err := gojson.MarshalContext(ctx, v); err != nil { + b.Fatal(err) + } + } +} diff --git a/encode.go b/encode.go index b55ec04..c9527c0 100644 --- a/encode.go +++ b/encode.go @@ -221,7 +221,7 @@ func encode(ctx *encoder.RuntimeContext, v interface{}) ([]byte, error) { typ := header.typ typeptr := uintptr(unsafe.Pointer(typ)) - codeSet, err := encoder.CompileToGetCodeSet(typeptr) + codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr) if err != nil { return nil, err } @@ -249,7 +249,7 @@ func encodeNoEscape(ctx *encoder.RuntimeContext, v interface{}) ([]byte, error) typ := header.typ typeptr := uintptr(unsafe.Pointer(typ)) - codeSet, err := encoder.CompileToGetCodeSet(typeptr) + codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr) if err != nil { return nil, err } @@ -276,7 +276,7 @@ func encodeIndent(ctx *encoder.RuntimeContext, v interface{}, prefix, indent str typ := header.typ typeptr := uintptr(unsafe.Pointer(typ)) - codeSet, err := encoder.CompileToGetCodeSet(typeptr) + codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr) if err != nil { return nil, err } diff --git a/internal/cmd/generator/vm.go.tmpl b/internal/cmd/generator/vm.go.tmpl index 4be6b80..f45a593 100644 --- a/internal/cmd/generator/vm.go.tmpl +++ b/internal/cmd/generator/vm.go.tmpl @@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b break } ctx.KeepRefs = append(ctx.KeepRefs, up) - ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) + ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ))) if err != nil { return nil, err } diff --git a/internal/encoder/code.go b/internal/encoder/code.go index aee0101..4715fad 100644 --- a/internal/encoder/code.go +++ b/internal/encoder/code.go @@ -10,6 +10,7 @@ import ( type Code interface { Kind() CodeKind ToOpcode(*compileContext) Opcodes + Filter(*FieldQuery) Code } type AnonymousCode interface { @@ -82,6 +83,10 @@ func (c *IntCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{code} } +func (c *IntCode) Filter(_ *FieldQuery) Code { + return c +} + type UintCode struct { typ *runtime.Type bitSize uint8 @@ -108,6 +113,10 @@ func (c *UintCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{code} } +func (c *UintCode) Filter(_ *FieldQuery) Code { + return c +} + type FloatCode struct { typ *runtime.Type bitSize uint8 @@ -140,6 +149,10 @@ func (c *FloatCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{code} } +func (c *FloatCode) Filter(_ *FieldQuery) Code { + return c +} + type StringCode struct { typ *runtime.Type isPtr bool @@ -169,6 +182,10 @@ func (c *StringCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{code} } +func (c *StringCode) Filter(_ *FieldQuery) Code { + return c +} + type BoolCode struct { typ *runtime.Type isPtr bool @@ -190,6 +207,10 @@ func (c *BoolCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{code} } +func (c *BoolCode) Filter(_ *FieldQuery) Code { + return c +} + type BytesCode struct { typ *runtime.Type isPtr bool @@ -211,6 +232,10 @@ func (c *BytesCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{code} } +func (c *BytesCode) Filter(_ *FieldQuery) Code { + return c +} + type SliceCode struct { typ *runtime.Type value Code @@ -245,6 +270,10 @@ func (c *SliceCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{header}.Add(codes...).Add(elemCode).Add(end) } +func (c *SliceCode) Filter(_ *FieldQuery) Code { + return c +} + type ArrayCode struct { typ *runtime.Type value Code @@ -286,6 +315,10 @@ func (c *ArrayCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{header}.Add(codes...).Add(elemCode).Add(end) } +func (c *ArrayCode) Filter(_ *FieldQuery) Code { + return c +} + type MapCode struct { typ *runtime.Type key Code @@ -332,6 +365,10 @@ func (c *MapCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{header}.Add(keyCodes...).Add(value).Add(valueCodes...).Add(key).Add(end) } +func (c *MapCode) Filter(_ *FieldQuery) Code { + return c +} + type StructCode struct { typ *runtime.Type fields []*StructFieldCode @@ -520,6 +557,45 @@ func (c *StructCode) enableIndirect() { structCode.enableIndirect() } +func (c *StructCode) Filter(query *FieldQuery) Code { + fieldMap := map[string]*FieldQuery{} + for _, field := range query.Fields { + fieldMap[field.Name] = field + } + fields := make([]*StructFieldCode, 0, len(c.fields)) + for _, field := range c.fields { + query, exists := fieldMap[field.key] + if !exists { + continue + } + fieldCode := &StructFieldCode{ + typ: field.typ, + key: field.key, + tag: field.tag, + value: field.value, + offset: field.offset, + isAnonymous: field.isAnonymous, + isTaggedKey: field.isTaggedKey, + isNilableType: field.isNilableType, + isNilCheck: field.isNilCheck, + isAddrForMarshaler: field.isAddrForMarshaler, + isNextOpPtrType: field.isNextOpPtrType, + } + if len(query.Fields) > 0 { + fieldCode.value = fieldCode.value.Filter(query) + } + fields = append(fields, fieldCode) + } + return &StructCode{ + typ: c.typ, + fields: fields, + isPtr: c.isPtr, + disableIndirectConversion: c.disableIndirectConversion, + isIndirect: c.isIndirect, + isRecursive: c.isRecursive, + } +} + type StructFieldCode struct { typ *runtime.Type key string @@ -532,6 +608,7 @@ type StructFieldCode struct { isNilCheck bool isAddrForMarshaler bool isNextOpPtrType bool + isMarshalerContext bool } func (c *StructFieldCode) getStruct() *StructCode { @@ -574,8 +651,12 @@ func (c *StructFieldCode) headerOpcodes(ctx *compileContext, field *Opcode, valu value := valueCodes.First() op := optimizeStructHeader(value, c.tag) field.Op = op + if value.Flags&MarshalerContextFlags != 0 { + field.Flags |= MarshalerContextFlags + } field.NumBitSize = value.NumBitSize field.PtrNum = value.PtrNum + field.FieldQuery = value.FieldQuery fieldCodes := Opcodes{field} if op.IsMultipleOpHead() { field.Next = value @@ -590,8 +671,12 @@ func (c *StructFieldCode) fieldOpcodes(ctx *compileContext, field *Opcode, value value := valueCodes.First() op := optimizeStructField(value, c.tag) field.Op = op + if value.Flags&MarshalerContextFlags != 0 { + field.Flags |= MarshalerContextFlags + } field.NumBitSize = value.NumBitSize field.PtrNum = value.PtrNum + field.FieldQuery = value.FieldQuery fieldCodes := Opcodes{field} if op.IsMultipleOpField() { @@ -645,6 +730,9 @@ func (c *StructFieldCode) flags() OpFlags { if c.isAnonymous { flags |= AnonymousKeyFlags } + if c.isMarshalerContext { + flags |= MarshalerContextFlags + } return flags } @@ -725,8 +813,9 @@ func isEnableStructEndOptimization(value Code) bool { } type InterfaceCode struct { - typ *runtime.Type - isPtr bool + typ *runtime.Type + fieldQuery *FieldQuery + isPtr bool } func (c *InterfaceCode) Kind() CodeKind { @@ -741,6 +830,7 @@ func (c *InterfaceCode) ToOpcode(ctx *compileContext) Opcodes { default: code = newOpCode(ctx, c.typ, OpInterface) } + code.FieldQuery = c.fieldQuery if c.typ.NumMethod() > 0 { code.Flags |= NonEmptyInterfaceFlags } @@ -748,8 +838,17 @@ func (c *InterfaceCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{code} } +func (c *InterfaceCode) Filter(query *FieldQuery) Code { + return &InterfaceCode{ + typ: c.typ, + fieldQuery: query, + isPtr: c.isPtr, + } +} + type MarshalJSONCode struct { typ *runtime.Type + fieldQuery *FieldQuery isAddrForMarshaler bool isNilableType bool isMarshalerContext bool @@ -761,6 +860,7 @@ func (c *MarshalJSONCode) Kind() CodeKind { func (c *MarshalJSONCode) ToOpcode(ctx *compileContext) Opcodes { code := newOpCode(ctx, c.typ, OpMarshalJSON) + code.FieldQuery = c.fieldQuery if c.isAddrForMarshaler { code.Flags |= AddrForMarshalerFlags } @@ -776,8 +876,19 @@ func (c *MarshalJSONCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{code} } +func (c *MarshalJSONCode) Filter(query *FieldQuery) Code { + return &MarshalJSONCode{ + typ: c.typ, + fieldQuery: query, + isAddrForMarshaler: c.isAddrForMarshaler, + isNilableType: c.isNilableType, + isMarshalerContext: c.isMarshalerContext, + } +} + type MarshalTextCode struct { typ *runtime.Type + fieldQuery *FieldQuery isAddrForMarshaler bool isNilableType bool } @@ -788,6 +899,7 @@ func (c *MarshalTextCode) Kind() CodeKind { func (c *MarshalTextCode) ToOpcode(ctx *compileContext) Opcodes { code := newOpCode(ctx, c.typ, OpMarshalText) + code.FieldQuery = c.fieldQuery if c.isAddrForMarshaler { code.Flags |= AddrForMarshalerFlags } @@ -800,6 +912,15 @@ func (c *MarshalTextCode) ToOpcode(ctx *compileContext) Opcodes { return Opcodes{code} } +func (c *MarshalTextCode) Filter(query *FieldQuery) Code { + return &MarshalTextCode{ + typ: c.typ, + fieldQuery: query, + isAddrForMarshaler: c.isAddrForMarshaler, + isNilableType: c.isNilableType, + } +} + type PtrCode struct { typ *runtime.Type value Code @@ -830,6 +951,14 @@ func (c *PtrCode) ToAnonymousOpcode(ctx *compileContext) Opcodes { return codes } +func (c *PtrCode) Filter(query *FieldQuery) Code { + return &PtrCode{ + typ: c.typ, + value: c.value.Filter(query), + ptrNum: c.ptrNum, + } +} + func convertPtrOp(code *Opcode) OpType { ptrHeadOp := code.Op.HeadToPtrHead() if code.Op != ptrHeadOp { diff --git a/internal/encoder/compiler.go b/internal/encoder/compiler.go index cbcdb65..0eb3d90 100644 --- a/internal/encoder/compiler.go +++ b/internal/encoder/compiler.go @@ -63,6 +63,27 @@ func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) { return codeSet, nil } +func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*OpcodeSet, error) { + if (ctx.Option.Flag & ContextOption) == 0 { + return codeSet, nil + } + query := FieldQueryFromContext(ctx.Option.Context) + if query == nil { + return codeSet, nil + } + ctx.Option.Flag |= FieldQueryOption + cacheCodeSet := codeSet.getQueryCache(query.Hash()) + if cacheCodeSet != nil { + return cacheCodeSet, nil + } + queryCodeSet, err := newCompiler().codeToOpcodeSet(codeSet.Type, codeSet.Code.Filter(query)) + if err != nil { + return nil, err + } + codeSet.setQueryCache(query.Hash(), queryCodeSet) + return queryCodeSet, nil +} + type Compiler struct { structTypeToCode map[uintptr]*StructCode } @@ -80,6 +101,10 @@ func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) { if err != nil { return nil, err } + return c.codeToOpcodeSet(typ, code) +} + +func (c *Compiler) codeToOpcodeSet(typ *runtime.Type, code Code) (*OpcodeSet, error) { noescapeKeyCode := c.codeToOpcode(&compileContext{ structTypeToCodes: map[uintptr]Opcodes{}, recursiveCodes: &Opcodes{}, @@ -107,6 +132,8 @@ func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) { InterfaceEscapeKeyCode: interfaceEscapeKeyCode, CodeLength: codeLength, EndCode: ToEndCode(interfaceNoescapeKeyCode), + Code: code, + QueryCache: map[string]*OpcodeSet{}, }, nil } diff --git a/internal/encoder/compiler_norace.go b/internal/encoder/compiler_norace.go index afc5b66..2576419 100644 --- a/internal/encoder/compiler_norace.go +++ b/internal/encoder/compiler_norace.go @@ -3,18 +3,26 @@ package encoder -func CompileToGetCodeSet(typeptr uintptr) (*OpcodeSet, error) { +func CompileToGetCodeSet(ctx *RuntimeContext, typeptr uintptr) (*OpcodeSet, error) { if typeptr > typeAddr.MaxTypeAddr { return compileToGetCodeSetSlowPath(typeptr) } index := (typeptr - typeAddr.BaseTypeAddr) >> typeAddr.AddrShift if codeSet := cachedOpcodeSets[index]; codeSet != nil { - return codeSet, nil + filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet) + if err != nil { + return nil, err + } + return filtered, nil } codeSet, err := newCompiler().compile(typeptr) if err != nil { return nil, err } + filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet) + if err != nil { + return nil, err + } cachedOpcodeSets[index] = codeSet - return codeSet, nil + return filtered, nil } diff --git a/internal/encoder/compiler_race.go b/internal/encoder/compiler_race.go index 846a898..c744511 100644 --- a/internal/encoder/compiler_race.go +++ b/internal/encoder/compiler_race.go @@ -9,15 +9,20 @@ import ( var setsMu sync.RWMutex -func CompileToGetCodeSet(typeptr uintptr) (*OpcodeSet, error) { +func CompileToGetCodeSet(ctx *RuntimeContext, typeptr uintptr) (*OpcodeSet, error) { if typeptr > typeAddr.MaxTypeAddr { return compileToGetCodeSetSlowPath(typeptr) } index := (typeptr - typeAddr.BaseTypeAddr) >> typeAddr.AddrShift setsMu.RLock() if codeSet := cachedOpcodeSets[index]; codeSet != nil { + filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet) + if err != nil { + setsMu.RUnlock() + return nil, err + } setsMu.RUnlock() - return codeSet, nil + return filtered, nil } setsMu.RUnlock() @@ -25,8 +30,12 @@ func CompileToGetCodeSet(typeptr uintptr) (*OpcodeSet, error) { if err != nil { return nil, err } + filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet) + if err != nil { + return nil, err + } setsMu.Lock() cachedOpcodeSets[index] = codeSet setsMu.Unlock() - return codeSet, nil + return filtered, nil } diff --git a/internal/encoder/encode_opcode_test.go b/internal/encoder/encode_opcode_test.go index b4260ba..e5893d3 100644 --- a/internal/encoder/encode_opcode_test.go +++ b/internal/encoder/encode_opcode_test.go @@ -6,11 +6,13 @@ import ( ) func TestDumpOpcode(t *testing.T) { + ctx := TakeRuntimeContext() + defer ReleaseRuntimeContext(ctx) var v interface{} = 1 header := (*emptyInterface)(unsafe.Pointer(&v)) typ := header.typ typeptr := uintptr(unsafe.Pointer(typ)) - codeSet, err := CompileToGetCodeSet(typeptr) + codeSet, err := CompileToGetCodeSet(ctx, typeptr) if err != nil { t.Fatal(err) } diff --git a/internal/encoder/encoder.go b/internal/encoder/encoder.go index 79a3f64..4ba3158 100644 --- a/internal/encoder/encoder.go +++ b/internal/encoder/encoder.go @@ -101,6 +101,22 @@ type OpcodeSet struct { InterfaceEscapeKeyCode *Opcode CodeLength int EndCode *Opcode + Code Code + QueryCache map[string]*OpcodeSet + cacheMu sync.RWMutex +} + +func (s *OpcodeSet) getQueryCache(hash string) *OpcodeSet { + s.cacheMu.RLock() + codeSet := s.QueryCache[hash] + s.cacheMu.RUnlock() + return codeSet +} + +func (s *OpcodeSet) setQueryCache(hash string, codeSet *OpcodeSet) { + s.cacheMu.Lock() + s.QueryCache[hash] = codeSet + s.cacheMu.Unlock() } type CompiledCode struct { @@ -395,7 +411,11 @@ func AppendMarshalJSON(ctx *RuntimeContext, code *Opcode, b []byte, v interface{ if !ok { return AppendNull(ctx, b), nil } - b, err := marshaler.MarshalJSON(ctx.Option.Context) + stdctx := ctx.Option.Context + if ctx.Option.Flag&FieldQueryOption != 0 { + stdctx = SetFieldQueryToContext(stdctx, code.FieldQuery) + } + b, err := marshaler.MarshalJSON(stdctx) if err != nil { return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err} } diff --git a/internal/encoder/opcode.go b/internal/encoder/opcode.go index 903dd6f..b02ae35 100644 --- a/internal/encoder/opcode.go +++ b/internal/encoder/opcode.go @@ -39,6 +39,7 @@ type Opcode struct { Type *runtime.Type // go type Jmp *CompiledCode // for recursive call + FieldQuery *FieldQuery // field query for Interface / MarshalJSON / MarshalText ElemIdx uint32 // offset to access array/slice elem Length uint32 // offset to access slice length or array length Indent uint32 // indent number @@ -333,6 +334,7 @@ func copyOpcode(code *Opcode) *Opcode { Idx: c.Idx, Offset: c.Offset, Type: c.Type, + FieldQuery: c.FieldQuery, DisplayIdx: c.DisplayIdx, DisplayKey: c.DisplayKey, ElemIdx: c.ElemIdx, diff --git a/internal/encoder/option.go b/internal/encoder/option.go index b30964c..dcec8f2 100644 --- a/internal/encoder/option.go +++ b/internal/encoder/option.go @@ -12,6 +12,7 @@ const ( ColorizeOption ContextOption NormalizeUTF8Option + FieldQueryOption ) type Option struct { diff --git a/internal/encoder/query.go b/internal/encoder/query.go new file mode 100644 index 0000000..1e1850c --- /dev/null +++ b/internal/encoder/query.go @@ -0,0 +1,135 @@ +package encoder + +import ( + "context" + "fmt" + "reflect" +) + +var ( + Marshal func(interface{}) ([]byte, error) + Unmarshal func([]byte, interface{}) error +) + +type FieldQuery struct { + Name string + Fields []*FieldQuery + hash string +} + +func (q *FieldQuery) Hash() string { + if q.hash != "" { + return q.hash + } + b, _ := Marshal(q) + q.hash = string(b) + return q.hash +} + +func (q *FieldQuery) MarshalJSON() ([]byte, error) { + if q.Name != "" { + if len(q.Fields) > 0 { + return Marshal(map[string][]*FieldQuery{q.Name: q.Fields}) + } + return Marshal(q.Name) + } + return Marshal(q.Fields) +} + +func (q *FieldQuery) QueryString() (FieldQueryString, error) { + b, err := Marshal(q) + if err != nil { + return "", err + } + return FieldQueryString(b), nil +} + +type FieldQueryString string + +func (s FieldQueryString) Build() (*FieldQuery, error) { + var query interface{} + if err := Unmarshal([]byte(s), &query); err != nil { + return nil, err + } + return s.build(reflect.ValueOf(query)) +} + +func (s FieldQueryString) build(v reflect.Value) (*FieldQuery, error) { + switch v.Type().Kind() { + case reflect.String: + return s.buildString(v) + case reflect.Map: + return s.buildMap(v) + case reflect.Slice: + return s.buildSlice(v) + case reflect.Interface: + return s.build(reflect.ValueOf(v.Interface())) + } + return nil, fmt.Errorf("failed to build field query") +} + +func (s FieldQueryString) buildString(v reflect.Value) (*FieldQuery, error) { + b := []byte(v.String()) + switch b[0] { + case '[', '{': + var query interface{} + if err := Unmarshal(b, &query); err != nil { + return nil, err + } + if str, ok := query.(string); ok { + return &FieldQuery{Name: str}, nil + } + return s.build(reflect.ValueOf(query)) + } + return &FieldQuery{Name: string(b)}, nil +} + +func (s FieldQueryString) buildSlice(v reflect.Value) (*FieldQuery, error) { + fields := make([]*FieldQuery, 0, v.Len()) + for i := 0; i < v.Len(); i++ { + def, err := s.build(v.Index(i)) + if err != nil { + return nil, err + } + fields = append(fields, def) + } + return &FieldQuery{Fields: fields}, nil +} + +func (s FieldQueryString) buildMap(v reflect.Value) (*FieldQuery, error) { + keys := v.MapKeys() + if len(keys) != 1 { + return nil, fmt.Errorf("failed to build field query object") + } + key := keys[0] + if key.Type().Kind() != reflect.String { + return nil, fmt.Errorf("failed to build field query. invalid object key type") + } + name := key.String() + def, err := s.build(v.MapIndex(key)) + if err != nil { + return nil, err + } + return &FieldQuery{ + Name: name, + Fields: def.Fields, + }, nil +} + +type queryKey struct{} + +func FieldQueryFromContext(ctx context.Context) *FieldQuery { + query := ctx.Value(queryKey{}) + if query == nil { + return nil + } + q, ok := query.(*FieldQuery) + if !ok { + return nil + } + return q +} + +func SetFieldQueryToContext(ctx context.Context, query *FieldQuery) context.Context { + return context.WithValue(ctx, queryKey{}, query) +} diff --git a/internal/encoder/vm/vm.go b/internal/encoder/vm/vm.go index 4be6b80..f45a593 100644 --- a/internal/encoder/vm/vm.go +++ b/internal/encoder/vm/vm.go @@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b break } ctx.KeepRefs = append(ctx.KeepRefs, up) - ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) + ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ))) if err != nil { return nil, err } diff --git a/internal/encoder/vm_color/vm.go b/internal/encoder/vm_color/vm.go index b13abe8..4a2e3c7 100644 --- a/internal/encoder/vm_color/vm.go +++ b/internal/encoder/vm_color/vm.go @@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b break } ctx.KeepRefs = append(ctx.KeepRefs, up) - ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) + ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ))) if err != nil { return nil, err } diff --git a/internal/encoder/vm_color_indent/vm.go b/internal/encoder/vm_color_indent/vm.go index a45aa54..a9fb725 100644 --- a/internal/encoder/vm_color_indent/vm.go +++ b/internal/encoder/vm_color_indent/vm.go @@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b break } ctx.KeepRefs = append(ctx.KeepRefs, up) - ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) + ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ))) if err != nil { return nil, err } diff --git a/internal/encoder/vm_indent/vm.go b/internal/encoder/vm_indent/vm.go index d1e0b45..a023472 100644 --- a/internal/encoder/vm_indent/vm.go +++ b/internal/encoder/vm_indent/vm.go @@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b break } ctx.KeepRefs = append(ctx.KeepRefs, up) - ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) + ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ))) if err != nil { return nil, err } diff --git a/json.go b/json.go index 5c9448d..413cb20 100644 --- a/json.go +++ b/json.go @@ -364,3 +364,8 @@ func Valid(data []byte) bool { } return decoder.InputOffset() >= int64(len(data)) } + +func init() { + encoder.Marshal = Marshal + encoder.Unmarshal = Unmarshal +} diff --git a/query.go b/query.go new file mode 100644 index 0000000..137e663 --- /dev/null +++ b/query.go @@ -0,0 +1,36 @@ +package json + +import ( + "github.com/goccy/go-json/internal/encoder" +) + +type ( + FieldQuery = encoder.FieldQuery + FieldQueryString = encoder.FieldQueryString +) + +var ( + FieldQueryFromContext = encoder.FieldQueryFromContext + SetFieldQueryToContext = encoder.SetFieldQueryToContext +) + +func BuildFieldQuery(fields ...FieldQueryString) (*FieldQuery, error) { + query, err := Marshal(fields) + if err != nil { + return nil, err + } + return FieldQueryString(query).Build() +} + +func BuildSubFieldQuery(name string) *SubFieldQuery { + return &SubFieldQuery{name: name} +} + +type SubFieldQuery struct { + name string +} + +func (q *SubFieldQuery) Fields(fields ...FieldQueryString) FieldQueryString { + query, _ := Marshal(map[string][]FieldQueryString{q.name: fields}) + return FieldQueryString(query) +} diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..d8cac61 --- /dev/null +++ b/query_test.go @@ -0,0 +1,121 @@ +package json_test + +import ( + "context" + "reflect" + "testing" + + "github.com/goccy/go-json" +) + +type queryTestX struct { + XA int + XB string + XC *queryTestY + XD bool + XE float32 +} + +type queryTestY struct { + YA int + YB string + YC *queryTestZ + YD bool + YE float32 +} + +type queryTestZ struct { + ZA string + ZB bool + ZC int +} + +func (z *queryTestZ) MarshalJSON(ctx context.Context) ([]byte, error) { + type _queryTestZ queryTestZ + return json.MarshalContext(ctx, (*_queryTestZ)(z)) +} + +func TestFieldQuery(t *testing.T) { + query, err := json.BuildFieldQuery( + "XA", + "XB", + json.BuildSubFieldQuery("XC").Fields( + "YA", + "YB", + json.BuildSubFieldQuery("YC").Fields( + "ZA", + "ZB", + ), + ), + ) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(query, &json.FieldQuery{ + Fields: []*json.FieldQuery{ + { + Name: "XA", + }, + { + Name: "XB", + }, + { + Name: "XC", + Fields: []*json.FieldQuery{ + { + Name: "YA", + }, + { + Name: "YB", + }, + { + Name: "YC", + Fields: []*json.FieldQuery{ + { + Name: "ZA", + }, + { + Name: "ZB", + }, + }, + }, + }, + }, + }, + }) { + t.Fatal("cannot get query") + } + queryStr, err := query.QueryString() + if err != nil { + t.Fatal(err) + } + if queryStr != `["XA","XB",{"XC":["YA","YB",{"YC":["ZA","ZB"]}]}]` { + t.Fatalf("failed to create query string. %s", queryStr) + } + ctx := json.SetFieldQueryToContext(context.Background(), query) + b, err := json.MarshalContext(ctx, &queryTestX{ + XA: 1, + XB: "xb", + XC: &queryTestY{ + YA: 2, + YB: "yb", + YC: &queryTestZ{ + ZA: "za", + ZB: true, + ZC: 3, + }, + YD: true, + YE: 4, + }, + XD: true, + XE: 5, + }) + if err != nil { + t.Fatal(err) + } + expected := `{"XA":1,"XB":"xb","XC":{"YA":2,"YB":"yb","YC":{"ZA":"za","ZB":true}}}` + got := string(b) + if expected != got { + t.Fatalf("failed to encode with field query: expected %q but got %q", expected, got) + } +} diff --git a/size_test.go b/size_test.go index 9804cab..00c3a2b 100644 --- a/size_test.go +++ b/size_test.go @@ -11,7 +11,7 @@ func TestOpcodeSize(t *testing.T) { const uintptrSize = 4 << (^uintptr(0) >> 63) if uintptrSize == 8 { size := unsafe.Sizeof(encoder.Opcode{}) - if size != 112 { + if size != 120 { t.Fatalf("unexpected opcode size: expected 112bytes but got %dbytes", size) } } diff --git a/test/example/example_query_test.go b/test/example/example_query_test.go new file mode 100644 index 0000000..cca22ba --- /dev/null +++ b/test/example/example_query_test.go @@ -0,0 +1,96 @@ +package json_test + +import ( + "context" + "fmt" + "log" + + "github.com/goccy/go-json" +) + +type User struct { + ID int64 + Name string + Age int + Address UserAddressResolver +} + +type UserAddress struct { + UserID int64 + PostCode string + City string + Address1 string + Address2 string +} + +type UserRepository struct { + uaRepo *UserAddressRepository +} + +func NewUserRepository() *UserRepository { + return &UserRepository{ + uaRepo: NewUserAddressRepository(), + } +} + +type UserAddressRepository struct{} + +func NewUserAddressRepository() *UserAddressRepository { + return &UserAddressRepository{} +} + +type UserAddressResolver func(context.Context) (*UserAddress, error) + +func (resolver UserAddressResolver) MarshalJSON(ctx context.Context) ([]byte, error) { + address, err := resolver(ctx) + if err != nil { + return nil, err + } + return json.MarshalContext(ctx, address) +} + +func (r *UserRepository) FindByID(ctx context.Context, id int64) (*User, error) { + user := &User{ID: id, Name: "Ken", Age: 20} + // resolve relation from User to UserAddress + user.Address = func(ctx context.Context) (*UserAddress, error) { + return r.uaRepo.FindByUserID(ctx, user.ID) + } + return user, nil +} + +func (*UserAddressRepository) FindByUserID(ctx context.Context, id int64) (*UserAddress, error) { + return &UserAddress{ + UserID: id, + City: "A", + Address1: "foo", + Address2: "bar", + }, nil +} + +func Example_fieldQuery() { + ctx := context.Background() + userRepo := NewUserRepository() + user, err := userRepo.FindByID(ctx, 1) + if err != nil { + log.Fatal(err) + } + query, err := json.BuildFieldQuery( + "Name", + "Age", + json.BuildSubFieldQuery("Address").Fields( + "City", + ), + ) + if err != nil { + log.Fatal(err) + } + ctx = json.SetFieldQueryToContext(ctx, query) + b, err := json.MarshalContext(ctx, user) + if err != nil { + log.Fatal(err) + } + fmt.Println(string(b)) + + // Output: + // {"Name":"Ken","Age":20,"Address":{"City":"A"}} +}