Merge pull request #110 from goccy/feature/fix-marshal-json

Fix encoding of type which implemented MarshalJSON
This commit is contained in:
Masaaki Goshima 2021-02-01 22:40:17 +09:00 committed by GitHub
commit 0a5e990b17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 128 additions and 62 deletions

View File

@ -22,11 +22,12 @@ type opcodeSet struct {
} }
var ( var (
marshalJSONType = reflect.TypeOf((*Marshaler)(nil)).Elem() marshalJSONType = reflect.TypeOf((*Marshaler)(nil)).Elem()
marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
cachedOpcode unsafe.Pointer // map[uintptr]*opcodeSet cachedOpcode unsafe.Pointer // map[uintptr]*opcodeSet
baseTypeAddr uintptr baseTypeAddr uintptr
cachedOpcodeSets []*opcodeSet cachedOpcodeSets []*opcodeSet
existsCachedOpcodeSets bool
) )
const ( const (
@ -80,6 +81,7 @@ func setupOpcodeSets() error {
return fmt.Errorf("too big address range %d", addrRange) return fmt.Errorf("too big address range %d", addrRange)
} }
cachedOpcodeSets = make([]*opcodeSet, addrRange) cachedOpcodeSets = make([]*opcodeSet, addrRange)
existsCachedOpcodeSets = true
baseTypeAddr = min baseTypeAddr = min
return nil return nil
} }
@ -90,35 +92,6 @@ func init() {
} }
} }
func encodeCompileToGetCodeSet(typeptr uintptr) (*opcodeSet, error) {
if cachedOpcodeSets == nil {
return encodeCompileToGetCodeSetSlowPath(typeptr)
}
if codeSet := cachedOpcodeSets[typeptr-baseTypeAddr]; codeSet != nil {
return codeSet, nil
}
// noescape trick for header.typ ( reflect.*rtype )
copiedType := *(**rtype)(unsafe.Pointer(&typeptr))
code, err := encodeCompileHead(&encodeCompileContext{
typ: copiedType,
root: true,
structTypeToCompiledCode: map[uintptr]*compiledCode{},
})
if err != nil {
return nil, err
}
code = copyOpcode(code)
codeLength := code.totalLength()
codeSet := &opcodeSet{
code: code,
codeLength: codeLength,
}
cachedOpcodeSets[int(typeptr-baseTypeAddr)] = codeSet
return codeSet, nil
}
func encodeCompileToGetCodeSetSlowPath(typeptr uintptr) (*opcodeSet, error) { func encodeCompileToGetCodeSetSlowPath(typeptr uintptr) (*opcodeSet, error) {
opcodeMap := loadOpcodeMap() opcodeMap := loadOpcodeMap()
if codeSet, exists := opcodeMap[typeptr]; exists { if codeSet, exists := opcodeMap[typeptr]; exists {

34
encode_compile_norace.go Normal file
View File

@ -0,0 +1,34 @@
// +build !race
package json
import "unsafe"
func encodeCompileToGetCodeSet(typeptr uintptr) (*opcodeSet, error) {
if !existsCachedOpcodeSets {
return encodeCompileToGetCodeSetSlowPath(typeptr)
}
if codeSet := cachedOpcodeSets[typeptr-baseTypeAddr]; codeSet != nil {
return codeSet, nil
}
// noescape trick for header.typ ( reflect.*rtype )
copiedType := *(**rtype)(unsafe.Pointer(&typeptr))
code, err := encodeCompileHead(&encodeCompileContext{
typ: copiedType,
root: true,
structTypeToCompiledCode: map[uintptr]*compiledCode{},
})
if err != nil {
return nil, err
}
code = copyOpcode(code)
codeLength := code.totalLength()
codeSet := &opcodeSet{
code: code,
codeLength: codeLength,
}
cachedOpcodeSets[int(typeptr-baseTypeAddr)] = codeSet
return codeSet, nil
}

44
encode_compile_race.go Normal file
View File

@ -0,0 +1,44 @@
// +build race
package json
import (
"sync"
"unsafe"
)
var setsMu sync.RWMutex
func encodeCompileToGetCodeSet(typeptr uintptr) (*opcodeSet, error) {
if !existsCachedOpcodeSets {
return encodeCompileToGetCodeSetSlowPath(typeptr)
}
setsMu.RLock()
if codeSet := cachedOpcodeSets[typeptr-baseTypeAddr]; codeSet != nil {
setsMu.RUnlock()
return codeSet, nil
}
setsMu.RUnlock()
// noescape trick for header.typ ( reflect.*rtype )
copiedType := *(**rtype)(unsafe.Pointer(&typeptr))
code, err := encodeCompileHead(&encodeCompileContext{
typ: copiedType,
root: true,
structTypeToCompiledCode: map[uintptr]*compiledCode{},
})
if err != nil {
return nil, err
}
code = copyOpcode(code)
codeLength := code.totalLength()
codeSet := &opcodeSet{
code: code,
codeLength: codeLength,
}
setsMu.Lock()
cachedOpcodeSets[int(typeptr-baseTypeAddr)] = codeSet
setsMu.Unlock()
return codeSet, nil
}

View File

@ -7801,6 +7801,15 @@ func encodeRunEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet, o
ptr := load(ctxptr, code.headIdx) ptr := load(ctxptr, code.headIdx)
b = append(b, code.escapedKey...) b = append(b, code.escapedKey...)
p := ptr + code.offset p := ptr + code.offset
if code.typ.Kind() == reflect.Ptr {
p = ptrToPtr(p)
}
if p == 0 {
b = encodeNull(b)
b = encodeComma(b)
code = code.next
break
}
v := ptrToInterface(code, p) v := ptrToInterface(code, p)
bb, err := v.(Marshaler).MarshalJSON() bb, err := v.(Marshaler).MarshalJSON()
if err != nil { if err != nil {
@ -7816,23 +7825,25 @@ func encodeRunEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet, o
case opStructFieldOmitEmptyMarshalJSON: case opStructFieldOmitEmptyMarshalJSON:
ptr := load(ctxptr, code.headIdx) ptr := load(ctxptr, code.headIdx)
p := ptr + code.offset p := ptr + code.offset
if code.typ.Kind() == reflect.Ptr && code.typ.Elem().Implements(marshalJSONType) { if code.typ.Kind() == reflect.Ptr {
p = ptrToPtr(p) p = ptrToPtr(p)
} }
v := ptrToInterface(code, p) if p == 0 {
if v != nil && p != 0 { code = code.next
bb, err := v.(Marshaler).MarshalJSON() break
if err != nil {
return nil, errMarshaler(code, err)
}
b = append(b, code.escapedKey...)
buf := bytes.NewBuffer(b)
if err := compact(buf, bb, true); err != nil {
return nil, err
}
b = buf.Bytes()
b = encodeComma(b)
} }
v := ptrToInterface(code, p)
bb, err := v.(Marshaler).MarshalJSON()
if err != nil {
return nil, errMarshaler(code, err)
}
b = append(b, code.escapedKey...)
buf := bytes.NewBuffer(b)
if err := compact(buf, bb, true); err != nil {
return nil, err
}
b = buf.Bytes()
b = encodeComma(b)
code = code.next code = code.next
case opStructFieldStringTagMarshalJSON: case opStructFieldStringTagMarshalJSON:
ptr := load(ctxptr, code.headIdx) ptr := load(ctxptr, code.headIdx)
@ -9256,20 +9267,10 @@ func encodeRunEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet, o
case opStructEndOmitEmptyMarshalJSON: case opStructEndOmitEmptyMarshalJSON:
ptr := load(ctxptr, code.headIdx) ptr := load(ctxptr, code.headIdx)
p := ptr + code.offset p := ptr + code.offset
v := ptrToInterface(code, p) if code.typ.Kind() == reflect.Ptr {
if v != nil && (code.typ.Kind() != reflect.Ptr || ptrToPtr(p) != 0) { p = ptrToPtr(p)
bb, err := v.(Marshaler).MarshalJSON() }
if err != nil { if p == 0 {
return nil, errMarshaler(code, err)
}
b = append(b, code.escapedKey...)
buf := bytes.NewBuffer(b)
if err := compact(buf, bb, true); err != nil {
return nil, err
}
b = buf.Bytes()
b = appendStructEnd(b)
} else {
last := len(b) - 1 last := len(b) - 1
if b[last] == ',' { if b[last] == ',' {
b[last] = '}' b[last] = '}'
@ -9277,7 +9278,21 @@ func encodeRunEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet, o
} else { } else {
b = appendStructEnd(b) b = appendStructEnd(b)
} }
code = code.next
break
} }
v := ptrToInterface(code, p)
bb, err := v.(Marshaler).MarshalJSON()
if err != nil {
return nil, errMarshaler(code, err)
}
b = append(b, code.escapedKey...)
buf := bytes.NewBuffer(b)
if err := compact(buf, bb, true); err != nil {
return nil, err
}
b = buf.Bytes()
b = appendStructEnd(b)
code = code.next code = code.next
case opStructEndStringTagMarshalJSON: case opStructEndStringTagMarshalJSON:
ptr := load(ctxptr, code.headIdx) ptr := load(ctxptr, code.headIdx)