diff --git a/decode.go b/decode.go index 86cacbb..8419d82 100644 --- a/decode.go +++ b/decode.go @@ -37,30 +37,11 @@ func unmarshal(data []byte, v interface{}) error { copy(src, data) header := (*interfaceHeader)(unsafe.Pointer(&v)) - header.typ.escape() - return decode(src, header) -} -func unmarshalNoEscape(data []byte, v interface{}) error { - src := make([]byte, len(data)+1) // append nul byte to the end - copy(src, data) - - header := (*interfaceHeader)(unsafe.Pointer(&v)) - return decode(src, header) -} - -func decode(src []byte, header *interfaceHeader) error { - typ := header.typ - typeptr := uintptr(unsafe.Pointer(typ)) - - // noescape trick for header.typ ( *reflect.rtype ) - copiedType := *(**rtype)(unsafe.Pointer(&typeptr)) - ptr := uintptr(header.ptr) - - if err := validateType(copiedType, ptr); err != nil { + if err := validateType(header.typ, uintptr(header.ptr)); err != nil { return err } - dec, err := decodeCompileToGetDecoder(typeptr, typ) + dec, err := decodeCompileToGetDecoder(header.typ) if err != nil { return err } @@ -70,6 +51,31 @@ func decode(src []byte, header *interfaceHeader) error { return nil } +func unmarshalNoEscape(data []byte, v interface{}) error { + src := make([]byte, len(data)+1) // append nul byte to the end + copy(src, data) + + header := (*interfaceHeader)(unsafe.Pointer(&v)) + + if err := validateType(header.typ, uintptr(header.ptr)); err != nil { + return err + } + dec, err := decodeCompileToGetDecoder(header.typ) + if err != nil { + return err + } + if _, err := dec.decode(src, 0, noescape(header.ptr)); err != nil { + return err + } + return nil +} + +//go:nosplit +func noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + return unsafe.Pointer(x ^ 0) +} + func validateType(typ *rtype, p uintptr) error { if typ.Kind() != reflect.Ptr || p == 0 { return &InvalidUnmarshalError{Type: rtype2type(typ)} @@ -132,7 +138,7 @@ func (d *Decoder) Decode(v interface{}) error { return err } - dec, err := decodeCompileToGetDecoder(typeptr, typ) + dec, err := decodeCompileToGetDecoder(typ) if err != nil { return err } diff --git a/decode_compile_norace.go b/decode_compile_norace.go index 3ad763c..a20f38d 100644 --- a/decode_compile_norace.go +++ b/decode_compile_norace.go @@ -2,7 +2,10 @@ package json -func decodeCompileToGetDecoder(typeptr uintptr, typ *rtype) (decoder, error) { +import "unsafe" + +func decodeCompileToGetDecoder(typ *rtype) (decoder, error) { + typeptr := uintptr(unsafe.Pointer(typ)) if typeptr > maxTypeAddr { return decodeCompileToGetDecoderSlowPath(typeptr, typ) } diff --git a/decode_compile_race.go b/decode_compile_race.go index 921af83..fb601c4 100644 --- a/decode_compile_race.go +++ b/decode_compile_race.go @@ -4,11 +4,13 @@ package json import ( "sync" + "unsafe" ) var decMu sync.RWMutex -func decodeCompileToGetDecoder(typeptr uintptr, typ *rtype) (decoder, error) { +func decodeCompileToGetDecoder(typ *rtype) (decoder, error) { + typeptr := uintptr(unsafe.Pointer(typ)) if typeptr > maxTypeAddr { return decodeCompileToGetDecoderSlowPath(typeptr, typ) } diff --git a/decode_interface.go b/decode_interface.go index bf26954..bc4dadd 100644 --- a/decode_interface.go +++ b/decode_interface.go @@ -188,7 +188,7 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { *(*interface{})(p) = nil return nil } - decoder, err := decodeCompileToGetDecoder(uintptr(unsafe.Pointer(typ)), typ) + decoder, err := decodeCompileToGetDecoder(typ) if err != nil { return err } @@ -228,7 +228,7 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i **(**interface{})(unsafe.Pointer(&p)) = nil return cursor, nil } - decoder, err := decodeCompileToGetDecoder(uintptr(unsafe.Pointer(typ)), typ) + decoder, err := decodeCompileToGetDecoder(typ) if err != nil { return 0, err }