From 1258224a26ea0c67c8df2ecafe30edae671a399a Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sun, 24 Jan 2021 03:02:26 +0900 Subject: [PATCH] Fix interface and map operation --- encode.go | 55 ++++++++-------- encode_context.go | 58 +++++++++++++++++ encode_vm.go | 3 +- encode_vm_escaped.go | 122 +++++++++++++++--------------------- encode_vm_escaped_indent.go | 3 +- encode_vm_indent.go | 3 +- 6 files changed, 139 insertions(+), 105 deletions(-) diff --git a/encode.go b/encode.go index 50cb56a..cf4cdaa 100644 --- a/encode.go +++ b/encode.go @@ -112,6 +112,7 @@ func (e *Encoder) EncodeWithOption(v interface{}, opts ...EncodeOption) error { } } header := (*interfaceHeader)(unsafe.Pointer(&v)) + e.ptr = header.ptr buf, err := e.encode(header, v == nil) if err != nil { return err @@ -192,23 +193,31 @@ func (e *Encoder) encode(header *interfaceHeader, isNil bool) ([]byte, error) { typ := header.typ typeptr := uintptr(unsafe.Pointer(typ)) + codeSet, err := e.compileToGetCodeSet(typeptr) + if err != nil { + return nil, err + } + + ctx := e.ctx + p := uintptr(header.ptr) + ctx.init(p, codeSet.codeLength) + if e.enabledIndent { + if e.enabledHTMLEscape { + return e.runEscapedIndent(ctx, b, codeSet) + } else { + return e.runIndent(ctx, b, codeSet) + } + } + if e.enabledHTMLEscape { + return e.runEscaped(ctx, b, codeSet) + } + return e.run(ctx, b, codeSet) +} + +func (e *Encoder) compileToGetCodeSet(typeptr uintptr) (*opcodeSet, error) { opcodeMap := loadOpcodeMap() if codeSet, exists := opcodeMap[typeptr]; exists { - ctx := e.ctx - p := uintptr(header.ptr) - ctx.init(p, codeSet.codeLength) - - if e.enabledIndent { - if e.enabledHTMLEscape { - return e.runEscapedIndent(ctx, b, codeSet.code) - } else { - return e.runIndent(ctx, b, codeSet.code) - } - } - if e.enabledHTMLEscape { - return e.runEscaped(ctx, b, codeSet.code) - } - return e.run(ctx, b, codeSet.code) + return codeSet, nil } // noescape trick for header.typ ( reflect.*rtype ) @@ -230,21 +239,7 @@ func (e *Encoder) encode(header *interfaceHeader, isNil bool) ([]byte, error) { } storeOpcodeSet(typeptr, codeSet, opcodeMap) - p := uintptr(header.ptr) - ctx := e.ctx - ctx.init(p, codeLength) - - if e.enabledIndent { - if e.enabledHTMLEscape { - return e.runEscapedIndent(ctx, b, codeSet.code) - } else { - return e.runIndent(ctx, b, codeSet.code) - } - } - if e.enabledHTMLEscape { - return e.runEscaped(ctx, b, codeSet.code) - } - return e.run(ctx, b, codeSet.code) + return codeSet, nil } func encodeFloat32(b []byte, v float32) []byte { diff --git a/encode_context.go b/encode_context.go index 0713f90..67f68d1 100644 --- a/encode_context.go +++ b/encode_context.go @@ -1,9 +1,67 @@ package json import ( + "bytes" + "sync" "unsafe" ) +type mapItem struct { + key []byte + value []byte +} + +type mapslice struct { + items []mapItem +} + +func (m *mapslice) Len() int { + return len(m.items) +} + +func (m *mapslice) Less(i, j int) bool { + return bytes.Compare(m.items[i].key, m.items[j].key) < 0 +} + +func (m *mapslice) Swap(i, j int) { + m.items[i], m.items[j] = m.items[j], m.items[i] +} + +type encodeMapContext struct { + iter unsafe.Pointer + pos []int + slice *mapslice + buf []byte +} + +var mapContextPool = sync.Pool{ + New: func() interface{} { + return &encodeMapContext{} + }, +} + +func newMapContext(mapLen int) *encodeMapContext { + ctx := mapContextPool.Get().(*encodeMapContext) + if ctx.slice == nil { + ctx.slice = &mapslice{ + items: make([]mapItem, 0, mapLen), + } + } + if cap(ctx.pos) < mapLen*2 { + ctx.pos = make([]int, 0, mapLen*2) + ctx.slice.items = make([]mapItem, 0, mapLen) + } else { + ctx.pos = ctx.pos[:0] + ctx.slice.items = ctx.slice.items[:0] + } + ctx.buf = ctx.buf[:0] + return ctx +} + +func releaseMapContext(c *encodeMapContext) { + mapContextPool.Put(c) +} + type encodeCompileContext struct { typ *rtype root bool diff --git a/encode_vm.go b/encode_vm.go index 14a755c..fe05b32 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -49,10 +49,11 @@ func errMarshaler(code *opcode, err error) *MarshalerError { } } -func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) { +func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet) ([]byte, error) { recursiveLevel := 0 ptrOffset := uintptr(0) ctxptr := ctx.ptr() + code := codeSet.code for { switch code.op { diff --git a/encode_vm_escaped.go b/encode_vm_escaped.go index 619c21a..2492c8f 100644 --- a/encode_vm_escaped.go +++ b/encode_vm_escaped.go @@ -11,11 +11,13 @@ import ( "unsafe" ) -func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) { +func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet) ([]byte, error) { recursiveLevel := 0 ptrOffset := uintptr(0) ctxptr := ctx.ptr() + code := codeSet.code + for { switch code.op { default: @@ -167,63 +169,51 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) return nil, errUnsupportedValue(code, ptr) } } - ctx.seenPtr = append(ctx.seenPtr, ptr) - v := e.ptrToInterface(code, ptr) - ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v)) - rv := reflect.ValueOf(v) - if rv.IsNil() { + iface := (*interfaceHeader)(e.ptrToUnsafePtr(ptr)) + ctx.keepRefs = append(ctx.keepRefs, iface.ptr) + if iface.ptr == nil { b = encodeNull(b) b = encodeComma(b) code = code.next break } - vv := rv.Interface() - header := (*interfaceHeader)(unsafe.Pointer(&vv)) - if header.typ.Kind() == reflect.Ptr { - if rv.Elem().IsNil() { - b = encodeNull(b) - b = encodeComma(b) - code = code.next - break - } - } - c, err := e.compileHead(&encodeCompileContext{ - typ: header.typ, - root: code.root, - indent: code.indent, - structTypeToCompiledCode: map[uintptr]*compiledCode{}, - }) + ifaceCodeSet, err := e.compileToGetCodeSet(uintptr(unsafe.Pointer(iface.typ))) if err != nil { return nil, err } - beforeLastCode := c.beforeLastCode() - lastCode := beforeLastCode.next - lastCode.idx = beforeLastCode.idx + uintptrSize - totalLength := uintptr(code.totalLength()) - nextTotalLength := uintptr(c.totalLength()) + + totalLength := uintptr(codeSet.codeLength) + nextTotalLength := uintptr(ifaceCodeSet.codeLength) + curlen := uintptr(len(ctx.ptrs)) offsetNum := ptrOffset / uintptrSize - oldOffset := ptrOffset ptrOffset += totalLength * uintptrSize newLen := offsetNum + totalLength + nextTotalLength if curlen < newLen { ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...) } - ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr - store(ctxptr, 0, uintptr(header.ptr)) - store(ctxptr, lastCode.idx, oldOffset) + newPtrs := ctx.ptrs[ptrOffset/uintptrSize:] + newPtrs[0] = uintptr(iface.ptr) - // link lastCode ( opInterfaceEnd ) => code.next - lastCode.op = opInterfaceEnd - lastCode.next = code.next + oldPtrs := ctx.ptrs + ctx.ptrs = newPtrs + ctx.seenPtr = append(ctx.seenPtr, uintptr(ptr)) + bb, err := e.runEscaped(ctx, b, ifaceCodeSet) + if err != nil { + return nil, err + } - code = c - recursiveLevel++ + ctx.ptrs = oldPtrs + ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1] + + b = bb + code = code.next case opInterfaceEnd: recursiveLevel-- // restore ctxptr + ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1] offset := load(ctxptr, code.idx) ctxptr = ctx.ptr() + offset ptrOffset = offset @@ -358,11 +348,9 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) store(ctxptr, code.length, uintptr(mlen)) store(ctxptr, code.mapIter, uintptr(iter)) if !e.unorderedMap { - pos := make([]int, 0, mlen) - pos = append(pos, len(b)) - posPtr := unsafe.Pointer(&pos) - ctx.keepRefs = append(ctx.keepRefs, posPtr) - store(ctxptr, code.end.mapPos, uintptr(posPtr)) + mapCtx := newMapContext(mlen) + mapCtx.pos = append(mapCtx.pos, len(b)) + store(ctxptr, code.end.mapPos, uintptr(unsafe.Pointer(mapCtx))) } key := mapiterkey(iter) store(ctxptr, code.next.idx, uintptr(key)) @@ -399,11 +387,9 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) key := mapiterkey(iter) store(ctxptr, code.next.idx, uintptr(key)) if !e.unorderedMap { - pos := make([]int, 0, mlen) - pos = append(pos, len(b)) - posPtr := unsafe.Pointer(&pos) - ctx.keepRefs = append(ctx.keepRefs, posPtr) - store(ctxptr, code.end.mapPos, uintptr(posPtr)) + mapCtx := newMapContext(mlen) + mapCtx.pos = append(mapCtx.pos, len(b)) + store(ctxptr, code.end.mapPos, uintptr(unsafe.Pointer(mapCtx))) } code = code.next } else { @@ -431,8 +417,8 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) } } else { ptr := load(ctxptr, code.end.mapPos) - posPtr := (*[]int)(*(*unsafe.Pointer)(unsafe.Pointer(&ptr))) - *posPtr = append(*posPtr, len(b)) + mapCtx := (*encodeMapContext)(e.ptrToUnsafePtr(ptr)) + mapCtx.pos = append(mapCtx.pos, len(b)) if idx < length { ptr := load(ctxptr, code.mapIter) iter := e.ptrToUnsafePtr(ptr) @@ -450,8 +436,8 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) b[last] = ':' } else { ptr := load(ctxptr, code.end.mapPos) - posPtr := (*[]int)(*(*unsafe.Pointer)(unsafe.Pointer(&ptr))) - *posPtr = append(*posPtr, len(b)) + mapCtx := (*encodeMapContext)(e.ptrToUnsafePtr(ptr)) + mapCtx.pos = append(mapCtx.pos, len(b)) } ptr := load(ctxptr, code.mapIter) iter := e.ptrToUnsafePtr(ptr) @@ -462,14 +448,9 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) case opMapEnd: // this operation only used by sorted map. length := int(load(ctxptr, code.length)) - type mapKV struct { - key string - value string - } - kvs := make([]mapKV, 0, length) ptr := load(ctxptr, code.mapPos) - posPtr := e.ptrToUnsafePtr(ptr) - pos := *(*[]int)(posPtr) + mapCtx := (*encodeMapContext)(e.ptrToUnsafePtr(ptr)) + pos := mapCtx.pos for i := 0; i < length; i++ { startKey := pos[i*2] startValue := pos[i*2+1] @@ -479,25 +460,22 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) } else { endValue = len(b) } - kvs = append(kvs, mapKV{ - key: string(b[startKey:startValue]), - value: string(b[startValue:endValue]), + mapCtx.slice.items = append(mapCtx.slice.items, mapItem{ + key: b[startKey:startValue], + value: b[startValue:endValue], }) } - sort.Slice(kvs, func(i, j int) bool { - return kvs[i].key < kvs[j].key - }) - buf := b[pos[0]:] - buf = buf[:0] - for _, kv := range kvs { - buf = append(buf, []byte(kv.key)...) - buf[len(buf)-1] = ':' - buf = append(buf, []byte(kv.value)...) + sort.Sort(mapCtx.slice) + for _, item := range mapCtx.slice.items { + mapCtx.buf = append(mapCtx.buf, item.key...) + mapCtx.buf[len(mapCtx.buf)-1] = ':' + mapCtx.buf = append(mapCtx.buf, item.value...) } - buf[len(buf)-1] = '}' - buf = append(buf, ',') + mapCtx.buf[len(mapCtx.buf)-1] = '}' + mapCtx.buf = append(mapCtx.buf, ',') b = b[:pos[0]] - b = append(b, buf...) + b = append(b, mapCtx.buf...) + releaseMapContext(mapCtx) code = code.next case opStructFieldPtrAnonymousHeadRecursive: store(ctxptr, code.idx, e.ptrToPtr(load(ctxptr, code.idx))) diff --git a/encode_vm_escaped_indent.go b/encode_vm_escaped_indent.go index 319e812..2cfd745 100644 --- a/encode_vm_escaped_indent.go +++ b/encode_vm_escaped_indent.go @@ -11,11 +11,12 @@ import ( "unsafe" ) -func (e *Encoder) runEscapedIndent(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) { +func (e *Encoder) runEscapedIndent(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet) ([]byte, error) { recursiveLevel := 0 var seenPtr map[uintptr]struct{} ptrOffset := uintptr(0) ctxptr := ctx.ptr() + code := codeSet.code for { switch code.op { diff --git a/encode_vm_indent.go b/encode_vm_indent.go index 536e962..78a6379 100644 --- a/encode_vm_indent.go +++ b/encode_vm_indent.go @@ -11,11 +11,12 @@ import ( "unsafe" ) -func (e *Encoder) runIndent(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) { +func (e *Encoder) runIndent(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet) ([]byte, error) { recursiveLevel := 0 var seenPtr map[uintptr]struct{} ptrOffset := uintptr(0) ctxptr := ctx.ptr() + code := codeSet.code for { switch code.op {