mirror of https://github.com/goccy/go-json.git
Fix recursive call
This commit is contained in:
parent
66ced55701
commit
ddfae9189e
|
@ -28,6 +28,9 @@ type Encoder struct {
|
||||||
|
|
||||||
type compiledCode struct {
|
type compiledCode struct {
|
||||||
code *opcode
|
code *opcode
|
||||||
|
linked bool // whether recursive code already have linked
|
||||||
|
curLen uintptr
|
||||||
|
nextLen uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -34,6 +34,7 @@ func (e *Encoder) compileHead(ctx *encodeCompileContext) (*opcode, error) {
|
||||||
}
|
}
|
||||||
e.convertHeadOnlyCode(code, isPtr)
|
e.convertHeadOnlyCode(code, isPtr)
|
||||||
e.optimizeStructEnd(code)
|
e.optimizeStructEnd(code)
|
||||||
|
e.linkRecursiveCode(code)
|
||||||
return code, nil
|
return code, nil
|
||||||
} else if isPtr && typ.Implements(marshalTextType) {
|
} else if isPtr && typ.Implements(marshalTextType) {
|
||||||
typ = orgType
|
typ = orgType
|
||||||
|
@ -46,9 +47,54 @@ func (e *Encoder) compileHead(ctx *encodeCompileContext) (*opcode, error) {
|
||||||
}
|
}
|
||||||
e.convertHeadOnlyCode(code, isPtr)
|
e.convertHeadOnlyCode(code, isPtr)
|
||||||
e.optimizeStructEnd(code)
|
e.optimizeStructEnd(code)
|
||||||
|
e.linkRecursiveCode(code)
|
||||||
return code, nil
|
return code, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Encoder) linkRecursiveCode(c *opcode) {
|
||||||
|
for code := c; code.op != opEnd && code.op != opStructFieldRecursiveEnd; {
|
||||||
|
switch code.op {
|
||||||
|
case opStructFieldRecursive,
|
||||||
|
opStructFieldPtrAnonymousHeadRecursive,
|
||||||
|
opStructFieldAnonymousHeadRecursive:
|
||||||
|
if code.jmp.linked {
|
||||||
|
code = code.next
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
code.jmp.code = copyOpcode(code.jmp.code)
|
||||||
|
c := code.jmp.code
|
||||||
|
c.end.next = newEndOp(&encodeCompileContext{})
|
||||||
|
c.op = c.op.ptrHeadToHead()
|
||||||
|
|
||||||
|
beforeLastCode := c.end
|
||||||
|
lastCode := beforeLastCode.next
|
||||||
|
|
||||||
|
lastCode.idx = beforeLastCode.idx + uintptrSize
|
||||||
|
lastCode.elemIdx = lastCode.idx + uintptrSize
|
||||||
|
|
||||||
|
// extend length to alloc slot for elemIdx
|
||||||
|
totalLength := uintptr(code.totalLength() + 1)
|
||||||
|
nextTotalLength := uintptr(c.totalLength() + 1)
|
||||||
|
|
||||||
|
c.end.next.op = opStructFieldRecursiveEnd
|
||||||
|
|
||||||
|
code.jmp.curLen = totalLength
|
||||||
|
code.jmp.nextLen = nextTotalLength
|
||||||
|
code.jmp.linked = true
|
||||||
|
|
||||||
|
e.linkRecursiveCode(code.jmp.code)
|
||||||
|
code = code.next
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch code.op.codeType() {
|
||||||
|
case codeArrayElem, codeSliceElem, codeMapKey:
|
||||||
|
code = code.end
|
||||||
|
default:
|
||||||
|
code = code.next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Encoder) optimizeStructEnd(c *opcode) {
|
func (e *Encoder) optimizeStructEnd(c *opcode) {
|
||||||
for code := c; code.op != opEnd; {
|
for code := c; code.op != opEnd; {
|
||||||
if code.op == opStructFieldRecursive {
|
if code.op == opStructFieldRecursive {
|
||||||
|
|
|
@ -86,6 +86,7 @@ func (c *encodeCompileContext) decPtrIndex() {
|
||||||
type encodeRuntimeContext struct {
|
type encodeRuntimeContext struct {
|
||||||
ptrs []uintptr
|
ptrs []uintptr
|
||||||
keepRefs []unsafe.Pointer
|
keepRefs []unsafe.Pointer
|
||||||
|
seenPtr []uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *encodeRuntimeContext) init(p uintptr, codelen int) {
|
func (c *encodeRuntimeContext) init(p uintptr, codelen int) {
|
||||||
|
@ -94,6 +95,7 @@ func (c *encodeRuntimeContext) init(p uintptr, codelen int) {
|
||||||
}
|
}
|
||||||
c.ptrs[0] = p
|
c.ptrs[0] = p
|
||||||
c.keepRefs = c.keepRefs[:0]
|
c.keepRefs = c.keepRefs[:0]
|
||||||
|
c.seenPtr = c.seenPtr[:0]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *encodeRuntimeContext) ptr() uintptr {
|
func (c *encodeRuntimeContext) ptr() uintptr {
|
||||||
|
|
47
encode_vm.go
47
encode_vm.go
|
@ -50,7 +50,6 @@ func errMarshaler(code *opcode, err error) *MarshalerError {
|
||||||
|
|
||||||
func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
|
func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
|
||||||
recursiveLevel := 0
|
recursiveLevel := 0
|
||||||
var seenPtr map[uintptr]struct{}
|
|
||||||
ptrOffset := uintptr(0)
|
ptrOffset := uintptr(0)
|
||||||
ctxptr := ctx.ptr()
|
ctxptr := ctx.ptr()
|
||||||
|
|
||||||
|
@ -200,13 +199,12 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
|
||||||
code = code.next
|
code = code.next
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if seenPtr == nil {
|
for _, seen := range ctx.seenPtr {
|
||||||
seenPtr = map[uintptr]struct{}{}
|
if ptr == seen {
|
||||||
}
|
|
||||||
if _, exists := seenPtr[ptr]; exists {
|
|
||||||
return nil, errUnsupportedValue(code, ptr)
|
return nil, errUnsupportedValue(code, ptr)
|
||||||
}
|
}
|
||||||
seenPtr[ptr] = struct{}{}
|
}
|
||||||
|
ctx.seenPtr = append(ctx.seenPtr, ptr)
|
||||||
v := e.ptrToInterface(code, ptr)
|
v := e.ptrToInterface(code, ptr)
|
||||||
ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v))
|
ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v))
|
||||||
rv := reflect.ValueOf(v)
|
rv := reflect.ValueOf(v)
|
||||||
|
@ -540,46 +538,29 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
|
||||||
ptr := load(ctxptr, code.idx)
|
ptr := load(ctxptr, code.idx)
|
||||||
if ptr != 0 {
|
if ptr != 0 {
|
||||||
if recursiveLevel > startDetectingCyclesAfter {
|
if recursiveLevel > startDetectingCyclesAfter {
|
||||||
if _, exists := seenPtr[ptr]; exists {
|
for _, seen := range ctx.seenPtr {
|
||||||
|
if ptr == seen {
|
||||||
return nil, errUnsupportedValue(code, ptr)
|
return nil, errUnsupportedValue(code, ptr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if seenPtr == nil {
|
|
||||||
seenPtr = map[uintptr]struct{}{}
|
|
||||||
}
|
}
|
||||||
seenPtr[ptr] = struct{}{}
|
ctx.seenPtr = append(ctx.seenPtr, ptr)
|
||||||
c := code.jmp.code
|
c := code.jmp.code
|
||||||
c.end.next = newEndOp(&encodeCompileContext{})
|
|
||||||
c.op = c.op.ptrHeadToHead()
|
|
||||||
|
|
||||||
beforeLastCode := c.end
|
|
||||||
lastCode := beforeLastCode.next
|
|
||||||
|
|
||||||
lastCode.idx = beforeLastCode.idx + uintptrSize
|
|
||||||
lastCode.elemIdx = lastCode.idx + uintptrSize
|
|
||||||
|
|
||||||
// extend length to alloc slot for elemIdx
|
|
||||||
totalLength := uintptr(code.totalLength() + 1)
|
|
||||||
nextTotalLength := uintptr(c.totalLength() + 1)
|
|
||||||
|
|
||||||
curlen := uintptr(len(ctx.ptrs))
|
curlen := uintptr(len(ctx.ptrs))
|
||||||
offsetNum := ptrOffset / uintptrSize
|
offsetNum := ptrOffset / uintptrSize
|
||||||
oldOffset := ptrOffset
|
oldOffset := ptrOffset
|
||||||
ptrOffset += totalLength * uintptrSize
|
ptrOffset += code.jmp.curLen * uintptrSize
|
||||||
|
|
||||||
newLen := offsetNum + totalLength + nextTotalLength
|
newLen := offsetNum + code.jmp.curLen + code.jmp.nextLen
|
||||||
if curlen < newLen {
|
if curlen < newLen {
|
||||||
ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...)
|
ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...)
|
||||||
}
|
}
|
||||||
ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr
|
ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr
|
||||||
|
|
||||||
store(ctxptr, c.idx, ptr)
|
store(ctxptr, c.idx, ptr)
|
||||||
store(ctxptr, lastCode.idx, oldOffset)
|
store(ctxptr, c.end.next.idx, oldOffset)
|
||||||
store(ctxptr, lastCode.elemIdx, uintptr(unsafe.Pointer(code.next)))
|
store(ctxptr, c.end.next.elemIdx, uintptr(unsafe.Pointer(code.next)))
|
||||||
|
|
||||||
// link lastCode ( opStructFieldRecursiveEnd ) => code.next
|
|
||||||
lastCode.op = opStructFieldRecursiveEnd
|
|
||||||
code = c
|
code = c
|
||||||
recursiveLevel++
|
recursiveLevel++
|
||||||
case opStructFieldRecursiveEnd:
|
case opStructFieldRecursiveEnd:
|
||||||
|
@ -587,8 +568,10 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
|
||||||
|
|
||||||
// restore ctxptr
|
// restore ctxptr
|
||||||
offset := load(ctxptr, code.idx)
|
offset := load(ctxptr, code.idx)
|
||||||
ptr := load(ctxptr, code.elemIdx)
|
ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1]
|
||||||
code = (*opcode)(e.ptrToUnsafePtr(ptr))
|
|
||||||
|
codePtr := load(ctxptr, code.elemIdx)
|
||||||
|
code = (*opcode)(e.ptrToUnsafePtr(codePtr))
|
||||||
ctxptr = ctx.ptr() + offset
|
ctxptr = ctx.ptr() + offset
|
||||||
ptrOffset = offset
|
ptrOffset = offset
|
||||||
case opStructFieldPtrHead:
|
case opStructFieldPtrHead:
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
|
|
||||||
func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
|
func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
|
||||||
recursiveLevel := 0
|
recursiveLevel := 0
|
||||||
var seenPtr map[uintptr]struct{}
|
|
||||||
ptrOffset := uintptr(0)
|
ptrOffset := uintptr(0)
|
||||||
ctxptr := ctx.ptr()
|
ctxptr := ctx.ptr()
|
||||||
|
|
||||||
|
@ -162,13 +161,12 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
|
||||||
code = code.next
|
code = code.next
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if seenPtr == nil {
|
for _, seen := range ctx.seenPtr {
|
||||||
seenPtr = map[uintptr]struct{}{}
|
if ptr == seen {
|
||||||
}
|
|
||||||
if _, exists := seenPtr[ptr]; exists {
|
|
||||||
return nil, errUnsupportedValue(code, ptr)
|
return nil, errUnsupportedValue(code, ptr)
|
||||||
}
|
}
|
||||||
seenPtr[ptr] = struct{}{}
|
}
|
||||||
|
ctx.seenPtr = append(ctx.seenPtr, ptr)
|
||||||
v := e.ptrToInterface(code, ptr)
|
v := e.ptrToInterface(code, ptr)
|
||||||
ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v))
|
ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v))
|
||||||
rv := reflect.ValueOf(v)
|
rv := reflect.ValueOf(v)
|
||||||
|
@ -502,46 +500,29 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
|
||||||
ptr := load(ctxptr, code.idx)
|
ptr := load(ctxptr, code.idx)
|
||||||
if ptr != 0 {
|
if ptr != 0 {
|
||||||
if recursiveLevel > startDetectingCyclesAfter {
|
if recursiveLevel > startDetectingCyclesAfter {
|
||||||
if _, exists := seenPtr[ptr]; exists {
|
for _, seen := range ctx.seenPtr {
|
||||||
|
if ptr == seen {
|
||||||
return nil, errUnsupportedValue(code, ptr)
|
return nil, errUnsupportedValue(code, ptr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if seenPtr == nil {
|
|
||||||
seenPtr = map[uintptr]struct{}{}
|
|
||||||
}
|
}
|
||||||
seenPtr[ptr] = struct{}{}
|
ctx.seenPtr = append(ctx.seenPtr, ptr)
|
||||||
c := code.jmp.code
|
c := code.jmp.code
|
||||||
c.end.next = newEndOp(&encodeCompileContext{})
|
|
||||||
c.op = c.op.ptrHeadToHead()
|
|
||||||
|
|
||||||
beforeLastCode := c.end
|
|
||||||
lastCode := beforeLastCode.next
|
|
||||||
|
|
||||||
lastCode.idx = beforeLastCode.idx + uintptrSize
|
|
||||||
lastCode.elemIdx = lastCode.idx + uintptrSize
|
|
||||||
|
|
||||||
// extend length to alloc slot for elemIdx
|
|
||||||
totalLength := uintptr(code.totalLength() + 1)
|
|
||||||
nextTotalLength := uintptr(c.totalLength() + 1)
|
|
||||||
|
|
||||||
curlen := uintptr(len(ctx.ptrs))
|
curlen := uintptr(len(ctx.ptrs))
|
||||||
offsetNum := ptrOffset / uintptrSize
|
offsetNum := ptrOffset / uintptrSize
|
||||||
oldOffset := ptrOffset
|
oldOffset := ptrOffset
|
||||||
ptrOffset += totalLength * uintptrSize
|
ptrOffset += code.jmp.curLen * uintptrSize
|
||||||
|
|
||||||
newLen := offsetNum + totalLength + nextTotalLength
|
newLen := offsetNum + code.jmp.curLen + code.jmp.nextLen
|
||||||
if curlen < newLen {
|
if curlen < newLen {
|
||||||
ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...)
|
ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...)
|
||||||
}
|
}
|
||||||
ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr
|
ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr
|
||||||
|
|
||||||
store(ctxptr, c.idx, ptr)
|
store(ctxptr, c.idx, ptr)
|
||||||
store(ctxptr, lastCode.idx, oldOffset)
|
store(ctxptr, c.end.next.idx, oldOffset)
|
||||||
store(ctxptr, lastCode.elemIdx, uintptr(unsafe.Pointer(code.next)))
|
store(ctxptr, c.end.next.elemIdx, uintptr(unsafe.Pointer(code.next)))
|
||||||
|
|
||||||
// link lastCode ( opStructFieldRecursiveEnd ) => code.next
|
|
||||||
lastCode.op = opStructFieldRecursiveEnd
|
|
||||||
code = c
|
code = c
|
||||||
recursiveLevel++
|
recursiveLevel++
|
||||||
case opStructFieldRecursiveEnd:
|
case opStructFieldRecursiveEnd:
|
||||||
|
@ -549,8 +530,10 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
|
||||||
|
|
||||||
// restore ctxptr
|
// restore ctxptr
|
||||||
offset := load(ctxptr, code.idx)
|
offset := load(ctxptr, code.idx)
|
||||||
ptr := load(ctxptr, code.elemIdx)
|
ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1]
|
||||||
code = (*opcode)(e.ptrToUnsafePtr(ptr))
|
|
||||||
|
codePtr := load(ctxptr, code.elemIdx)
|
||||||
|
code = (*opcode)(e.ptrToUnsafePtr(codePtr))
|
||||||
ctxptr = ctx.ptr() + offset
|
ctxptr = ctx.ptr() + offset
|
||||||
ptrOffset = offset
|
ptrOffset = offset
|
||||||
case opStructFieldPtrHead:
|
case opStructFieldPtrHead:
|
||||||
|
|
Loading…
Reference in New Issue