Merge pull request #314 from goccy/feature/json-field-query

Supports dynamic filtering of struct fields
This commit is contained in:
Masaaki Goshima 2022-01-03 15:48:53 +09:00 committed by GitHub
commit 0707c2a188
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 700 additions and 20 deletions

View File

@ -23,7 +23,6 @@ Fast JSON encoder/decoder compatible with encoding/json for Go
We are accepting requests for features that will be implemented between v0.8.0 and v.1.0.0. We are accepting requests for features that will be implemented between v0.8.0 and v.1.0.0.
If you have the API you need, please submit your issue [here](https://github.com/goccy/go-json/issues). If you have the API you need, please submit your issue [here](https://github.com/goccy/go-json/issues).
For example, I'm thinking of supporting `context.Context` of `json.Marshaler` and decoding using JSON Path.
# Features # Features
@ -32,6 +31,7 @@ For example, I'm thinking of supporting `context.Context` of `json.Marshaler` an
- Flexible customization with options - Flexible customization with options
- Coloring the encoded string - Coloring the encoded string
- Can propagate context.Context to `MarshalJSON` or `UnmarshalJSON` - Can propagate context.Context to `MarshalJSON` or `UnmarshalJSON`
- Can dynamically filter the fields of the structure type-safely
# Installation # Installation

View File

@ -2,6 +2,7 @@ package benchmark
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"testing" "testing"
@ -835,3 +836,80 @@ func Benchmark_Encode_MarshalJSON_GoJson(b *testing.B) {
} }
} }
} }
type queryTestX struct {
XA int
XB string
XC *queryTestY
XD bool
XE float32
}
type queryTestY struct {
YA int
YB string
YC bool
YD float32
}
func Benchmark_Encode_FilterByMap(b *testing.B) {
v := &queryTestX{
XA: 1,
XB: "xb",
XC: &queryTestY{
YA: 2,
YB: "yb",
YC: true,
YD: 4,
},
XD: true,
XE: 5,
}
b.ReportAllocs()
for i := 0; i < b.N; i++ {
filteredMap := map[string]interface{}{
"XA": v.XA,
"XB": v.XB,
"XC": map[string]interface{}{
"YA": v.XC.YA,
"YB": v.XC.YB,
},
}
if _, err := gojson.Marshal(filteredMap); err != nil {
b.Fatal(err)
}
}
}
func Benchmark_Encode_FilterByFieldQuery(b *testing.B) {
query, err := gojson.BuildFieldQuery(
"XA",
"XB",
gojson.BuildSubFieldQuery("XC").Fields(
"YA",
"YB",
),
)
if err != nil {
b.Fatal(err)
}
v := &queryTestX{
XA: 1,
XB: "xb",
XC: &queryTestY{
YA: 2,
YB: "yb",
YC: true,
YD: 4,
},
XD: true,
XE: 5,
}
ctx := gojson.SetFieldQueryToContext(context.Background(), query)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := gojson.MarshalContext(ctx, v); err != nil {
b.Fatal(err)
}
}
}

View File

