Supports dynamic filtering of struct fields

This commit is contained in:
Masaaki Goshima 2022-01-03 12:33:51 +09:00
parent b0f4ac6d83
commit 89bcc3be86
No known key found for this signature in database
GPG Key ID: 6A53785055537153
21 changed files with 688 additions and 19 deletions

View File

@ -2,6 +2,7 @@ package benchmark
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"testing" "testing"
@ -835,3 +836,80 @@ func Benchmark_Encode_MarshalJSON_GoJson(b *testing.B) {
} }
} }
} }
type queryTestX struct {
XA int
XB string
XC *queryTestY
XD bool
XE float32
}
type queryTestY struct {
YA int
YB string
YC bool
YD float32
}
func Benchmark_Encode_FilterByMap(b *testing.B) {
v := &queryTestX{
XA: 1,
XB: "xb",
XC: &queryTestY{
YA: 2,
YB: "yb",
YC: true,
YD: 4,
},
XD: true,
XE: 5,
}
b.ReportAllocs()
for i := 0; i < b.N; i++ {
filteredMap := map[string]interface{}{
"XA": v.XA,
"XB": v.XB,
"XC": map[string]interface{}{
"YA": v.XC.YA,
"YB": v.XC.YB,
},
}
if _, err := gojson.Marshal(filteredMap); err != nil {
b.Fatal(err)
}
}
}
func Benchmark_Encode_FilterByFieldQuery(b *testing.B) {
query, err := gojson.BuildFieldQuery(
"XA",
"XB",
gojson.BuildSubFieldQuery("XC").Fields(
"YA",
"YB",
),
)
if err != nil {
b.Fatal(err)
}
v := &queryTestX{
XA: 1,
XB: "xb",
XC: &queryTestY{
YA: 2,
YB: "yb",
YC: true,
YD: 4,
},
XD: true,
XE: 5,
}
ctx := gojson.SetFieldQueryToContext(context.Background(), query)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := gojson.MarshalContext(ctx, v); err != nil {
b.Fatal(err)
}
}
}

View File

@ -221,7 +221,7 @@ func encode(ctx *encoder.RuntimeContext, v interface{}) ([]byte, error) {
typ := header.typ typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(typeptr) codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -249,7 +249,7 @@ func encodeNoEscape(ctx *encoder.RuntimeContext, v interface{}) ([]byte, error)
typ := header.typ typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(typeptr) codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -276,7 +276,7 @@ func encodeIndent(ctx *encoder.RuntimeContext, v interface{}, prefix, indent str
typ := header.typ typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(typeptr) codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b
break break
} }
ctx.KeepRefs = append(ctx.KeepRefs, up) ctx.KeepRefs = append(ctx.KeepRefs, up)
ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ)))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -10,6 +10,7 @@ import (
type Code interface { type Code interface {
Kind() CodeKind Kind() CodeKind
ToOpcode(*compileContext) Opcodes ToOpcode(*compileContext) Opcodes
Filter(*FieldQuery) Code
} }
type AnonymousCode interface { type AnonymousCode interface {
@ -82,6 +83,10 @@ func (c *IntCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *IntCode) Filter(_ *FieldQuery) Code {
return c
}
type UintCode struct { type UintCode struct {
typ *runtime.Type typ *runtime.Type
bitSize uint8 bitSize uint8
@ -108,6 +113,10 @@ func (c *UintCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *UintCode) Filter(_ *FieldQuery) Code {
return c
}
type FloatCode struct { type FloatCode struct {
typ *runtime.Type typ *runtime.Type
bitSize uint8 bitSize uint8
@ -140,6 +149,10 @@ func (c *FloatCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *FloatCode) Filter(_ *FieldQuery) Code {
return c
}
type StringCode struct { type StringCode struct {
typ *runtime.Type typ *runtime.Type
isPtr bool isPtr bool
@ -169,6 +182,10 @@ func (c *StringCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *StringCode) Filter(_ *FieldQuery) Code {
return c
}
type BoolCode struct { type BoolCode struct {
typ *runtime.Type typ *runtime.Type
isPtr bool isPtr bool
@ -190,6 +207,10 @@ func (c *BoolCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *BoolCode) Filter(_ *FieldQuery) Code {
return c
}
type BytesCode struct { type BytesCode struct {
typ *runtime.Type typ *runtime.Type
isPtr bool isPtr bool
@ -211,6 +232,10 @@ func (c *BytesCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *BytesCode) Filter(_ *FieldQuery) Code {
return c
}
type SliceCode struct { type SliceCode struct {
typ *runtime.Type typ *runtime.Type
value Code value Code
@ -245,6 +270,10 @@ func (c *SliceCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{header}.Add(codes...).Add(elemCode).Add(end) return Opcodes{header}.Add(codes...).Add(elemCode).Add(end)
} }
func (c *SliceCode) Filter(_ *FieldQuery) Code {
return c
}
type ArrayCode struct { type ArrayCode struct {
typ *runtime.Type typ *runtime.Type
value Code value Code
@ -286,6 +315,10 @@ func (c *ArrayCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{header}.Add(codes...).Add(elemCode).Add(end) return Opcodes{header}.Add(codes...).Add(elemCode).Add(end)
} }
func (c *ArrayCode) Filter(_ *FieldQuery) Code {
return c
}
type MapCode struct { type MapCode struct {
typ *runtime.Type typ *runtime.Type
key Code key Code
@ -332,6 +365,10 @@ func (c *MapCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{header}.Add(keyCodes...).Add(value).Add(valueCodes...).Add(key).Add(end) return Opcodes{header}.Add(keyCodes...).Add(value).Add(valueCodes...).Add(key).Add(end)
} }
func (c *MapCode) Filter(_ *FieldQuery) Code {
return c
}
type StructCode struct { type StructCode struct {
typ *runtime.Type typ *runtime.Type
fields []*StructFieldCode fields []*StructFieldCode
@ -520,6 +557,45 @@ func (c *StructCode) enableIndirect() {
structCode.enableIndirect() structCode.enableIndirect()
} }
func (c *StructCode) Filter(query *FieldQuery) Code {
fieldMap := map[string]*FieldQuery{}
for _, field := range query.Fields {
fieldMap[field.Name] = field
}
fields := make([]*StructFieldCode, 0, len(c.fields))
for _, field := range c.fields {
query, exists := fieldMap[field.key]
if !exists {
continue
}
fieldCode := &StructFieldCode{
typ: field.typ,
key: field.key,
tag: field.tag,
value: field.value,
offset: field.offset,
isAnonymous: field.isAnonymous,
isTaggedKey: field.isTaggedKey,
isNilableType: field.isNilableType,
isNilCheck: field.isNilCheck,
isAddrForMarshaler: field.isAddrForMarshaler,
isNextOpPtrType: field.isNextOpPtrType,
}
if len(query.Fields) > 0 {
fieldCode.value = fieldCode.value.Filter(query)
}
fields = append(fields, fieldCode)
}
return &StructCode{
typ: c.typ,
fields: fields,
isPtr: c.isPtr,
disableIndirectConversion: c.disableIndirectConversion,
isIndirect: c.isIndirect,
isRecursive: c.isRecursive,
}
}
type StructFieldCode struct { type StructFieldCode struct {
typ *runtime.Type typ *runtime.Type
key string key string
@ -532,6 +608,7 @@ type StructFieldCode struct {
isNilCheck bool isNilCheck bool
isAddrForMarshaler bool isAddrForMarshaler bool
isNextOpPtrType bool isNextOpPtrType bool
isMarshalerContext bool
} }
func (c *StructFieldCode) getStruct() *StructCode { func (c *StructFieldCode) getStruct() *StructCode {
@ -574,8 +651,12 @@ func (c *StructFieldCode) headerOpcodes(ctx *compileContext, field *Opcode, valu
value := valueCodes.First() value := valueCodes.First()
op := optimizeStructHeader(value, c.tag) op := optimizeStructHeader(value, c.tag)
field.Op = op field.Op = op
if value.Flags&MarshalerContextFlags != 0 {
field.Flags |= MarshalerContextFlags
}
field.NumBitSize = value.NumBitSize field.NumBitSize = value.NumBitSize
field.PtrNum = value.PtrNum field.PtrNum = value.PtrNum
field.FieldQuery = value.FieldQuery
fieldCodes := Opcodes{field} fieldCodes := Opcodes{field}
if op.IsMultipleOpHead() { if op.IsMultipleOpHead() {
field.Next = value field.Next = value
@ -590,8 +671,12 @@ func (c *StructFieldCode) fieldOpcodes(ctx *compileContext, field *Opcode, value
value := valueCodes.First() value := valueCodes.First()
op := optimizeStructField(value, c.tag) op := optimizeStructField(value, c.tag)
field.Op = op field.Op = op
if value.Flags&MarshalerContextFlags != 0 {
field.Flags |= MarshalerContextFlags
}
field.NumBitSize = value.NumBitSize field.NumBitSize = value.NumBitSize
field.PtrNum = value.PtrNum field.PtrNum = value.PtrNum
field.FieldQuery = value.FieldQuery
fieldCodes := Opcodes{field} fieldCodes := Opcodes{field}
if op.IsMultipleOpField() { if op.IsMultipleOpField() {
@ -645,6 +730,9 @@ func (c *StructFieldCode) flags() OpFlags {
if c.isAnonymous { if c.isAnonymous {
flags |= AnonymousKeyFlags flags |= AnonymousKeyFlags
} }
if c.isMarshalerContext {
flags |= MarshalerContextFlags
}
return flags return flags
} }
@ -726,6 +814,7 @@ func isEnableStructEndOptimization(value Code) bool {
type InterfaceCode struct { type InterfaceCode struct {
typ *runtime.Type typ *runtime.Type
fieldQuery *FieldQuery
isPtr bool isPtr bool
} }
@ -741,6 +830,7 @@ func (c *InterfaceCode) ToOpcode(ctx *compileContext) Opcodes {
default: default:
code = newOpCode(ctx, c.typ, OpInterface) code = newOpCode(ctx, c.typ, OpInterface)
} }
code.FieldQuery = c.fieldQuery
if c.typ.NumMethod() > 0 { if c.typ.NumMethod() > 0 {
code.Flags |= NonEmptyInterfaceFlags code.Flags |= NonEmptyInterfaceFlags
} }
@ -748,8 +838,17 @@ func (c *InterfaceCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *InterfaceCode) Filter(query *FieldQuery) Code {
return &InterfaceCode{
typ: c.typ,
fieldQuery: query,
isPtr: c.isPtr,
}
}
type MarshalJSONCode struct { type MarshalJSONCode struct {
typ *runtime.Type typ *runtime.Type
fieldQuery *FieldQuery
isAddrForMarshaler bool isAddrForMarshaler bool
isNilableType bool isNilableType bool
isMarshalerContext bool isMarshalerContext bool
@ -761,6 +860,7 @@ func (c *MarshalJSONCode) Kind() CodeKind {
func (c *MarshalJSONCode) ToOpcode(ctx *compileContext) Opcodes { func (c *MarshalJSONCode) ToOpcode(ctx *compileContext) Opcodes {
code := newOpCode(ctx, c.typ, OpMarshalJSON) code := newOpCode(ctx, c.typ, OpMarshalJSON)
code.FieldQuery = c.fieldQuery
if c.isAddrForMarshaler { if c.isAddrForMarshaler {
code.Flags |= AddrForMarshalerFlags code.Flags |= AddrForMarshalerFlags
} }
@ -776,8 +876,19 @@ func (c *MarshalJSONCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *MarshalJSONCode) Filter(query *FieldQuery) Code {
return &MarshalJSONCode{
typ: c.typ,
fieldQuery: query,
isAddrForMarshaler: c.isAddrForMarshaler,
isNilableType: c.isNilableType,
isMarshalerContext: c.isMarshalerContext,
}
}
type MarshalTextCode struct { type MarshalTextCode struct {
typ *runtime.Type typ *runtime.Type
fieldQuery *FieldQuery
isAddrForMarshaler bool isAddrForMarshaler bool
isNilableType bool isNilableType bool
} }
@ -788,6 +899,7 @@ func (c *MarshalTextCode) Kind() CodeKind {
func (c *MarshalTextCode) ToOpcode(ctx *compileContext) Opcodes { func (c *MarshalTextCode) ToOpcode(ctx *compileContext) Opcodes {
code := newOpCode(ctx, c.typ, OpMarshalText) code := newOpCode(ctx, c.typ, OpMarshalText)
code.FieldQuery = c.fieldQuery
if c.isAddrForMarshaler { if c.isAddrForMarshaler {
code.Flags |= AddrForMarshalerFlags code.Flags |= AddrForMarshalerFlags
} }
@ -800,6 +912,15 @@ func (c *MarshalTextCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *MarshalTextCode) Filter(query *FieldQuery) Code {
return &MarshalTextCode{
typ: c.typ,
fieldQuery: query,
isAddrForMarshaler: c.isAddrForMarshaler,
isNilableType: c.isNilableType,
}
}
type PtrCode struct { type PtrCode struct {
typ *runtime.Type typ *runtime.Type
value Code value Code
@ -830,6 +951,14 @@ func (c *PtrCode) ToAnonymousOpcode(ctx *compileContext) Opcodes {
return codes return codes
} }
func (c *PtrCode) Filter(query *FieldQuery) Code {
return &PtrCode{
typ: c.typ,
value: c.value.Filter(query),
ptrNum: c.ptrNum,
}
}
func convertPtrOp(code *Opcode) OpType { func convertPtrOp(code *Opcode) OpType {
ptrHeadOp := code.Op.HeadToPtrHead() ptrHeadOp := code.Op.HeadToPtrHead()
if code.Op != ptrHeadOp { if code.Op != ptrHeadOp {

View File

@ -63,6 +63,27 @@ func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) {
return codeSet, nil return codeSet, nil
} }
func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*OpcodeSet, error) {
if (ctx.Option.Flag & ContextOption) == 0 {
return codeSet, nil
}
query := FieldQueryFromContext(ctx.Option.Context)
if query == nil {
return codeSet, nil
}
ctx.Option.Flag |= FieldQueryOption
cacheCodeSet := codeSet.getQueryCache(query.Hash())
if cacheCodeSet != nil {
return cacheCodeSet, nil
}
queryCodeSet, err := newCompiler().codeToOpcodeSet(codeSet.Type, codeSet.Code.Filter(query))
if err != nil {
return nil, err
}
codeSet.setQueryCache(query.Hash(), queryCodeSet)
return queryCodeSet, nil
}
type Compiler struct { type Compiler struct {
structTypeToCode map[uintptr]*StructCode structTypeToCode map[uintptr]*StructCode
} }
@ -80,6 +101,10 @@ func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c.codeToOpcodeSet(typ, code)
}
func (c *Compiler) codeToOpcodeSet(typ *runtime.Type, code Code) (*OpcodeSet, error) {
noescapeKeyCode := c.codeToOpcode(&compileContext{ noescapeKeyCode := c.codeToOpcode(&compileContext{
structTypeToCodes: map[uintptr]Opcodes{}, structTypeToCodes: map[uintptr]Opcodes{},
recursiveCodes: &Opcodes{}, recursiveCodes: &Opcodes{},
@ -107,6 +132,8 @@ func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) {
InterfaceEscapeKeyCode: interfaceEscapeKeyCode, InterfaceEscapeKeyCode: interfaceEscapeKeyCode,
CodeLength: codeLength, CodeLength: codeLength,
EndCode: ToEndCode(interfaceNoescapeKeyCode), EndCode: ToEndCode(interfaceNoescapeKeyCode),
Code: code,
QueryCache: map[string]*OpcodeSet{},
}, nil }, nil
} }

View File

@ -3,18 +3,26 @@
package encoder package encoder
func CompileToGetCodeSet(typeptr uintptr) (*OpcodeSet, error) { func CompileToGetCodeSet(ctx *RuntimeContext, typeptr uintptr) (*OpcodeSet, error) {
if typeptr > typeAddr.MaxTypeAddr { if typeptr > typeAddr.MaxTypeAddr {
return compileToGetCodeSetSlowPath(typeptr) return compileToGetCodeSetSlowPath(typeptr)
} }
index := (typeptr - typeAddr.BaseTypeAddr) >> typeAddr.AddrShift index := (typeptr - typeAddr.BaseTypeAddr) >> typeAddr.AddrShift
if codeSet := cachedOpcodeSets[index]; codeSet != nil { if codeSet := cachedOpcodeSets[index]; codeSet != nil {
return codeSet, nil filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet)
if err != nil {
return nil, err
}
return filtered, nil
} }
codeSet, err := newCompiler().compile(typeptr) codeSet, err := newCompiler().compile(typeptr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cachedOpcodeSets[index] = codeSet filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet)
return codeSet, nil if err != nil {
return nil, err
}
cachedOpcodeSets[index] = codeSet
return filtered, nil
} }

View File

@ -9,15 +9,20 @@ import (
var setsMu sync.RWMutex var setsMu sync.RWMutex
func CompileToGetCodeSet(typeptr uintptr) (*OpcodeSet, error) { func CompileToGetCodeSet(ctx *RuntimeContext, typeptr uintptr) (*OpcodeSet, error) {
if typeptr > typeAddr.MaxTypeAddr { if typeptr > typeAddr.MaxTypeAddr {
return compileToGetCodeSetSlowPath(typeptr) return compileToGetCodeSetSlowPath(typeptr)
} }
index := (typeptr - typeAddr.BaseTypeAddr) >> typeAddr.AddrShift index := (typeptr - typeAddr.BaseTypeAddr) >> typeAddr.AddrShift
setsMu.RLock() setsMu.RLock()
if codeSet := cachedOpcodeSets[index]; codeSet != nil { if codeSet := cachedOpcodeSets[index]; codeSet != nil {
filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet)
if err != nil {
setsMu.RUnlock() setsMu.RUnlock()
return codeSet, nil return nil, err
}
setsMu.RUnlock()
return filtered, nil
} }
setsMu.RUnlock() setsMu.RUnlock()
@ -25,8 +30,12 @@ func CompileToGetCodeSet(typeptr uintptr) (*OpcodeSet, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet)
if err != nil {
return nil, err
}
setsMu.Lock() setsMu.Lock()
cachedOpcodeSets[index] = codeSet cachedOpcodeSets[index] = codeSet
setsMu.Unlock() setsMu.Unlock()
return codeSet, nil return filtered, nil
} }

View File

@ -6,11 +6,13 @@ import (
) )
func TestDumpOpcode(t *testing.T) { func TestDumpOpcode(t *testing.T) {
ctx := TakeRuntimeContext()
defer ReleaseRuntimeContext(ctx)
var v interface{} = 1 var v interface{} = 1
header := (*emptyInterface)(unsafe.Pointer(&v)) header := (*emptyInterface)(unsafe.Pointer(&v))
typ := header.typ typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := CompileToGetCodeSet(typeptr) codeSet, err := CompileToGetCodeSet(ctx, typeptr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -101,6 +101,22 @@ type OpcodeSet struct {
InterfaceEscapeKeyCode *Opcode InterfaceEscapeKeyCode *Opcode
CodeLength int CodeLength int
EndCode *Opcode EndCode *Opcode
Code Code
QueryCache map[string]*OpcodeSet
cacheMu sync.RWMutex
}
func (s *OpcodeSet) getQueryCache(hash string) *OpcodeSet {
s.cacheMu.RLock()
codeSet := s.QueryCache[hash]
s.cacheMu.RUnlock()
return codeSet
}
func (s *OpcodeSet) setQueryCache(hash string, codeSet *OpcodeSet) {
s.cacheMu.Lock()
s.QueryCache[hash] = codeSet
s.cacheMu.Unlock()
} }
type CompiledCode struct { type CompiledCode struct {
@ -395,7 +411,11 @@ func AppendMarshalJSON(ctx *RuntimeContext, code *Opcode, b []byte, v interface{
if !ok { if !ok {
return AppendNull(ctx, b), nil return AppendNull(ctx, b), nil
} }
b, err := marshaler.MarshalJSON(ctx.Option.Context) stdctx := ctx.Option.Context
if ctx.Option.Flag&FieldQueryOption != 0 {
stdctx = SetFieldQueryToContext(stdctx, code.FieldQuery)
}
b, err := marshaler.MarshalJSON(stdctx)
if err != nil { if err != nil {
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err} return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
} }

View File

@ -39,6 +39,7 @@ type Opcode struct {
Type *runtime.Type // go type Type *runtime.Type // go type
Jmp *CompiledCode // for recursive call Jmp *CompiledCode // for recursive call
FieldQuery *FieldQuery // field query for Interface / MarshalJSON / MarshalText
ElemIdx uint32 // offset to access array/slice elem ElemIdx uint32 // offset to access array/slice elem
Length uint32 // offset to access slice length or array length Length uint32 // offset to access slice length or array length
Indent uint32 // indent number Indent uint32 // indent number
@ -333,6 +334,7 @@ func copyOpcode(code *Opcode) *Opcode {
Idx: c.Idx, Idx: c.Idx,
Offset: c.Offset, Offset: c.Offset,
Type: c.Type, Type: c.Type,
FieldQuery: c.FieldQuery,
DisplayIdx: c.DisplayIdx, DisplayIdx: c.DisplayIdx,
DisplayKey: c.DisplayKey, DisplayKey: c.DisplayKey,
ElemIdx: c.ElemIdx, ElemIdx: c.ElemIdx,

View File

@ -12,6 +12,7 @@ const (
ColorizeOption ColorizeOption
ContextOption ContextOption
NormalizeUTF8Option NormalizeUTF8Option
FieldQueryOption
) )
type Option struct { type Option struct {

135
internal/encoder/query.go Normal file
View File

@ -0,0 +1,135 @@
package encoder
import (
"context"
"fmt"
"reflect"
)
var (
Marshal func(interface{}) ([]byte, error)
Unmarshal func([]byte, interface{}) error
)
type FieldQuery struct {
Name string
Fields []*FieldQuery
hash string
}
func (q *FieldQuery) Hash() string {
if q.hash != "" {
return q.hash
}
b, _ := Marshal(q)
q.hash = string(b)
return q.hash
}
func (q *FieldQuery) MarshalJSON() ([]byte, error) {
if q.Name != "" {
if len(q.Fields) > 0 {
return Marshal(map[string][]*FieldQuery{q.Name: q.Fields})
}
return Marshal(q.Name)
}
return Marshal(q.Fields)
}
func (q *FieldQuery) QueryString() (FieldQueryString, error) {
b, err := Marshal(q)
if err != nil {
return "", err
}
return FieldQueryString(b), nil
}
type FieldQueryString string
func (s FieldQueryString) Build() (*FieldQuery, error) {
var query interface{}
if err := Unmarshal([]byte(s), &query); err != nil {
return nil, err
}
return s.build(reflect.ValueOf(query))
}
func (s FieldQueryString) build(v reflect.Value) (*FieldQuery, error) {
switch v.Type().Kind() {
case reflect.String:
return s.buildString(v)
case reflect.Map:
return s.buildMap(v)
case reflect.Slice:
return s.buildSlice(v)
case reflect.Interface:
return s.build(reflect.ValueOf(v.Interface()))
}
return nil, fmt.Errorf("failed to build field query")
}
func (s FieldQueryString) buildString(v reflect.Value) (*FieldQuery, error) {
b := []byte(v.String())
switch b[0] {
case '[', '{':
var query interface{}
if err := Unmarshal(b, &query); err != nil {
return nil, err
}
if str, ok := query.(string); ok {
return &FieldQuery{Name: str}, nil
}
return s.build(reflect.ValueOf(query))
}
return &FieldQuery{Name: string(b)}, nil
}
func (s FieldQueryString) buildSlice(v reflect.Value) (*FieldQuery, error) {
fields := make([]*FieldQuery, 0, v.Len())
for i := 0; i < v.Len(); i++ {
def, err := s.build(v.Index(i))
if err != nil {
return nil, err
}
fields = append(fields, def)
}
return &FieldQuery{Fields: fields}, nil
}
func (s FieldQueryString) buildMap(v reflect.Value) (*FieldQuery, error) {
keys := v.MapKeys()
if len(keys) != 1 {
return nil, fmt.Errorf("failed to build field query object")
}
key := keys[0]
if key.Type().Kind() != reflect.String {
return nil, fmt.Errorf("failed to build field query. invalid object key type")
}
name := key.String()
def, err := s.build(v.MapIndex(key))
if err != nil {
return nil, err
}
return &FieldQuery{
Name: name,
Fields: def.Fields,
}, nil
}
type queryKey struct{}
func FieldQueryFromContext(ctx context.Context) *FieldQuery {
query := ctx.Value(queryKey{})
if query == nil {
return nil
}
q, ok := query.(*FieldQuery)
if !ok {
return nil
}
return q
}
func SetFieldQueryToContext(ctx context.Context, query *FieldQuery) context.Context {
return context.WithValue(ctx, queryKey{}, query)
}

View File

@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b
break break
} }
ctx.KeepRefs = append(ctx.KeepRefs, up) ctx.KeepRefs = append(ctx.KeepRefs, up)
ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ)))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b
break break
} }
ctx.KeepRefs = append(ctx.KeepRefs, up) ctx.KeepRefs = append(ctx.KeepRefs, up)
ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ)))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b
break break
} }
ctx.KeepRefs = append(ctx.KeepRefs, up) ctx.KeepRefs = append(ctx.KeepRefs, up)
ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ)))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b
break break
} }
ctx.KeepRefs = append(ctx.KeepRefs, up) ctx.KeepRefs = append(ctx.KeepRefs, up)
ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ)))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -364,3 +364,8 @@ func Valid(data []byte) bool {
} }
return decoder.InputOffset() >= int64(len(data)) return decoder.InputOffset() >= int64(len(data))
} }
func init() {
encoder.Marshal = Marshal
encoder.Unmarshal = Unmarshal
}

36
query.go Normal file
View File

@ -0,0 +1,36 @@
package json
import (
"github.com/goccy/go-json/internal/encoder"
)
type (
FieldQuery = encoder.FieldQuery
FieldQueryString = encoder.FieldQueryString
)
var (
FieldQueryFromContext = encoder.FieldQueryFromContext
SetFieldQueryToContext = encoder.SetFieldQueryToContext
)
func BuildFieldQuery(fields ...FieldQueryString) (*FieldQuery, error) {
query, err := Marshal(fields)
if err != nil {
return nil, err
}
return FieldQueryString(query).Build()
}
func BuildSubFieldQuery(name string) *SubFieldQuery {
return &SubFieldQuery{name: name}
}
type SubFieldQuery struct {
name string
}
func (q *SubFieldQuery) Fields(fields ...FieldQueryString) FieldQueryString {
query, _ := Marshal(map[string][]FieldQueryString{q.name: fields})
return FieldQueryString(query)
}

121
query_test.go Normal file
View File

@ -0,0 +1,121 @@
package json_test
import (
"context"
"reflect"
"testing"
"github.com/goccy/go-json"
)
type queryTestX struct {
XA int
XB string
XC *queryTestY
XD bool
XE float32
}
type queryTestY struct {
YA int
YB string
YC *queryTestZ
YD bool
YE float32
}
type queryTestZ struct {
ZA string
ZB bool
ZC int
}
func (z *queryTestZ) MarshalJSON(ctx context.Context) ([]byte, error) {
type _queryTestZ queryTestZ
return json.MarshalContext(ctx, (*_queryTestZ)(z))
}
func TestFieldQuery(t *testing.T) {
query, err := json.BuildFieldQuery(
"XA",
"XB",
json.BuildSubFieldQuery("XC").Fields(
"YA",
"YB",
json.BuildSubFieldQuery("YC").Fields(
"ZA",
"ZB",
),
),
)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(query, &json.FieldQuery{
Fields: []*json.FieldQuery{
{
Name: "XA",
},
{
Name: "XB",
},
{
Name: "XC",
Fields: []*json.FieldQuery{
{
Name: "YA",
},
{
Name: "YB",
},
{
Name: "YC",
Fields: []*json.FieldQuery{
{
Name: "ZA",
},
{
Name: "ZB",
},
},
},
},
},
},
}) {
t.Fatal("cannot get query")
}
queryStr, err := query.QueryString()
if err != nil {
t.Fatal(err)
}
if queryStr != `["XA","XB",{"XC":["YA","YB",{"YC":["ZA","ZB"]}]}]` {
t.Fatalf("failed to create query string. %s", queryStr)
}
ctx := json.SetFieldQueryToContext(context.Background(), query)
b, err := json.MarshalContext(ctx, &queryTestX{
XA: 1,
XB: "xb",
XC: &queryTestY{
YA: 2,
YB: "yb",
YC: &queryTestZ{
ZA: "za",
ZB: true,
ZC: 3,
},
YD: true,
YE: 4,
},
XD: true,
XE: 5,
})
if err != nil {
t.Fatal(err)
}
expected := `{"XA":1,"XB":"xb","XC":{"YA":2,"YB":"yb","YC":{"ZA":"za","ZB":true}}}`
got := string(b)
if expected != got {
t.Fatalf("failed to encode with field query: expected %q but got %q", expected, got)
}
}

View File

@ -11,7 +11,7 @@ func TestOpcodeSize(t *testing.T) {
const uintptrSize = 4 << (^uintptr(0) >> 63) const uintptrSize = 4 << (^uintptr(0) >> 63)
if uintptrSize == 8 { if uintptrSize == 8 {
size := unsafe.Sizeof(encoder.Opcode{}) size := unsafe.Sizeof(encoder.Opcode{})
if size != 112 { if size != 120 {
t.Fatalf("unexpected opcode size: expected 112bytes but got %dbytes", size) t.Fatalf("unexpected opcode size: expected 112bytes but got %dbytes", size)
} }
} }

View File

@ -0,0 +1,96 @@
package json_test
import (
"context"
"fmt"
"log"
"github.com/goccy/go-json"
)
type User struct {
ID int64
Name string
Age int
Address UserAddressResolver
}
type UserAddress struct {
UserID int64
PostCode string
City string
Address1 string
Address2 string
}
type UserRepository struct {
uaRepo *UserAddressRepository
}
func NewUserRepository() *UserRepository {
return &UserRepository{
uaRepo: NewUserAddressRepository(),
}
}
type UserAddressRepository struct{}
func NewUserAddressRepository() *UserAddressRepository {
return &UserAddressRepository{}
}
type UserAddressResolver func(context.Context) (*UserAddress, error)
func (resolver UserAddressResolver) MarshalJSON(ctx context.Context) ([]byte, error) {
address, err := resolver(ctx)
if err != nil {
return nil, err
}
return json.MarshalContext(ctx, address)
}
func (r *UserRepository) FindByID(ctx context.Context, id int64) (*User, error) {
user := &User{ID: id, Name: "Ken", Age: 20}
// resolve relation from User to UserAddress
user.Address = func(ctx context.Context) (*UserAddress, error) {
return r.uaRepo.FindByUserID(ctx, user.ID)
}
return user, nil
}
func (*UserAddressRepository) FindByUserID(ctx context.Context, id int64) (*UserAddress, error) {
return &UserAddress{
UserID: id,
City: "A",
Address1: "foo",
Address2: "bar",
}, nil
}
func Example_fieldQuery() {
ctx := context.Background()
userRepo := NewUserRepository()
user, err := userRepo.FindByID(ctx, 1)
if err != nil {
log.Fatal(err)
}
query, err := json.BuildFieldQuery(
"Name",
"Age",
json.BuildSubFieldQuery("Address").Fields(
"City",
),
)
if err != nil {
log.Fatal(err)
}
ctx = json.SetFieldQueryToContext(ctx, query)
b, err := json.MarshalContext(ctx, user)
if err != nil {
log.Fatal(err)
}
fmt.Println(string(b))
// Output:
// {"Name":"Ken","Age":20,"Address":{"City":"A"}}
}