Merge pull request #229 from goccy/feature/keep-original-slice-reference

Keep original reference of slice element
This commit is contained in:
Masaaki Goshima 2021-05-20 15:47:34 +09:00 committed by GitHub
commit 902fd6a1b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 19 deletions

View File

@ -49,9 +49,20 @@ func newSliceDecoder(dec decoder, elemType *rtype, size uintptr, structName, fie
} }
} }
func (d *sliceDecoder) newSlice() *sliceHeader { func (d *sliceDecoder) newSlice(src *sliceHeader) *sliceHeader {
slice := d.arrayPool.Get().(*sliceHeader) slice := d.arrayPool.Get().(*sliceHeader)
if src.len > 0 {
// copy original elem
if slice.cap < src.cap {
data := newArray(d.elemType, src.cap)
slice = &sliceHeader{data: data, len: src.len, cap: src.cap}
} else {
slice.len = src.len
}
copySlice(d.elemType, *slice, *src)
} else {
slice.len = 0 slice.len = 0
}
return slice return slice
} }
@ -109,7 +120,8 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er
return nil return nil
} }
idx := 0 idx := 0
slice := d.newSlice() slice := d.newSlice((*sliceHeader)(p))
srcLen := slice.len
capacity := slice.cap capacity := slice.cap
data := slice.data data := slice.data
for { for {
@ -121,12 +133,17 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er
copySlice(d.elemType, dst, src) copySlice(d.elemType, dst, src)
} }
ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size)
// if srcLen is greater than idx, keep the original reference
if srcLen <= idx {
if d.isElemPointerType { if d.isElemPointerType {
**(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer
} else { } else {
// assign new element to the slice // assign new element to the slice
typedmemmove(d.elemType, ep, unsafe_New(d.elemType)) typedmemmove(d.elemType, ep, unsafe_New(d.elemType))
} }
}
if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil { if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil {
return err return err
} }
@ -212,7 +229,8 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer)
return cursor, nil return cursor, nil
} }
idx := 0 idx := 0
slice := d.newSlice() slice := d.newSlice((*sliceHeader)(p))
srcLen := slice.len
capacity := slice.cap capacity := slice.cap
data := slice.data data := slice.data
for { for {
@ -224,12 +242,15 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer)
copySlice(d.elemType, dst, src) copySlice(d.elemType, dst, src)
} }
ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size)
// if srcLen is greater than idx, keep the original reference
if srcLen <= idx {
if d.isElemPointerType { if d.isElemPointerType {
**(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer
} else { } else {
// assign new element to the slice // assign new element to the slice
typedmemmove(d.elemType, ep, unsafe_New(d.elemType)) typedmemmove(d.elemType, ep, unsafe_New(d.elemType))
} }
}
c, err := d.valueDecoder.decode(buf, cursor, depth, ep) c, err := d.valueDecoder.decode(buf, cursor, depth, ep)
if err != nil { if err != nil {
return 0, err return 0, err

View File

@ -3052,7 +3052,7 @@ func TestMultipleDecodeWithRawMessage(t *testing.T) {
type intUnmarshaler int type intUnmarshaler int
func (u *intUnmarshaler) UnmarshalJSON(b []byte) error { func (u *intUnmarshaler) UnmarshalJSON(b []byte) error {
if *u != 0 { if *u != 0 && *u != 10 {
return fmt.Errorf("failed to decode of slice with int unmarshaler") return fmt.Errorf("failed to decode of slice with int unmarshaler")
} }
*u = 10 *u = 10
@ -3062,7 +3062,7 @@ func (u *intUnmarshaler) UnmarshalJSON(b []byte) error {
type arrayUnmarshaler [5]int type arrayUnmarshaler [5]int
func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error { func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error {
if (*u)[0] != 0 { if (*u)[0] != 0 && (*u)[0] != 10 {
return fmt.Errorf("failed to decode of slice with array unmarshaler") return fmt.Errorf("failed to decode of slice with array unmarshaler")
} }
(*u)[0] = 10 (*u)[0] = 10
@ -3072,7 +3072,7 @@ func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error {
type mapUnmarshaler map[string]int type mapUnmarshaler map[string]int
func (u *mapUnmarshaler) UnmarshalJSON(b []byte) error { func (u *mapUnmarshaler) UnmarshalJSON(b []byte) error {
if len(*u) != 0 { if len(*u) != 0 && len(*u) != 1 {
return fmt.Errorf("failed to decode of slice with map unmarshaler") return fmt.Errorf("failed to decode of slice with map unmarshaler")
} }
*u = map[string]int{"a": 10} *u = map[string]int{"a": 10}
@ -3081,13 +3081,15 @@ func (u *mapUnmarshaler) UnmarshalJSON(b []byte) error {
type structUnmarshaler struct { type structUnmarshaler struct {
A int A int
notFirst bool
} }
func (u *structUnmarshaler) UnmarshalJSON(b []byte) error { func (u *structUnmarshaler) UnmarshalJSON(b []byte) error {
if u.A != 0 { if !u.notFirst && u.A != 0 {
return fmt.Errorf("failed to decode of slice with struct unmarshaler") return fmt.Errorf("failed to decode of slice with struct unmarshaler")
} }
u.A = 10 u.A = 10
u.notFirst = true
return nil return nil
} }
@ -3199,6 +3201,29 @@ func TestSliceElemUnmarshaler(t *testing.T) {
}) })
} }
type keepRefTest struct {
A int
B string
}
func (t *keepRefTest) UnmarshalJSON(data []byte) error {
v := []interface{}{&t.A, &t.B}
return json.Unmarshal(data, &v)
}
func TestKeepReferenceSlice(t *testing.T) {
var v keepRefTest
if err := json.Unmarshal([]byte(`[54,"hello"]`), &v); err != nil {
t.Fatal(err)
}
if v.A != 54 {
t.Fatal("failed to keep reference for slice")
}
if v.B != "hello" {
t.Fatal("failed to keep reference for slice")
}
}
func TestInvalidTopLevelValue(t *testing.T) { func TestInvalidTopLevelValue(t *testing.T) {
t.Run("invalid end of buffer", func(t *testing.T) { t.Run("invalid end of buffer", func(t *testing.T) {
var v struct{} var v struct{}