@ -221,7 +221,7 @@ func encode(ctx *encoder.RuntimeContext, v interface{}) ([]byte, error) {
typ := header.typ typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(typeptr) codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -249,7 +249,7 @@ func encodeNoEscape(ctx *encoder.RuntimeContext, v interface{}) ([]byte, error)
typ := header.typ typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(typeptr) codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -276,7 +276,7 @@ func encodeIndent(ctx *encoder.RuntimeContext, v interface{}, prefix, indent str
typ := header.typ typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := encoder.CompileToGetCodeSet(typeptr) codeSet, err := encoder.CompileToGetCodeSet(ctx, typeptr)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b
break break
} }
ctx.KeepRefs = append(ctx.KeepRefs, up) ctx.KeepRefs = append(ctx.KeepRefs, up)
ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ)))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -10,6 +10,7 @@ import (
type Code interface { type Code interface {
Kind() CodeKind Kind() CodeKind
ToOpcode(*compileContext) Opcodes ToOpcode(*compileContext) Opcodes
Filter(*FieldQuery) Code
} }
type AnonymousCode interface { type AnonymousCode interface {
@ -82,6 +83,10 @@ func (c *IntCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *IntCode) Filter(_ *FieldQuery) Code {
return c
}
type UintCode struct { type UintCode struct {
typ *runtime.Type typ *runtime.Type
bitSize uint8 bitSize uint8
@ -108,6 +113,10 @@ func (c *UintCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *UintCode) Filter(_ *FieldQuery) Code {
return c
}
type FloatCode struct { type FloatCode struct {
typ *runtime.Type typ *runtime.Type
bitSize uint8 bitSize uint8
@ -140,6 +149,10 @@ func (c *FloatCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *FloatCode) Filter(_ *FieldQuery) Code {
return c
}
type StringCode struct { type StringCode struct {
typ *runtime.Type typ *runtime.Type
isPtr bool isPtr bool
@ -169,6 +182,10 @@ func (c *StringCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *StringCode) Filter(_ *FieldQuery) Code {
return c
}
type BoolCode struct { type BoolCode struct {
typ *runtime.Type typ *runtime.Type
isPtr bool isPtr bool
@ -190,6 +207,10 @@ func (c *BoolCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *BoolCode) Filter(_ *FieldQuery) Code {
return c
}
type BytesCode struct { type BytesCode struct {
typ *runtime.Type typ *runtime.Type
isPtr bool isPtr bool
@ -211,6 +232,10 @@ func (c *BytesCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *BytesCode) Filter(_ *FieldQuery) Code {
return c
}
type SliceCode struct { type SliceCode struct {
typ *runtime.Type typ *runtime.Type
value Code value Code
@ -245,6 +270,10 @@ func (c *SliceCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{header}.Add(codes...).Add(elemCode).Add(end) return Opcodes{header}.Add(codes...).Add(elemCode).Add(end)
} }
func (c *SliceCode) Filter(_ *FieldQuery) Code {
return c
}
type ArrayCode struct { type ArrayCode struct {
typ *runtime.Type typ *runtime.Type
value Code value Code
@ -286,6 +315,10 @@ func (c *ArrayCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{header}.Add(codes...).Add(elemCode).Add(end) return Opcodes{header}.Add(codes...).Add(elemCode).Add(end)
} }
func (c *ArrayCode) Filter(_ *FieldQuery) Code {
return c
}
type MapCode struct { type MapCode struct {
typ *runtime.Type typ *runtime.Type
key Code key Code
@ -332,6 +365,10 @@ func (c *MapCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{header}.Add(keyCodes...).Add(value).Add(valueCodes...).Add(key).Add(end) return Opcodes{header}.Add(keyCodes...).Add(value).Add(valueCodes...).Add(key).Add(end)
} }
func (c *MapCode) Filter(_ *FieldQuery) Code {
return c
}
type StructCode struct { type StructCode struct {
typ *runtime.Type typ *runtime.Type
fields []*StructFieldCode fields []*StructFieldCode
@ -520,6 +557,45 @@ func (c *StructCode) enableIndirect() {
structCode.enableIndirect() structCode.enableIndirect()
} }
func (c *StructCode) Filter(query *FieldQuery) Code {
fieldMap := map[string]*FieldQuery{}
for _, field := range query.Fields {
fieldMap[field.Name] = field
}
fields := make([]*StructFieldCode, 0, len(c.fields))
for _, field := range c.fields {
query, exists := fieldMap[field.key]
if !exists {
continue
}
fieldCode := &StructFieldCode{
typ: field.typ,
key: field.key,
tag: field.tag,
value: field.value,
offset: field.offset,
isAnonymous: field.isAnonymous,
isTaggedKey: field.isTaggedKey,
isNilableType: field.isNilableType,
isNilCheck: field.isNilCheck,
isAddrForMarshaler: field.isAddrForMarshaler,
isNextOpPtrType: field.isNextOpPtrType,
}
if len(query.Fields) > 0 {
fieldCode.value = fieldCode.value.Filter(query)
}
fields = append(fields, fieldCode)
}
return &StructCode{
typ: c.typ,
fields: fields,
isPtr: c.isPtr,
disableIndirectConversion: c.disableIndirectConversion,
isIndirect: c.isIndirect,
isRecursive: c.isRecursive,
}
}
type StructFieldCode struct { type StructFieldCode struct {
typ *runtime.Type typ *runtime.Type
key string key string
@ -532,6 +608,7 @@ type StructFieldCode struct {
isNilCheck bool isNilCheck bool
isAddrForMarshaler bool isAddrForMarshaler bool
isNextOpPtrType bool isNextOpPtrType bool
isMarshalerContext bool
} }
func (c *StructFieldCode) getStruct() *StructCode { func (c *StructFieldCode) getStruct() *StructCode {
@ -574,8 +651,12 @@ func (c *StructFieldCode) headerOpcodes(ctx *compileContext, field *Opcode, valu
value := valueCodes.First() value := valueCodes.First()
op := optimizeStructHeader(value, c.tag) op := optimizeStructHeader(value, c.tag)
field.Op = op field.Op = op
if value.Flags&MarshalerContextFlags != 0 {
field.Flags |= MarshalerContextFlags
}
field.NumBitSize = value.NumBitSize field.NumBitSize = value.NumBitSize
field.PtrNum = value.PtrNum field.PtrNum = value.PtrNum
field.FieldQuery = value.FieldQuery
fieldCodes := Opcodes{field} fieldCodes := Opcodes{field}
if op.IsMultipleOpHead() { if op.IsMultipleOpHead() {
field.Next = value field.Next = value
@ -590,8 +671,12 @@ func (c *StructFieldCode) fieldOpcodes(ctx *compileContext, field *Opcode, value
value := valueCodes.First() value := valueCodes.First()
op := optimizeStructField(value, c.tag) op := optimizeStructField(value, c.tag)
field.Op = op field.Op = op
if value.Flags&MarshalerContextFlags != 0 {
field.Flags |= MarshalerContextFlags
}
field.NumBitSize = value.NumBitSize field.NumBitSize = value.NumBitSize
field.PtrNum = value.PtrNum field.PtrNum = value.PtrNum
field.FieldQuery = value.FieldQuery
fieldCodes := Opcodes{field} fieldCodes := Opcodes{field}
if op.IsMultipleOpField() { if op.IsMultipleOpField() {
@ -645,6 +730,9 @@ func (c *StructFieldCode) flags() OpFlags {
if c.isAnonymous { if c.isAnonymous {
flags |= AnonymousKeyFlags flags |= AnonymousKeyFlags
} }
if c.isMarshalerContext {
flags |= MarshalerContextFlags
}
return flags return flags
} }
@ -726,6 +814,7 @@ func isEnableStructEndOptimization(value Code) bool {
type InterfaceCode struct { type InterfaceCode struct {
typ *runtime.Type typ *runtime.Type
fieldQuery *FieldQuery
isPtr bool isPtr bool
} }
@ -741,6 +830,7 @@ func (c *InterfaceCode) ToOpcode(ctx *compileContext) Opcodes {
default: default:
code = newOpCode(ctx, c.typ, OpInterface) code = newOpCode(ctx, c.typ, OpInterface)
} }
code.FieldQuery = c.fieldQuery
if c.typ.NumMethod() > 0 { if c.typ.NumMethod() > 0 {
code.Flags |= NonEmptyInterfaceFlags code.Flags |= NonEmptyInterfaceFlags
} }
@ -748,8 +838,17 @@ func (c *InterfaceCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *InterfaceCode) Filter(query *FieldQuery) Code {
return &InterfaceCode{
typ: c.typ,
fieldQuery: query,
isPtr: c.isPtr,
}
}
type MarshalJSONCode struct { type MarshalJSONCode struct {
typ *runtime.Type typ *runtime.Type
fieldQuery *FieldQuery
isAddrForMarshaler bool isAddrForMarshaler bool
isNilableType bool isNilableType bool
isMarshalerContext bool isMarshalerContext bool
@ -761,6 +860,7 @@ func (c *MarshalJSONCode) Kind() CodeKind {
func (c *MarshalJSONCode) ToOpcode(ctx *compileContext) Opcodes { func (c *MarshalJSONCode) ToOpcode(ctx *compileContext) Opcodes {
code := newOpCode(ctx, c.typ, OpMarshalJSON) code := newOpCode(ctx, c.typ, OpMarshalJSON)
code.FieldQuery = c.fieldQuery
if c.isAddrForMarshaler { if c.isAddrForMarshaler {
code.Flags |= AddrForMarshalerFlags code.Flags |= AddrForMarshalerFlags
} }
@ -776,8 +876,19 @@ func (c *MarshalJSONCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *MarshalJSONCode) Filter(query *FieldQuery) Code {
return &MarshalJSONCode{
typ: c.typ,
fieldQuery: query,
isAddrForMarshaler: c.isAddrForMarshaler,
isNilableType: c.isNilableType,
isMarshalerContext: c.isMarshalerContext,
}
}
type MarshalTextCode struct { type MarshalTextCode struct {
typ *runtime.Type typ *runtime.Type
fieldQuery *FieldQuery
isAddrForMarshaler bool isAddrForMarshaler bool
isNilableType bool isNilableType bool
} }
@ -788,6 +899,7 @@ func (c *MarshalTextCode) Kind() CodeKind {
func (c *MarshalTextCode) ToOpcode(ctx *compileContext) Opcodes { func (c *MarshalTextCode) ToOpcode(ctx *compileContext) Opcodes {
code := newOpCode(ctx, c.typ, OpMarshalText) code := newOpCode(ctx, c.typ, OpMarshalText)
code.FieldQuery = c.fieldQuery
if c.isAddrForMarshaler { if c.isAddrForMarshaler {
code.Flags |= AddrForMarshalerFlags code.Flags |= AddrForMarshalerFlags
} }
@ -800,6 +912,15 @@ func (c *MarshalTextCode) ToOpcode(ctx *compileContext) Opcodes {
return Opcodes{code} return Opcodes{code}
} }
func (c *MarshalTextCode) Filter(query *FieldQuery) Code {
return &MarshalTextCode{
typ: c.typ,
fieldQuery: query,
isAddrForMarshaler: c.isAddrForMarshaler,
isNilableType: c.isNilableType,
}
}
type PtrCode struct { type PtrCode struct {
typ *runtime.Type typ *runtime.Type
value Code value Code
@ -830,6 +951,14 @@ func (c *PtrCode) ToAnonymousOpcode(ctx *compileContext) Opcodes {
return codes return codes
} }
func (c *PtrCode) Filter(query *FieldQuery) Code {
return &PtrCode{
typ: c.typ,
value: c.value.Filter(query),
ptrNum: c.ptrNum,
}
}
func convertPtrOp(code *Opcode) OpType { func convertPtrOp(code *Opcode) OpType {
ptrHeadOp := code.Op.HeadToPtrHead() ptrHeadOp := code.Op.HeadToPtrHead()
if code.Op != ptrHeadOp { if code.Op != ptrHeadOp {

View File

@ -63,6 +63,27 @@ func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) {
return codeSet, nil return codeSet, nil
} }
func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*OpcodeSet, error) {
if (ctx.Option.Flag & ContextOption) == 0 {
return codeSet, nil
}
query := FieldQueryFromContext(ctx.Option.Context)
if query == nil {
return codeSet, nil
}
ctx.Option.Flag |= FieldQueryOption
cacheCodeSet := codeSet.getQueryCache(query.Hash())
if cacheCodeSet != nil {
return cacheCodeSet, nil
}
queryCodeSet, err := newCompiler().codeToOpcodeSet(codeSet.Type, codeSet.Code.Filter(query))
if err != nil {
return nil, err
}
codeSet.setQueryCache(query.Hash(), queryCodeSet)
return queryCodeSet, nil
}
type Compiler struct { type Compiler struct {
structTypeToCode map[uintptr]*StructCode structTypeToCode map[uintptr]*StructCode
} }
@ -80,6 +101,10 @@ func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c.codeToOpcodeSet(typ, code)
}
func (c *Compiler) codeToOpcodeSet(typ *runtime.Type, code Code) (*OpcodeSet, error) {
noescapeKeyCode := c.codeToOpcode(&compileContext{ noescapeKeyCode := c.codeToOpcode(&compileContext{
structTypeToCodes: map[uintptr]Opcodes{}, structTypeToCodes: map[uintptr]Opcodes{},
recursiveCodes: &Opcodes{}, recursiveCodes: &Opcodes{},
@ -107,6 +132,8 @@ func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) {
InterfaceEscapeKeyCode: interfaceEscapeKeyCode, InterfaceEscapeKeyCode: interfaceEscapeKeyCode,
CodeLength: codeLength, CodeLength: codeLength,
EndCode: ToEndCode(interfaceNoescapeKeyCode), EndCode: ToEndCode(interfaceNoescapeKeyCode),
Code: code,
QueryCache: map[string]*OpcodeSet{},
}, nil }, nil
} }

View File

@ -3,18 +3,26 @@
package encoder package encoder
func CompileToGetCodeSet(typeptr uintptr) (*OpcodeSet, error) { func CompileToGetCodeSet(ctx *RuntimeContext, typeptr uintptr) (*OpcodeSet, error) {
if typeptr > typeAddr.MaxTypeAddr { if typeptr > typeAddr.MaxTypeAddr {
return compileToGetCodeSetSlowPath(typeptr) return compileToGetCodeSetSlowPath(typeptr)
} }
index := (typeptr - typeAddr.BaseTypeAddr) >> typeAddr.AddrShift index := (typeptr - typeAddr.BaseTypeAddr) >> typeAddr.AddrShift
if codeSet := cachedOpcodeSets[index]; codeSet != nil { if codeSet := cachedOpcodeSets[index]; codeSet != nil {
return codeSet, nil filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet)
if err != nil {
return nil, err
}
return filtered, nil
} }
codeSet, err := newCompiler().compile(typeptr) codeSet, err := newCompiler().compile(typeptr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet)
if err != nil {
return nil, err
}
cachedOpcodeSets[index] = codeSet cachedOpcodeSets[index] = codeSet
return codeSet, nil return filtered, nil
} }

View File

@ -9,15 +9,20 @@ import (
var setsMu sync.RWMutex var setsMu sync.RWMutex
func CompileToGetCodeSet(typeptr uintptr) (*OpcodeSet, error) { func CompileToGetCodeSet(ctx *RuntimeContext, typeptr uintptr) (*OpcodeSet, error) {
if typeptr > typeAddr.MaxTypeAddr { if typeptr > typeAddr.MaxTypeAddr {
return compileToGetCodeSetSlowPath(typeptr) return compileToGetCodeSetSlowPath(typeptr)
} }
index := (typeptr - typeAddr.BaseTypeAddr) >> typeAddr.AddrShift index := (typeptr - typeAddr.BaseTypeAddr) >> typeAddr.AddrShift
setsMu.RLock() setsMu.RLock()
if codeSet := cachedOpcodeSets[index]; codeSet != nil { if codeSet := cachedOpcodeSets[index]; codeSet != nil {
filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet)
if err != nil {
setsMu.RUnlock() setsMu.RUnlock()
return codeSet, nil return nil, err
}
setsMu.RUnlock()
return filtered, nil
} }
setsMu.RUnlock() setsMu.RUnlock()
@ -25,8 +30,12 @@ func CompileToGetCodeSet(typeptr uintptr) (*OpcodeSet, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
filtered, err := getFilteredCodeSetIfNeeded(ctx, codeSet)
if err != nil {
return nil, err
}
setsMu.Lock() setsMu.Lock()
cachedOpcodeSets[index] = codeSet cachedOpcodeSets[index] = codeSet
setsMu.Unlock() setsMu.Unlock()
return codeSet, nil return filtered, nil
} }

View File

@ -6,11 +6,13 @@ import (
) )
func TestDumpOpcode(t *testing.T) { func TestDumpOpcode(t *testing.T) {
ctx := TakeRuntimeContext()
defer ReleaseRuntimeContext(ctx)
var v interface{} = 1 var v interface{} = 1
header := (*emptyInterface)(unsafe.Pointer(&v)) header := (*emptyInterface)(unsafe.Pointer(&v))
typ := header.typ typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := CompileToGetCodeSet(typeptr) codeSet, err := CompileToGetCodeSet(ctx, typeptr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -101,6 +101,22 @@ type OpcodeSet struct {
InterfaceEscapeKeyCode *Opcode InterfaceEscapeKeyCode *Opcode
CodeLength int CodeLength int
EndCode *Opcode EndCode *Opcode
Code Code
QueryCache map[string]*OpcodeSet
cacheMu sync.RWMutex
}
func (s *OpcodeSet) getQueryCache(hash string) *OpcodeSet {
s.cacheMu.RLock()
codeSet := s.QueryCache[hash]
s.cacheMu.RUnlock()
return codeSet
}
func (s *OpcodeSet) setQueryCache(hash string, codeSet *OpcodeSet) {
s.cacheMu.Lock()
s.QueryCache[hash] = codeSet
s.cacheMu.Unlock()
} }
type CompiledCode struct { type CompiledCode struct {
@ -395,7 +411,11 @@ func AppendMarshalJSON(ctx *RuntimeContext, code *Opcode, b []byte, v interface{
if !ok { if !ok {
return AppendNull(ctx, b), nil return AppendNull(ctx, b), nil
} }
b, err := marshaler.MarshalJSON(ctx.Option.Context) stdctx := ctx.Option.Context
if ctx.Option.Flag&FieldQueryOption != 0 {
stdctx = SetFieldQueryToContext(stdctx, code.FieldQuery)
}
b, err := marshaler.MarshalJSON(stdctx)
if err != nil { if err != nil {
return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err} return nil, &errors.MarshalerError{Type: reflect.TypeOf(v), Err: err}
} }

View File

@ -39,6 +39,7 @@ type Opcode struct {
Type *runtime.Type // go type Type *runtime.Type // go type
Jmp *CompiledCode // for recursive call Jmp *CompiledCode // for recursive call
FieldQuery *FieldQuery // field query for Interface / MarshalJSON / MarshalText
ElemIdx uint32 // offset to access array/slice elem ElemIdx uint32 // offset to access array/slice elem
Length uint32 // offset to access slice length or array length Length uint32 // offset to access slice length or array length
Indent uint32 // indent number Indent uint32 // indent number
@ -333,6 +334,7 @@ func copyOpcode(code *Opcode) *Opcode {
Idx: c.Idx, Idx: c.Idx,
Offset: c.Offset, Offset: c.Offset,
Type: c.Type, Type: c.Type,
FieldQuery: c.FieldQuery,
DisplayIdx: c.DisplayIdx, DisplayIdx: c.DisplayIdx,
DisplayKey: c.DisplayKey, DisplayKey: c.DisplayKey,
ElemIdx: c.ElemIdx, ElemIdx: c.ElemIdx,

View File

@ -12,6 +12,7 @@ const (
ColorizeOption ColorizeOption
ContextOption ContextOption
NormalizeUTF8Option NormalizeUTF8Option
FieldQueryOption
) )
type Option struct { type Option struct {

135
internal/encoder/query.go Normal file
View File

@ -0,0 +1,135 @@
package encoder
import (
"context"
"fmt"
"reflect"
)
var (
Marshal func(interface{}) ([]byte, error)
Unmarshal func([]byte, interface{}) error
)
type FieldQuery struct {
Name string
Fields []*FieldQuery
hash string
}
func (q *FieldQuery) Hash() string {
if q.hash != "" {
return q.hash
}
b, _ := Marshal(q)
q.hash = string(b)
return q.hash
}
func (q *FieldQuery) MarshalJSON() ([]byte, error) {
if q.Name != "" {
if len(q.Fields) > 0 {
return Marshal(map[string][]*FieldQuery{q.Name: q.Fields})
}
return Marshal(q.Name)
}
return Marshal(q.Fields)
}
func (q *FieldQuery) QueryString() (FieldQueryString, error) {
b, err := Marshal(q)
if err != nil {
return "", err
}
return FieldQueryString(b), nil
}
type FieldQueryString string
func (s FieldQueryString) Build() (*FieldQuery, error) {
var query interface{}
if err := Unmarshal([]byte(s), &query); err != nil {
return nil, err
}
return s.build(reflect.ValueOf(query))
}
func (s FieldQueryString) build(v reflect.Value) (*FieldQuery, error) {
switch v.Type().Kind() {
case reflect.String:
return s.buildString(v)
case reflect.Map:
return s.buildMap(v)
case reflect.Slice:
return s.buildSlice(v)
case reflect.Interface:
return s.build(reflect.ValueOf(v.Interface()))
}
return nil, fmt.Errorf("failed to build field query")
}
func (s FieldQueryString) buildString(v reflect.Value) (*FieldQuery, error) {
b := []byte(v.String())
switch b[0] {
case '[', '{':
var query interface{}
if err := Unmarshal(b, &query); err != nil {
return nil, err
}
if str, ok := query.(string); ok {
return &FieldQuery{Name: str}, nil
}
return s.build(reflect.ValueOf(query))
}
return &FieldQuery{Name: string(b)}, nil
}
func (s FieldQueryString) buildSlice(v reflect.Value) (*FieldQuery, error) {
fields := make([]*FieldQuery, 0, v.Len())
for i := 0; i < v.Len(); i++ {
def, err := s.build(v.Index(i))
if err != nil {
return nil, err
}
fields = append(fields, def)
}
return &FieldQuery{Fields: fields}, nil
}
func (s FieldQueryString) buildMap(v reflect.Value) (*FieldQuery, error) {
keys := v.MapKeys()
if len(keys) != 1 {
return nil, fmt.Errorf("failed to build field query object")
}
key := keys[0]
if key.Type().Kind() != reflect.String {
return nil, fmt.Errorf("failed to build field query. invalid object key type")
}
name := key.String()
def, err := s.build(v.MapIndex(key))
if err != nil {
return nil, err
}
return &FieldQuery{
Name: name,
Fields: def.Fields,
}, nil
}
type queryKey struct{}
func FieldQueryFromContext(ctx context.Context) *FieldQuery {
query := ctx.Value(queryKey{})
if query == nil {
return nil
}
q, ok := query.(*FieldQuery)
if !ok {
return nil
}
return q
}
func SetFieldQueryToContext(ctx context.Context, query *FieldQuery) context.Context {
return context.WithValue(ctx, queryKey{}, query)
}

View File

@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b
break break
} }
ctx.KeepRefs = append(ctx.KeepRefs, up) ctx.KeepRefs = append(ctx.KeepRefs, up)
ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ)))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b
break break
} }
ctx.KeepRefs = append(ctx.KeepRefs, up) ctx.KeepRefs = append(ctx.KeepRefs, up)
ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ)))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b
break break
} }
ctx.KeepRefs = append(ctx.KeepRefs, up) ctx.KeepRefs = append(ctx.KeepRefs, up)
ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ)))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -199,7 +199,7 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet) ([]b
break break
} }
ctx.KeepRefs = append(ctx.KeepRefs, up) ctx.KeepRefs = append(ctx.KeepRefs, up)
ifaceCodeSet, err := encoder.CompileToGetCodeSet(uintptr(unsafe.Pointer(typ))) ifaceCodeSet, err := encoder.CompileToGetCodeSet(ctx, uintptr(unsafe.Pointer(typ)))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -364,3 +364,8 @@ func Valid(data []byte) bool {
} }
return decoder.InputOffset() >= int64(len(data)) return decoder.InputOffset() >= int64(len(data))
} }
func init() {
encoder.Marshal = Marshal
encoder.Unmarshal = Unmarshal
}

47
query.go Normal file
View File

@ -0,0 +1,47 @@
package json
import (
"github.com/goccy/go-json/internal/encoder"
)
type (
// FieldQuery you can dynamically filter the fields in the structure by creating a FieldQuery,
// adding it to context.Context using SetFieldQueryToContext and then passing it to MarshalContext.
// This is a type-safe operation, so it is faster than filtering using map[string]interface{}.
FieldQuery = encoder.FieldQuery
FieldQueryString = encoder.FieldQueryString
)
var (
// FieldQueryFromContext get current FieldQuery from context.Context.
FieldQueryFromContext = encoder.FieldQueryFromContext
// SetFieldQueryToContext set current FieldQuery to context.Context.
SetFieldQueryToContext = encoder.SetFieldQueryToContext
)
// BuildFieldQuery builds FieldQuery by fieldName or sub field query.
// First, specify the field name that you want to keep in structure type.
// If the field you want to keep is a structure type, by creating a sub field query using BuildSubFieldQuery,
// you can select the fields you want to keep in the structure.
// This description can be written recursively.
func BuildFieldQuery(fields ...FieldQueryString) (*FieldQuery, error) {
query, err := Marshal(fields)
if err != nil {
return nil, err
}
return FieldQueryString(query).Build()
}
// BuildSubFieldQuery builds sub field query.
func BuildSubFieldQuery(name string) *SubFieldQuery {
return &SubFieldQuery{name: name}
}
type SubFieldQuery struct {
name string
}
func (q *SubFieldQuery) Fields(fields ...FieldQueryString) FieldQueryString {
query, _ := Marshal(map[string][]FieldQueryString{q.name: fields})
return FieldQueryString(query)
}

121
query_test.go Normal file
View File

@ -0,0 +1,121 @@
package json_test
import (
"context"
"reflect"
"testing"
"github.com/goccy/go-json"
)
type queryTestX struct {
XA int
XB string
XC *queryTestY
XD bool
XE float32
}
type queryTestY struct {
YA int
YB string
YC *queryTestZ
YD bool
YE float32
}
type queryTestZ struct {
ZA string
ZB bool
ZC int
}
func (z *queryTestZ) MarshalJSON(ctx context.Context) ([]byte, error) {
type _queryTestZ queryTestZ
return json.MarshalContext(ctx, (*_queryTestZ)(z))
}
func TestFieldQuery(t *testing.T) {
query, err := json.BuildFieldQuery(
"XA",
"XB",
json.BuildSubFieldQuery("XC").Fields(
"YA",
"YB",
json.BuildSubFieldQuery("YC").Fields(
"ZA",
"ZB",
),
),
)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(query, &json.FieldQuery{
Fields: []*json.FieldQuery{
{
Name: "XA",
},
{
Name: "XB",
},
{
Name: "XC",
Fields: []*json.FieldQuery{
{
Name: "YA",
},
{
Name: "YB",
},
{
Name: "YC",
Fields: []*json.FieldQuery{
{
Name: "ZA",
},
{
Name: "ZB",
},
},
},
},
},
},
}) {
t.Fatal("cannot get query")
}
queryStr, err := query.QueryString()
if err != nil {
t.Fatal(err)
}
if queryStr != `["XA","XB",{"XC":["YA","YB",{"YC":["ZA","ZB"]}]}]` {
t.Fatalf("failed to create query string. %s", queryStr)
}
ctx := json.SetFieldQueryToContext(context.Background(), query)
b, err := json.MarshalContext(ctx, &queryTestX{
XA: 1,
XB: "xb",
XC: &queryTestY{
YA: 2,
YB: "yb",
YC: &queryTestZ{
ZA: "za",
ZB: true,
ZC: 3,
},
YD: true,
YE: 4,
},
XD: true,
XE: 5,
})
if err != nil {
t.Fatal(err)
}
expected := `{"XA":1,"XB":"xb","XC":{"YA":2,"YB":"yb","YC":{"ZA":"za","ZB":true}}}`
got := string(b)
if expected != got {
t.Fatalf("failed to encode with field query: expected %q but got %q", expected, got)
}
}

View File

@ -11,7 +11,7 @@ func TestOpcodeSize(t *testing.T) {
const uintptrSize = 4 << (^uintptr(0) >> 63) const uintptrSize = 4 << (^uintptr(0) >> 63)
if uintptrSize == 8 { if uintptrSize == 8 {
size := unsafe.Sizeof(encoder.Opcode{}) size := unsafe.Sizeof(encoder.Opcode{})
if size != 112 { if size != 120 {
t.Fatalf("unexpected opcode size: expected 112bytes but got %dbytes", size) t.Fatalf("unexpected opcode size: expected 112bytes but got %dbytes", size)
} }
} }

View File

@ -0,0 +1,96 @@
package json_test
import (
"context"
"fmt"
"log"
"github.com/goccy/go-json"
)
type User struct {
ID int64
Name string
Age int
Address UserAddressResolver
}
type UserAddress struct {
UserID int64
PostCode string
City string
Address1 string
Address2 string
}
type UserRepository struct {
uaRepo *UserAddressRepository
}
func NewUserRepository() *UserRepository {
return &UserRepository{
uaRepo: NewUserAddressRepository(),
}
}
type UserAddressRepository struct{}
func NewUserAddressRepository() *UserAddressRepository {
return &UserAddressRepository{}
}
type UserAddressResolver func(context.Context) (*UserAddress, error)
func (resolver UserAddressResolver) MarshalJSON(ctx context.Context) ([]byte, error) {
address, err := resolver(ctx)
if err != nil {
return nil, err
}
return json.MarshalContext(ctx, address)
}
func (r *UserRepository) FindByID(ctx context.Context, id int64) (*User, error) {
user := &User{ID: id, Name: "Ken", Age: 20}
// resolve relation from User to UserAddress
user.Address = func(ctx context.Context) (*UserAddress, error) {
return r.uaRepo.FindByUserID(ctx, user.ID)
}
return user, nil
}
func (*UserAddressRepository) FindByUserID(ctx context.Context, id int64) (*UserAddress, error) {
return &UserAddress{
UserID: id,
City: "A",
Address1: "foo",
Address2: "bar",
}, nil
}
func Example_fieldQuery() {
ctx := context.Background()
userRepo := NewUserRepository()
user, err := userRepo.FindByID(ctx, 1)
if err != nil {
log.Fatal(err)
}
query, err := json.BuildFieldQuery(
"Name",
"Age",
json.BuildSubFieldQuery("Address").Fields(
"City",
),
)
if err != nil {
log.Fatal(err)
}
ctx = json.SetFieldQueryToContext(ctx, query)
b, err := json.MarshalContext(ctx, user)
if err != nil {
log.Fatal(err)
}
fmt.Println(string(b))
// Output:
// {"Name":"Ken","Age":20,"Address":{"City":"A"}}
}