From 0297427ef56e1b6b8aa672f0bb71f50e2155ce63 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Mon, 1 Feb 2021 18:43:28 +0900 Subject: [PATCH 1/2] Fix encoding of MarshalJSON type --- encode_vm_escaped.go | 71 +++++++++++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/encode_vm_escaped.go b/encode_vm_escaped.go index 7bc194f..47f6486 100644 --- a/encode_vm_escaped.go +++ b/encode_vm_escaped.go @@ -7801,6 +7801,15 @@ func encodeRunEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet, o ptr := load(ctxptr, code.headIdx) b = append(b, code.escapedKey...) p := ptr + code.offset + if code.typ.Kind() == reflect.Ptr { + p = ptrToPtr(p) + } + if p == 0 { + b = encodeNull(b) + b = encodeComma(b) + code = code.next + break + } v := ptrToInterface(code, p) bb, err := v.(Marshaler).MarshalJSON() if err != nil { @@ -7816,23 +7825,25 @@ func encodeRunEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet, o case opStructFieldOmitEmptyMarshalJSON: ptr := load(ctxptr, code.headIdx) p := ptr + code.offset - if code.typ.Kind() == reflect.Ptr && code.typ.Elem().Implements(marshalJSONType) { + if code.typ.Kind() == reflect.Ptr { p = ptrToPtr(p) } - v := ptrToInterface(code, p) - if v != nil && p != 0 { - bb, err := v.(Marshaler).MarshalJSON() - if err != nil { - return nil, errMarshaler(code, err) - } - b = append(b, code.escapedKey...) - buf := bytes.NewBuffer(b) - if err := compact(buf, bb, true); err != nil { - return nil, err - } - b = buf.Bytes() - b = encodeComma(b) + if p == 0 { + code = code.next + break } + v := ptrToInterface(code, p) + bb, err := v.(Marshaler).MarshalJSON() + if err != nil { + return nil, errMarshaler(code, err) + } + b = append(b, code.escapedKey...) + buf := bytes.NewBuffer(b) + if err := compact(buf, bb, true); err != nil { + return nil, err + } + b = buf.Bytes() + b = encodeComma(b) code = code.next case opStructFieldStringTagMarshalJSON: ptr := load(ctxptr, code.headIdx) @@ -9256,20 +9267,10 @@ func encodeRunEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet, o case opStructEndOmitEmptyMarshalJSON: ptr := load(ctxptr, code.headIdx) p := ptr + code.offset - v := ptrToInterface(code, p) - if v != nil && (code.typ.Kind() != reflect.Ptr || ptrToPtr(p) != 0) { - bb, err := v.(Marshaler).MarshalJSON() - if err != nil { - return nil, errMarshaler(code, err) - } - b = append(b, code.escapedKey...) - buf := bytes.NewBuffer(b) - if err := compact(buf, bb, true); err != nil { - return nil, err - } - b = buf.Bytes() - b = appendStructEnd(b) - } else { + if code.typ.Kind() == reflect.Ptr { + p = ptrToPtr(p) + } + if p == 0 { last := len(b) - 1 if b[last] == ',' { b[last] = '}' @@ -9277,7 +9278,21 @@ func encodeRunEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet, o } else { b = appendStructEnd(b) } + code = code.next + break } + v := ptrToInterface(code, p) + bb, err := v.(Marshaler).MarshalJSON() + if err != nil { + return nil, errMarshaler(code, err) + } + b = append(b, code.escapedKey...) + buf := bytes.NewBuffer(b) + if err := compact(buf, bb, true); err != nil { + return nil, err + } + b = buf.Bytes() + b = appendStructEnd(b) code = code.next case opStructEndStringTagMarshalJSON: ptr := load(ctxptr, code.headIdx) From b431a095d6ce12c9b119a327cbb919614464b923 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Mon, 1 Feb 2021 20:02:43 +0900 Subject: [PATCH 2/2] Fix error by race detector --- encode_compile.go | 41 +++++++------------------------------ encode_compile_norace.go | 34 +++++++++++++++++++++++++++++++ encode_compile_race.go | 44 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 34 deletions(-) create mode 100644 encode_compile_norace.go create mode 100644 encode_compile_race.go diff --git a/encode_compile.go b/encode_compile.go index 04e256c..06f21a0 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -22,11 +22,12 @@ type opcodeSet struct { } var ( - marshalJSONType = reflect.TypeOf((*Marshaler)(nil)).Elem() - marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() - cachedOpcode unsafe.Pointer // map[uintptr]*opcodeSet - baseTypeAddr uintptr - cachedOpcodeSets []*opcodeSet + marshalJSONType = reflect.TypeOf((*Marshaler)(nil)).Elem() + marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + cachedOpcode unsafe.Pointer // map[uintptr]*opcodeSet + baseTypeAddr uintptr + cachedOpcodeSets []*opcodeSet + existsCachedOpcodeSets bool ) const ( @@ -80,6 +81,7 @@ func setupOpcodeSets() error { return fmt.Errorf("too big address range %d", addrRange) } cachedOpcodeSets = make([]*opcodeSet, addrRange) + existsCachedOpcodeSets = true baseTypeAddr = min return nil } @@ -90,35 +92,6 @@ func init() { } } -func encodeCompileToGetCodeSet(typeptr uintptr) (*opcodeSet, error) { - if cachedOpcodeSets == nil { - return encodeCompileToGetCodeSetSlowPath(typeptr) - } - if codeSet := cachedOpcodeSets[typeptr-baseTypeAddr]; codeSet != nil { - return codeSet, nil - } - - // noescape trick for header.typ ( reflect.*rtype ) - copiedType := *(**rtype)(unsafe.Pointer(&typeptr)) - - code, err := encodeCompileHead(&encodeCompileContext{ - typ: copiedType, - root: true, - structTypeToCompiledCode: map[uintptr]*compiledCode{}, - }) - if err != nil { - return nil, err - } - code = copyOpcode(code) - codeLength := code.totalLength() - codeSet := &opcodeSet{ - code: code, - codeLength: codeLength, - } - cachedOpcodeSets[int(typeptr-baseTypeAddr)] = codeSet - return codeSet, nil -} - func encodeCompileToGetCodeSetSlowPath(typeptr uintptr) (*opcodeSet, error) { opcodeMap := loadOpcodeMap() if codeSet, exists := opcodeMap[typeptr]; exists { diff --git a/encode_compile_norace.go b/encode_compile_norace.go new file mode 100644 index 0000000..209f60c --- /dev/null +++ b/encode_compile_norace.go @@ -0,0 +1,34 @@ +// +build !race + +package json + +import "unsafe" + +func encodeCompileToGetCodeSet(typeptr uintptr) (*opcodeSet, error) { + if !existsCachedOpcodeSets { + return encodeCompileToGetCodeSetSlowPath(typeptr) + } + if codeSet := cachedOpcodeSets[typeptr-baseTypeAddr]; codeSet != nil { + return codeSet, nil + } + + // noescape trick for header.typ ( reflect.*rtype ) + copiedType := *(**rtype)(unsafe.Pointer(&typeptr)) + + code, err := encodeCompileHead(&encodeCompileContext{ + typ: copiedType, + root: true, + structTypeToCompiledCode: map[uintptr]*compiledCode{}, + }) + if err != nil { + return nil, err + } + code = copyOpcode(code) + codeLength := code.totalLength() + codeSet := &opcodeSet{ + code: code, + codeLength: codeLength, + } + cachedOpcodeSets[int(typeptr-baseTypeAddr)] = codeSet + return codeSet, nil +} diff --git a/encode_compile_race.go b/encode_compile_race.go new file mode 100644 index 0000000..a2fcb73 --- /dev/null +++ b/encode_compile_race.go @@ -0,0 +1,44 @@ +// +build race + +package json + +import ( + "sync" + "unsafe" +) + +var setsMu sync.RWMutex + +func encodeCompileToGetCodeSet(typeptr uintptr) (*opcodeSet, error) { + if !existsCachedOpcodeSets { + return encodeCompileToGetCodeSetSlowPath(typeptr) + } + setsMu.RLock() + if codeSet := cachedOpcodeSets[typeptr-baseTypeAddr]; codeSet != nil { + setsMu.RUnlock() + return codeSet, nil + } + setsMu.RUnlock() + + // noescape trick for header.typ ( reflect.*rtype ) + copiedType := *(**rtype)(unsafe.Pointer(&typeptr)) + + code, err := encodeCompileHead(&encodeCompileContext{ + typ: copiedType, + root: true, + structTypeToCompiledCode: map[uintptr]*compiledCode{}, + }) + if err != nil { + return nil, err + } + code = copyOpcode(code) + codeLength := code.totalLength() + codeSet := &opcodeSet{ + code: code, + codeLength: codeLength, + } + setsMu.Lock() + cachedOpcodeSets[int(typeptr-baseTypeAddr)] = codeSet + setsMu.Unlock() + return codeSet, nil +}