Fix interface and map operation

This commit is contained in:
Masaaki Goshima 2021-01-24 03:02:26 +09:00
parent 86ae7d931a
commit 1258224a26
6 changed files with 139 additions and 105 deletions

View File

@ -112,6 +112,7 @@ func (e *Encoder) EncodeWithOption(v interface{}, opts ...EncodeOption) error {
}
}
header := (*interfaceHeader)(unsafe.Pointer(&v))
e.ptr = header.ptr
buf, err := e.encode(header, v == nil)
if err != nil {
return err
@ -192,23 +193,31 @@ func (e *Encoder) encode(header *interfaceHeader, isNil bool) ([]byte, error) {
typ := header.typ
typeptr := uintptr(unsafe.Pointer(typ))
codeSet, err := e.compileToGetCodeSet(typeptr)
if err != nil {
return nil, err
}
ctx := e.ctx
p := uintptr(header.ptr)
ctx.init(p, codeSet.codeLength)
if e.enabledIndent {
if e.enabledHTMLEscape {
return e.runEscapedIndent(ctx, b, codeSet)
} else {
return e.runIndent(ctx, b, codeSet)
}
}
if e.enabledHTMLEscape {
return e.runEscaped(ctx, b, codeSet)
}
return e.run(ctx, b, codeSet)
}
func (e *Encoder) compileToGetCodeSet(typeptr uintptr) (*opcodeSet, error) {
opcodeMap := loadOpcodeMap()
if codeSet, exists := opcodeMap[typeptr]; exists {
ctx := e.ctx
p := uintptr(header.ptr)
ctx.init(p, codeSet.codeLength)
if e.enabledIndent {
if e.enabledHTMLEscape {
return e.runEscapedIndent(ctx, b, codeSet.code)
} else {
return e.runIndent(ctx, b, codeSet.code)
}
}
if e.enabledHTMLEscape {
return e.runEscaped(ctx, b, codeSet.code)
}
return e.run(ctx, b, codeSet.code)
return codeSet, nil
}
// noescape trick for header.typ ( reflect.*rtype )
@ -230,21 +239,7 @@ func (e *Encoder) encode(header *interfaceHeader, isNil bool) ([]byte, error) {
}
storeOpcodeSet(typeptr, codeSet, opcodeMap)
p := uintptr(header.ptr)
ctx := e.ctx
ctx.init(p, codeLength)
if e.enabledIndent {
if e.enabledHTMLEscape {
return e.runEscapedIndent(ctx, b, codeSet.code)
} else {
return e.runIndent(ctx, b, codeSet.code)
}
}
if e.enabledHTMLEscape {
return e.runEscaped(ctx, b, codeSet.code)
}
return e.run(ctx, b, codeSet.code)
return codeSet, nil
}
func encodeFloat32(b []byte, v float32) []byte {

View File

@ -1,9 +1,67 @@
package json
import (
"bytes"
"sync"
"unsafe"
)
type mapItem struct {
key []byte
value []byte
}
type mapslice struct {
items []mapItem
}
func (m *mapslice) Len() int {
return len(m.items)
}
func (m *mapslice) Less(i, j int) bool {
return bytes.Compare(m.items[i].key, m.items[j].key) < 0
}
func (m *mapslice) Swap(i, j int) {
m.items[i], m.items[j] = m.items[j], m.items[i]
}
type encodeMapContext struct {
iter unsafe.Pointer
pos []int
slice *mapslice
buf []byte
}
var mapContextPool = sync.Pool{
New: func() interface{} {
return &encodeMapContext{}
},
}
func newMapContext(mapLen int) *encodeMapContext {
ctx := mapContextPool.Get().(*encodeMapContext)
if ctx.slice == nil {
ctx.slice = &mapslice{
items: make([]mapItem, 0, mapLen),
}
}
if cap(ctx.pos) < mapLen*2 {
ctx.pos = make([]int, 0, mapLen*2)
ctx.slice.items = make([]mapItem, 0, mapLen)
} else {
ctx.pos = ctx.pos[:0]
ctx.slice.items = ctx.slice.items[:0]
}
ctx.buf = ctx.buf[:0]
return ctx
}
func releaseMapContext(c *encodeMapContext) {
mapContextPool.Put(c)
}
type encodeCompileContext struct {
typ *rtype
root bool

View File

@ -49,10 +49,11 @@ func errMarshaler(code *opcode, err error) *MarshalerError {
}
}
func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet) ([]byte, error) {
recursiveLevel := 0
ptrOffset := uintptr(0)
ctxptr := ctx.ptr()
code := codeSet.code
for {
switch code.op {

View File

@ -11,11 +11,13 @@ import (
"unsafe"
)
func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet) ([]byte, error) {
recursiveLevel := 0
ptrOffset := uintptr(0)
ctxptr := ctx.ptr()
code := codeSet.code
for {
switch code.op {
default:
@ -167,63 +169,51 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
return nil, errUnsupportedValue(code, ptr)
}
}
ctx.seenPtr = append(ctx.seenPtr, ptr)
v := e.ptrToInterface(code, ptr)
ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v))
rv := reflect.ValueOf(v)
if rv.IsNil() {
iface := (*interfaceHeader)(e.ptrToUnsafePtr(ptr))
ctx.keepRefs = append(ctx.keepRefs, iface.ptr)
if iface.ptr == nil {
b = encodeNull(b)
b = encodeComma(b)
code = code.next
break
}
vv := rv.Interface()
header := (*interfaceHeader)(unsafe.Pointer(&vv))
if header.typ.Kind() == reflect.Ptr {
if rv.Elem().IsNil() {
b = encodeNull(b)
b = encodeComma(b)
code = code.next
break
}
}
c, err := e.compileHead(&encodeCompileContext{
typ: header.typ,
root: code.root,
indent: code.indent,
structTypeToCompiledCode: map[uintptr]*compiledCode{},
})
ifaceCodeSet, err := e.compileToGetCodeSet(uintptr(unsafe.Pointer(iface.typ)))
if err != nil {
return nil, err
}
beforeLastCode := c.beforeLastCode()
lastCode := beforeLastCode.next
lastCode.idx = beforeLastCode.idx + uintptrSize
totalLength := uintptr(code.totalLength())
nextTotalLength := uintptr(c.totalLength())
totalLength := uintptr(codeSet.codeLength)
nextTotalLength := uintptr(ifaceCodeSet.codeLength)
curlen := uintptr(len(ctx.ptrs))
offsetNum := ptrOffset / uintptrSize
oldOffset := ptrOffset
ptrOffset += totalLength * uintptrSize
newLen := offsetNum + totalLength + nextTotalLength
if curlen < newLen {
ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...)
}
ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr
store(ctxptr, 0, uintptr(header.ptr))
store(ctxptr, lastCode.idx, oldOffset)
newPtrs := ctx.ptrs[ptrOffset/uintptrSize:]
newPtrs[0] = uintptr(iface.ptr)
// link lastCode ( opInterfaceEnd ) => code.next
lastCode.op = opInterfaceEnd
lastCode.next = code.next
oldPtrs := ctx.ptrs
ctx.ptrs = newPtrs
ctx.seenPtr = append(ctx.seenPtr, uintptr(ptr))
bb, err := e.runEscaped(ctx, b, ifaceCodeSet)
if err != nil {
return nil, err
}
code = c
recursiveLevel++
ctx.ptrs = oldPtrs
ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1]
b = bb
code = code.next
case opInterfaceEnd:
recursiveLevel--
// restore ctxptr
ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1]
offset := load(ctxptr, code.idx)
ctxptr = ctx.ptr() + offset
ptrOffset = offset
@ -358,11 +348,9 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
store(ctxptr, code.length, uintptr(mlen))
store(ctxptr, code.mapIter, uintptr(iter))
if !e.unorderedMap {
pos := make([]int, 0, mlen)
pos = append(pos, len(b))
posPtr := unsafe.Pointer(&pos)
ctx.keepRefs = append(ctx.keepRefs, posPtr)
store(ctxptr, code.end.mapPos, uintptr(posPtr))
mapCtx := newMapContext(mlen)
mapCtx.pos = append(mapCtx.pos, len(b))
store(ctxptr, code.end.mapPos, uintptr(unsafe.Pointer(mapCtx)))
}
key := mapiterkey(iter)
store(ctxptr, code.next.idx, uintptr(key))
@ -399,11 +387,9 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
key := mapiterkey(iter)
store(ctxptr, code.next.idx, uintptr(key))
if !e.unorderedMap {
pos := make([]int, 0, mlen)
pos = append(pos, len(b))
posPtr := unsafe.Pointer(&pos)
ctx.keepRefs = append(ctx.keepRefs, posPtr)
store(ctxptr, code.end.mapPos, uintptr(posPtr))
mapCtx := newMapContext(mlen)
mapCtx.pos = append(mapCtx.pos, len(b))
store(ctxptr, code.end.mapPos, uintptr(unsafe.Pointer(mapCtx)))
}
code = code.next
} else {
@ -431,8 +417,8 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
}
} else {
ptr := load(ctxptr, code.end.mapPos)
posPtr := (*[]int)(*(*unsafe.Pointer)(unsafe.Pointer(&ptr)))
*posPtr = append(*posPtr, len(b))
mapCtx := (*encodeMapContext)(e.ptrToUnsafePtr(ptr))
mapCtx.pos = append(mapCtx.pos, len(b))
if idx < length {
ptr := load(ctxptr, code.mapIter)
iter := e.ptrToUnsafePtr(ptr)
@ -450,8 +436,8 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
b[last] = ':'
} else {
ptr := load(ctxptr, code.end.mapPos)
posPtr := (*[]int)(*(*unsafe.Pointer)(unsafe.Pointer(&ptr)))
*posPtr = append(*posPtr, len(b))
mapCtx := (*encodeMapContext)(e.ptrToUnsafePtr(ptr))
mapCtx.pos = append(mapCtx.pos, len(b))
}
ptr := load(ctxptr, code.mapIter)
iter := e.ptrToUnsafePtr(ptr)
@ -462,14 +448,9 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
case opMapEnd:
// this operation only used by sorted map.
length := int(load(ctxptr, code.length))
type mapKV struct {
key string
value string
}
kvs := make([]mapKV, 0, length)
ptr := load(ctxptr, code.mapPos)
posPtr := e.ptrToUnsafePtr(ptr)
pos := *(*[]int)(posPtr)
mapCtx := (*encodeMapContext)(e.ptrToUnsafePtr(ptr))
pos := mapCtx.pos
for i := 0; i < length; i++ {
startKey := pos[i*2]
startValue := pos[i*2+1]
@ -479,25 +460,22 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
} else {
endValue = len(b)
}
kvs = append(kvs, mapKV{
key: string(b[startKey:startValue]),
value: string(b[startValue:endValue]),
mapCtx.slice.items = append(mapCtx.slice.items, mapItem{
key: b[startKey:startValue],
value: b[startValue:endValue],
})
}
sort.Slice(kvs, func(i, j int) bool {
return kvs[i].key < kvs[j].key
})
buf := b[pos[0]:]
buf = buf[:0]
for _, kv := range kvs {
buf = append(buf, []byte(kv.key)...)
buf[len(buf)-1] = ':'
buf = append(buf, []byte(kv.value)...)
sort.Sort(mapCtx.slice)
for _, item := range mapCtx.slice.items {
mapCtx.buf = append(mapCtx.buf, item.key...)
mapCtx.buf[len(mapCtx.buf)-1] = ':'
mapCtx.buf = append(mapCtx.buf, item.value...)
}
buf[len(buf)-1] = '}'
buf = append(buf, ',')
mapCtx.buf[len(mapCtx.buf)-1] = '}'
mapCtx.buf = append(mapCtx.buf, ',')
b = b[:pos[0]]
b = append(b, buf...)
b = append(b, mapCtx.buf...)
releaseMapContext(mapCtx)
code = code.next
case opStructFieldPtrAnonymousHeadRecursive:
store(ctxptr, code.idx, e.ptrToPtr(load(ctxptr, code.idx)))

View File

@ -11,11 +11,12 @@ import (
"unsafe"
)
func (e *Encoder) runEscapedIndent(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
func (e *Encoder) runEscapedIndent(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet) ([]byte, error) {
recursiveLevel := 0
var seenPtr map[uintptr]struct{}
ptrOffset := uintptr(0)
ctxptr := ctx.ptr()
code := codeSet.code
for {
switch code.op {

View File

@ -11,11 +11,12 @@ import (
"unsafe"
)
func (e *Encoder) runIndent(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
func (e *Encoder) runIndent(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet) ([]byte, error) {
recursiveLevel := 0
var seenPtr map[uintptr]struct{}
ptrOffset := uintptr(0)
ctxptr := ctx.ptr()
code := codeSet.code
for {
switch code.op {