From 90d4d18dbf2502bed3a4836036b536f02252d334 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 20 May 2021 02:04:52 +0900 Subject: [PATCH] Keep original reference of slice element - If the entered slice length is greater than zero, copy all references to the shared slice. --- decode_slice.go | 49 +++++++++++++++++++++++++++++++++++-------------- decode_test.go | 35 ++++++++++++++++++++++++++++++----- 2 files changed, 65 insertions(+), 19 deletions(-) diff --git a/decode_slice.go b/decode_slice.go index b546c49..01f9dfe 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -49,9 +49,20 @@ func newSliceDecoder(dec decoder, elemType *rtype, size uintptr, structName, fie } } -func (d *sliceDecoder) newSlice() *sliceHeader { +func (d *sliceDecoder) newSlice(src *sliceHeader) *sliceHeader { slice := d.arrayPool.Get().(*sliceHeader) - slice.len = 0 + if src.len > 0 { + // copy original elem + if slice.cap < src.cap { + data := newArray(d.elemType, src.cap) + slice = &sliceHeader{data: data, len: src.len, cap: src.cap} + } else { + slice.len = src.len + } + copySlice(d.elemType, *slice, *src) + } else { + slice.len = 0 + } return slice } @@ -109,7 +120,8 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er return nil } idx := 0 - slice := d.newSlice() + slice := d.newSlice((*sliceHeader)(p)) + srcLen := slice.len capacity := slice.cap data := slice.data for { @@ -121,12 +133,17 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er copySlice(d.elemType, dst, src) } ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) - if d.isElemPointerType { - **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer - } else { - // assign new element to the slice - typedmemmove(d.elemType, ep, unsafe_New(d.elemType)) + + // if srcLen is greater than idx, keep the original reference + if srcLen <= idx { + if d.isElemPointerType { + **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer + } else { + // assign new element to the slice + typedmemmove(d.elemType, ep, unsafe_New(d.elemType)) + } } + if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil { return err } @@ -212,7 +229,8 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) return cursor, nil } idx := 0 - slice := d.newSlice() + slice := d.newSlice((*sliceHeader)(p)) + srcLen := slice.len capacity := slice.cap data := slice.data for { @@ -224,11 +242,14 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) copySlice(d.elemType, dst, src) } ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) - if d.isElemPointerType { - **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer - } else { - // assign new element to the slice - typedmemmove(d.elemType, ep, unsafe_New(d.elemType)) + // if srcLen is greater than idx, keep the original reference + if srcLen <= idx { + if d.isElemPointerType { + **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer + } else { + // assign new element to the slice + typedmemmove(d.elemType, ep, unsafe_New(d.elemType)) + } } c, err := d.valueDecoder.decode(buf, cursor, depth, ep) if err != nil { diff --git a/decode_test.go b/decode_test.go index cdacbc5..7dab500 100644 --- a/decode_test.go +++ b/decode_test.go @@ -3052,7 +3052,7 @@ func TestMultipleDecodeWithRawMessage(t *testing.T) { type intUnmarshaler int func (u *intUnmarshaler) UnmarshalJSON(b []byte) error { - if *u != 0 { + if *u != 0 && *u != 10 { return fmt.Errorf("failed to decode of slice with int unmarshaler") } *u = 10 @@ -3062,7 +3062,7 @@ func (u *intUnmarshaler) UnmarshalJSON(b []byte) error { type arrayUnmarshaler [5]int func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error { - if (*u)[0] != 0 { + if (*u)[0] != 0 && (*u)[0] != 10 { return fmt.Errorf("failed to decode of slice with array unmarshaler") } (*u)[0] = 10 @@ -3072,7 +3072,7 @@ func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error { type mapUnmarshaler map[string]int func (u *mapUnmarshaler) UnmarshalJSON(b []byte) error { - if len(*u) != 0 { + if len(*u) != 0 && len(*u) != 1 { return fmt.Errorf("failed to decode of slice with map unmarshaler") } *u = map[string]int{"a": 10} @@ -3080,14 +3080,16 @@ func (u *mapUnmarshaler) UnmarshalJSON(b []byte) error { } type structUnmarshaler struct { - A int + A int + notFirst bool } func (u *structUnmarshaler) UnmarshalJSON(b []byte) error { - if u.A != 0 { + if !u.notFirst && u.A != 0 { return fmt.Errorf("failed to decode of slice with struct unmarshaler") } u.A = 10 + u.notFirst = true return nil } @@ -3199,6 +3201,29 @@ func TestSliceElemUnmarshaler(t *testing.T) { }) } +type keepRefTest struct { + A int + B string +} + +func (t *keepRefTest) UnmarshalJSON(data []byte) error { + v := []interface{}{&t.A, &t.B} + return json.Unmarshal(data, &v) +} + +func TestKeepReferenceSlice(t *testing.T) { + var v keepRefTest + if err := json.Unmarshal([]byte(`[54,"hello"]`), &v); err != nil { + t.Fatal(err) + } + if v.A != 54 { + t.Fatal("failed to keep reference for slice") + } + if v.B != "hello" { + t.Fatal("failed to keep reference for slice") + } +} + func TestInvalidTopLevelValue(t *testing.T) { t.Run("invalid end of buffer", func(t *testing.T) { var v struct{}