From ddfae9189eb364ca0f0c69b85639a0ac2e47a4fc Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 15 Jan 2021 16:25:00 +0900 Subject: [PATCH] Fix recursive call --- encode.go | 5 ++++- encode_compile.go | 46 ++++++++++++++++++++++++++++++++++++++ encode_context.go | 2 ++ encode_vm.go | 53 +++++++++++++++----------------------------- encode_vm_escaped.go | 53 +++++++++++++++----------------------------- 5 files changed, 88 insertions(+), 71 deletions(-) diff --git a/encode.go b/encode.go index 94a1108..50cb56a 100644 --- a/encode.go +++ b/encode.go @@ -27,7 +27,10 @@ type Encoder struct { } type compiledCode struct { - code *opcode + code *opcode + linked bool // whether recursive code already have linked + curLen uintptr + nextLen uintptr } const ( diff --git a/encode_compile.go b/encode_compile.go index cece1e9..1bfb3d3 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -34,6 +34,7 @@ func (e *Encoder) compileHead(ctx *encodeCompileContext) (*opcode, error) { } e.convertHeadOnlyCode(code, isPtr) e.optimizeStructEnd(code) + e.linkRecursiveCode(code) return code, nil } else if isPtr && typ.Implements(marshalTextType) { typ = orgType @@ -46,9 +47,54 @@ func (e *Encoder) compileHead(ctx *encodeCompileContext) (*opcode, error) { } e.convertHeadOnlyCode(code, isPtr) e.optimizeStructEnd(code) + e.linkRecursiveCode(code) return code, nil } +func (e *Encoder) linkRecursiveCode(c *opcode) { + for code := c; code.op != opEnd && code.op != opStructFieldRecursiveEnd; { + switch code.op { + case opStructFieldRecursive, + opStructFieldPtrAnonymousHeadRecursive, + opStructFieldAnonymousHeadRecursive: + if code.jmp.linked { + code = code.next + continue + } + code.jmp.code = copyOpcode(code.jmp.code) + c := code.jmp.code + c.end.next = newEndOp(&encodeCompileContext{}) + c.op = c.op.ptrHeadToHead() + + beforeLastCode := c.end + lastCode := beforeLastCode.next + + lastCode.idx = beforeLastCode.idx + uintptrSize + lastCode.elemIdx = lastCode.idx + uintptrSize + + // extend length to alloc slot for elemIdx + totalLength := uintptr(code.totalLength() + 1) + nextTotalLength := uintptr(c.totalLength() + 1) + + c.end.next.op = opStructFieldRecursiveEnd + + code.jmp.curLen = totalLength + code.jmp.nextLen = nextTotalLength + code.jmp.linked = true + + e.linkRecursiveCode(code.jmp.code) + code = code.next + continue + } + switch code.op.codeType() { + case codeArrayElem, codeSliceElem, codeMapKey: + code = code.end + default: + code = code.next + } + } +} + func (e *Encoder) optimizeStructEnd(c *opcode) { for code := c; code.op != opEnd; { if code.op == opStructFieldRecursive { diff --git a/encode_context.go b/encode_context.go index 52f1e93..0713f90 100644 --- a/encode_context.go +++ b/encode_context.go @@ -86,6 +86,7 @@ func (c *encodeCompileContext) decPtrIndex() { type encodeRuntimeContext struct { ptrs []uintptr keepRefs []unsafe.Pointer + seenPtr []uintptr } func (c *encodeRuntimeContext) init(p uintptr, codelen int) { @@ -94,6 +95,7 @@ func (c *encodeRuntimeContext) init(p uintptr, codelen int) { } c.ptrs[0] = p c.keepRefs = c.keepRefs[:0] + c.seenPtr = c.seenPtr[:0] } func (c *encodeRuntimeContext) ptr() uintptr { diff --git a/encode_vm.go b/encode_vm.go index 58956ff..c68caf6 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -50,7 +50,6 @@ func errMarshaler(code *opcode, err error) *MarshalerError { func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) { recursiveLevel := 0 - var seenPtr map[uintptr]struct{} ptrOffset := uintptr(0) ctxptr := ctx.ptr() @@ -200,13 +199,12 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte code = code.next break } - if seenPtr == nil { - seenPtr = map[uintptr]struct{}{} + for _, seen := range ctx.seenPtr { + if ptr == seen { + return nil, errUnsupportedValue(code, ptr) + } } - if _, exists := seenPtr[ptr]; exists { - return nil, errUnsupportedValue(code, ptr) - } - seenPtr[ptr] = struct{}{} + ctx.seenPtr = append(ctx.seenPtr, ptr) v := e.ptrToInterface(code, ptr) ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v)) rv := reflect.ValueOf(v) @@ -540,46 +538,29 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte ptr := load(ctxptr, code.idx) if ptr != 0 { if recursiveLevel > startDetectingCyclesAfter { - if _, exists := seenPtr[ptr]; exists { - return nil, errUnsupportedValue(code, ptr) + for _, seen := range ctx.seenPtr { + if ptr == seen { + return nil, errUnsupportedValue(code, ptr) + } } } } - if seenPtr == nil { - seenPtr = map[uintptr]struct{}{} - } - seenPtr[ptr] = struct{}{} + ctx.seenPtr = append(ctx.seenPtr, ptr) c := code.jmp.code - c.end.next = newEndOp(&encodeCompileContext{}) - c.op = c.op.ptrHeadToHead() - - beforeLastCode := c.end - lastCode := beforeLastCode.next - - lastCode.idx = beforeLastCode.idx + uintptrSize - lastCode.elemIdx = lastCode.idx + uintptrSize - - // extend length to alloc slot for elemIdx - totalLength := uintptr(code.totalLength() + 1) - nextTotalLength := uintptr(c.totalLength() + 1) - curlen := uintptr(len(ctx.ptrs)) offsetNum := ptrOffset / uintptrSize oldOffset := ptrOffset - ptrOffset += totalLength * uintptrSize + ptrOffset += code.jmp.curLen * uintptrSize - newLen := offsetNum + totalLength + nextTotalLength + newLen := offsetNum + code.jmp.curLen + code.jmp.nextLen if curlen < newLen { ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...) } ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr store(ctxptr, c.idx, ptr) - store(ctxptr, lastCode.idx, oldOffset) - store(ctxptr, lastCode.elemIdx, uintptr(unsafe.Pointer(code.next))) - - // link lastCode ( opStructFieldRecursiveEnd ) => code.next - lastCode.op = opStructFieldRecursiveEnd + store(ctxptr, c.end.next.idx, oldOffset) + store(ctxptr, c.end.next.elemIdx, uintptr(unsafe.Pointer(code.next))) code = c recursiveLevel++ case opStructFieldRecursiveEnd: @@ -587,8 +568,10 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte // restore ctxptr offset := load(ctxptr, code.idx) - ptr := load(ctxptr, code.elemIdx) - code = (*opcode)(e.ptrToUnsafePtr(ptr)) + ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1] + + codePtr := load(ctxptr, code.elemIdx) + code = (*opcode)(e.ptrToUnsafePtr(codePtr)) ctxptr = ctx.ptr() + offset ptrOffset = offset case opStructFieldPtrHead: diff --git a/encode_vm_escaped.go b/encode_vm_escaped.go index e08fde6..862b923 100644 --- a/encode_vm_escaped.go +++ b/encode_vm_escaped.go @@ -12,7 +12,6 @@ import ( func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) { recursiveLevel := 0 - var seenPtr map[uintptr]struct{} ptrOffset := uintptr(0) ctxptr := ctx.ptr() @@ -162,13 +161,12 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) code = code.next break } - if seenPtr == nil { - seenPtr = map[uintptr]struct{}{} + for _, seen := range ctx.seenPtr { + if ptr == seen { + return nil, errUnsupportedValue(code, ptr) + } } - if _, exists := seenPtr[ptr]; exists { - return nil, errUnsupportedValue(code, ptr) - } - seenPtr[ptr] = struct{}{} + ctx.seenPtr = append(ctx.seenPtr, ptr) v := e.ptrToInterface(code, ptr) ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v)) rv := reflect.ValueOf(v) @@ -502,46 +500,29 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) ptr := load(ctxptr, code.idx) if ptr != 0 { if recursiveLevel > startDetectingCyclesAfter { - if _, exists := seenPtr[ptr]; exists { - return nil, errUnsupportedValue(code, ptr) + for _, seen := range ctx.seenPtr { + if ptr == seen { + return nil, errUnsupportedValue(code, ptr) + } } } } - if seenPtr == nil { - seenPtr = map[uintptr]struct{}{} - } - seenPtr[ptr] = struct{}{} + ctx.seenPtr = append(ctx.seenPtr, ptr) c := code.jmp.code - c.end.next = newEndOp(&encodeCompileContext{}) - c.op = c.op.ptrHeadToHead() - - beforeLastCode := c.end - lastCode := beforeLastCode.next - - lastCode.idx = beforeLastCode.idx + uintptrSize - lastCode.elemIdx = lastCode.idx + uintptrSize - - // extend length to alloc slot for elemIdx - totalLength := uintptr(code.totalLength() + 1) - nextTotalLength := uintptr(c.totalLength() + 1) - curlen := uintptr(len(ctx.ptrs)) offsetNum := ptrOffset / uintptrSize oldOffset := ptrOffset - ptrOffset += totalLength * uintptrSize + ptrOffset += code.jmp.curLen * uintptrSize - newLen := offsetNum + totalLength + nextTotalLength + newLen := offsetNum + code.jmp.curLen + code.jmp.nextLen if curlen < newLen { ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...) } ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr store(ctxptr, c.idx, ptr) - store(ctxptr, lastCode.idx, oldOffset) - store(ctxptr, lastCode.elemIdx, uintptr(unsafe.Pointer(code.next))) - - // link lastCode ( opStructFieldRecursiveEnd ) => code.next - lastCode.op = opStructFieldRecursiveEnd + store(ctxptr, c.end.next.idx, oldOffset) + store(ctxptr, c.end.next.elemIdx, uintptr(unsafe.Pointer(code.next))) code = c recursiveLevel++ case opStructFieldRecursiveEnd: @@ -549,8 +530,10 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) // restore ctxptr offset := load(ctxptr, code.idx) - ptr := load(ctxptr, code.elemIdx) - code = (*opcode)(e.ptrToUnsafePtr(ptr)) + ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1] + + codePtr := load(ctxptr, code.elemIdx) + code = (*opcode)(e.ptrToUnsafePtr(codePtr)) ctxptr = ctx.ptr() + offset ptrOffset = offset case opStructFieldPtrHead: