diff --git a/decode_ptr.go b/decode_ptr.go index 900af57..90a038a 100644 --- a/decode_ptr.go +++ b/decode_ptr.go @@ -1,12 +1,15 @@ package json import ( + "sync" "unsafe" ) type ptrDecoder struct { - dec decoder - typ *rtype + dec decoder + typ *rtype + mu sync.Mutex + keepRefs []unsafe.Pointer } func newPtrDecoder(dec decoder, typ *rtype) *ptrDecoder { @@ -14,24 +17,27 @@ func newPtrDecoder(dec decoder, typ *rtype) *ptrDecoder { } //go:linkname unsafe_New reflect.unsafe_New -func unsafe_New(*rtype) uintptr +func unsafe_New(*rtype) unsafe.Pointer func (d *ptrDecoder) decodeStream(s *stream, p uintptr) error { newptr := unsafe_New(d.typ) - if err := d.dec.decodeStream(s, newptr); err != nil { + if err := d.dec.decodeStream(s, uintptr(newptr)); err != nil { return err } - **(**uintptr)(unsafe.Pointer(&p)) = newptr + **(**unsafe.Pointer)(unsafe.Pointer(&p)) = newptr return nil } func (d *ptrDecoder) decode(buf []byte, cursor int64, p uintptr) (int64, error) { + d.mu.Lock() newptr := unsafe_New(d.typ) - c, err := d.dec.decode(buf, cursor, newptr) + d.keepRefs = append(d.keepRefs, newptr) + **(**unsafe.Pointer)(unsafe.Pointer(&p)) = newptr + d.mu.Unlock() + c, err := d.dec.decode(buf, cursor, uintptr(newptr)) if err != nil { return 0, err } cursor = c - **(**uintptr)(unsafe.Pointer(&p)) = newptr return cursor, nil }