From dc2d025d2a8643120e40ea526a3093d9905b0c01 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 30 Apr 2021 04:01:51 +0900 Subject: [PATCH 1/7] Add test case --- decode_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/decode_test.go b/decode_test.go index 574341f..a7a7bec 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2931,6 +2931,71 @@ func TestDecodeSlice(t *testing.T) { } } +func TestDecodeMultipleUnmarshal(t *testing.T) { + data := []byte(`[{"AA":{"X":[{"a": "A"},{"b": "B"}],"Y":"y","Z":"z"},"BB":"bb"},{"AA":{"X":[],"Y":"y","Z":"z"},"BB":"bb"}]`) + var a []json.RawMessage + if err := json.Unmarshal(data, &a); err != nil { + t.Fatal(err) + } + if len(a) != 2 { + t.Fatalf("failed to decode: got %v", a) + } + t.Run("first", func(t *testing.T) { + data := a[0] + var v map[string]json.RawMessage + if err := json.Unmarshal(data, &v); err != nil { + t.Fatal(err) + } + if string(v["AA"]) != `{"X":[{"a": "A"},{"b": "B"}],"Y":"y","Z":"z"}` { + t.Fatalf("failed to decode. got %q", v["AA"]) + } + var aa map[string]json.RawMessage + if err := json.Unmarshal(v["AA"], &aa); err != nil { + t.Fatal(err) + } + if string(aa["X"]) != `[{"a": "A"},{"b": "B"}]` { + t.Fatalf("failed to decode. got %q", v["X"]) + } + var x []json.RawMessage + if err := json.Unmarshal(aa["X"], &x); err != nil { + t.Fatal(err) + } + if len(x) != 2 { + t.Fatalf("failed to decode: %v", x) + } + if string(x[0]) != `{"a": "A"}` { + t.Fatal("failed to decode") + } + if string(x[1]) != `{"b": "B"}` { + t.Fatal("failed to decode") + } + }) + t.Run("second", func(t *testing.T) { + data := a[1] + var v map[string]json.RawMessage + if err := json.Unmarshal(data, &v); err != nil { + t.Fatal(err) + } + if string(v["AA"]) != `{"X":[],"Y":"y","Z":"z"}` { + t.Fatalf("failed to decode. got %q", v["AA"]) + } + var aa map[string]json.RawMessage + if err := json.Unmarshal(v["AA"], &aa); err != nil { + t.Fatal(err) + } + if string(aa["X"]) != `[]` { + t.Fatalf("failed to decode. got %q", v["X"]) + } + var x []json.RawMessage + if err := json.Unmarshal(aa["X"], &x); err != nil { + t.Fatal(err) + } + if len(x) != 0 { + t.Fatalf("failed to decode: %v", x) + } + }) +} + func TestInvalidTopLevelValue(t *testing.T) { t.Run("invalid end of buffer", func(t *testing.T) { var v struct{} From 16a358048e466a72ae05b593731f81ad81509747 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 30 Apr 2021 04:02:06 +0900 Subject: [PATCH 2/7] if elem type is slice, clear it --- decode_slice.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/decode_slice.go b/decode_slice.go index ea09f88..603d34b 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -9,6 +9,7 @@ import ( type sliceDecoder struct { elemType *rtype isElemPointerType bool + isElemSliceType bool valueDecoder decoder size uintptr arrayPool sync.Pool @@ -34,6 +35,7 @@ func newSliceDecoder(dec decoder, elemType *rtype, size uintptr, structName, fie valueDecoder: dec, elemType: elemType, isElemPointerType: elemType.Kind() == reflect.Ptr || elemType.Kind() == reflect.Map, + isElemSliceType: elemType.Kind() == reflect.Slice, size: size, arrayPool: sync.Pool{ New: func() interface{} { @@ -233,6 +235,12 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) if d.isElemPointerType { *(*unsafe.Pointer)(ep) = nil // initialize elem pointer + } else if d.isElemSliceType { + *(*sliceHeader)(ep) = sliceHeader{ + data: newArray(d.elemType, 0), + len: 0, + cap: 0, + } } c, err := d.valueDecoder.decode(buf, cursor, depth, ep) if err != nil { From 17392ab7167e4c2eb7db12c1359a174c431070a5 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 30 Apr 2021 04:06:23 +0900 Subject: [PATCH 3/7] Fix stream decoder --- decode_slice.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/decode_slice.go b/decode_slice.go index 603d34b..aa67ef2 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -121,6 +121,12 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) if d.isElemPointerType { *(*unsafe.Pointer)(ep) = nil // initialize elem pointer + } else if d.isElemSliceType { + *(*sliceHeader)(ep) = sliceHeader{ + data: newArray(d.elemType, 0), + len: 0, + cap: 0, + } } if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil { return err From e647dafb4138dffc29101cbd4ea83be8100bb9ea Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 30 Apr 2021 22:55:08 +0900 Subject: [PATCH 4/7] Fix decoding of slice with unmarshal json type --- decode_slice.go | 80 +++++++++++++++----------- decode_test.go | 150 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 195 insertions(+), 35 deletions(-) diff --git a/decode_slice.go b/decode_slice.go index aa67ef2..98f24d5 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -7,14 +7,14 @@ import ( ) type sliceDecoder struct { - elemType *rtype - isElemPointerType bool - isElemSliceType bool - valueDecoder decoder - size uintptr - arrayPool sync.Pool - structName string - fieldName string + elemType *rtype + isElemPointerType bool + isElemUnmarshalJSONType bool + valueDecoder decoder + size uintptr + arrayPool sync.Pool + structName string + fieldName string } // If use reflect.SliceHeader, data type is uintptr. @@ -32,11 +32,11 @@ const ( func newSliceDecoder(dec decoder, elemType *rtype, size uintptr, structName, fieldName string) *sliceDecoder { return &sliceDecoder{ - valueDecoder: dec, - elemType: elemType, - isElemPointerType: elemType.Kind() == reflect.Ptr || elemType.Kind() == reflect.Map, - isElemSliceType: elemType.Kind() == reflect.Slice, - size: size, + valueDecoder: dec, + elemType: elemType, + isElemPointerType: elemType.Kind() == reflect.Ptr || elemType.Kind() == reflect.Map, + isElemUnmarshalJSONType: rtype_ptrTo(elemType).Implements(unmarshalJSONType), + size: size, arrayPool: sync.Pool{ New: func() interface{} { return &sliceHeader{ @@ -119,17 +119,19 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er copySlice(d.elemType, dst, src) } ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) - if d.isElemPointerType { - *(*unsafe.Pointer)(ep) = nil // initialize elem pointer - } else if d.isElemSliceType { - *(*sliceHeader)(ep) = sliceHeader{ - data: newArray(d.elemType, 0), - len: 0, - cap: 0, + if d.isElemUnmarshalJSONType { + receiver := unsafe_New(d.elemType) + if err := d.valueDecoder.decodeStream(s, depth, receiver); err != nil { + return err + } + *(*unsafe.Pointer)(ep) = **(**unsafe.Pointer)(unsafe.Pointer(&receiver)) + } else { + if d.isElemPointerType { + *(*unsafe.Pointer)(ep) = nil // initialize elem pointer + } + if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil { + return err } - } - if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil { - return err } s.skipWhiteSpace() RETRY: @@ -239,20 +241,28 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) copySlice(d.elemType, dst, src) } ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) - if d.isElemPointerType { - *(*unsafe.Pointer)(ep) = nil // initialize elem pointer - } else if d.isElemSliceType { - *(*sliceHeader)(ep) = sliceHeader{ - data: newArray(d.elemType, 0), - len: 0, - cap: 0, + if d.isElemUnmarshalJSONType { + receiver := unsafe_New(d.elemType) + if d.elemType.Kind() == reflect.Slice { + *(*unsafe.Pointer)(ep) = receiver + } else { + *(*unsafe.Pointer)(ep) = *(*unsafe.Pointer)(receiver) } + c, err := d.valueDecoder.decode(buf, cursor, depth, ep) + if err != nil { + return 0, err + } + cursor = c + } else { + if d.isElemPointerType { + *(*unsafe.Pointer)(ep) = nil // initialize elem pointer + } + c, err := d.valueDecoder.decode(buf, cursor, depth, ep) + if err != nil { + return 0, err + } + cursor = c } - c, err := d.valueDecoder.decode(buf, cursor, depth, ep) - if err != nil { - return 0, err - } - cursor = c cursor = skipWhiteSpace(buf, cursor) switch buf[cursor] { case ']': diff --git a/decode_test.go b/decode_test.go index a7a7bec..3073553 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2996,6 +2996,156 @@ func TestDecodeMultipleUnmarshal(t *testing.T) { }) } +type intUnmarshaler int + +func (u *intUnmarshaler) UnmarshalJSON(b []byte) error { + if *u != 0 { + return fmt.Errorf("failed to decode of slice with int unmarshaler") + } + *u = 10 + return nil +} + +type arrayUnmarshaler [5]int + +func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error { + if (*u)[0] != 0 { + return fmt.Errorf("failed to decode of slice with array unmarshaler") + } + (*u)[0] = 10 + return nil +} + +type mapUnmarshaler map[string]int + +func (u *mapUnmarshaler) UnmarshalJSON(b []byte) error { + if len(*u) != 0 { + return fmt.Errorf("failed to decode of slice with map unmarshaler") + } + *u = map[string]int{"a": 10} + return nil +} + +type structUnmarshaler struct { + A int +} + +func (u *structUnmarshaler) UnmarshalJSON(b []byte) error { + if u.A != 0 { + return fmt.Errorf("failed to decode of slice with struct unmarshaler") + } + u.A = 10 + return nil +} + +func TestSliceElemUnmarshaler(t *testing.T) { + t.Run("int", func(t *testing.T) { + var v []intUnmarshaler + if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 5 { + t.Fatalf("failed to decode of slice with int unmarshaler: %v", v) + } + if v[0] != 10 { + t.Fatalf("failed to decode of slice with int unmarshaler: %v", v) + } + if err := json.Unmarshal([]byte(`[6]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 1 { + t.Fatalf("failed to decode of slice with int unmarshaler: %v", v) + } + if v[0] != 10 { + t.Fatalf("failed to decode of slice with int unmarshaler: %v", v) + } + }) + t.Run("slice", func(t *testing.T) { + var v []json.RawMessage + if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 5 { + t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v) + } + if len(v[0]) != 1 { + t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v) + } + if err := json.Unmarshal([]byte(`[6]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 1 { + t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v) + } + if len(v[0]) != 1 { + t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v) + } + }) + t.Run("array", func(t *testing.T) { + var v []arrayUnmarshaler + if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 5 { + t.Fatalf("failed to decode of slice with array unmarshaler: %v", v) + } + if v[0][0] != 10 { + t.Fatalf("failed to decode of slice with array unmarshaler: %v", v) + } + if err := json.Unmarshal([]byte(`[6]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 1 { + t.Fatalf("failed to decode of slice with array unmarshaler: %v", v) + } + if v[0][0] != 10 { + t.Fatalf("failed to decode of slice with array unmarshaler: %v", v) + } + }) + t.Run("map", func(t *testing.T) { + var v []mapUnmarshaler + if err := json.Unmarshal([]byte(`[{"a":1},{"b":2},{"c":3},{"d":4},{"e":5}]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 5 { + t.Fatalf("failed to decode of slice with map unmarshaler: %v", v) + } + if v[0]["a"] != 10 { + t.Fatalf("failed to decode of slice with map unmarshaler: %v", v) + } + if err := json.Unmarshal([]byte(`[6]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 1 { + t.Fatalf("failed to decode of slice with map unmarshaler: %v", v) + } + if v[0]["a"] != 10 { + t.Fatalf("failed to decode of slice with map unmarshaler: %v", v) + } + }) + t.Run("struct", func(t *testing.T) { + var v []structUnmarshaler + if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 5 { + t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v) + } + if v[0].A != 10 { + t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v) + } + if err := json.Unmarshal([]byte(`[6]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 1 { + t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v) + } + if v[0].A != 10 { + t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v) + } + }) +} + func TestInvalidTopLevelValue(t *testing.T) { t.Run("invalid end of buffer", func(t *testing.T) { var v struct{} From 6444a1b0579b3710ee015674edf8b3bcb3f97d3f Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 30 Apr 2021 23:49:04 +0900 Subject: [PATCH 5/7] Fix checkptr error --- decode_slice.go | 44 ++++++++++++++++++-------------------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/decode_slice.go b/decode_slice.go index 98f24d5..f8ab536 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -121,17 +121,16 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) if d.isElemUnmarshalJSONType { receiver := unsafe_New(d.elemType) - if err := d.valueDecoder.decodeStream(s, depth, receiver); err != nil { - return err - } - *(*unsafe.Pointer)(ep) = **(**unsafe.Pointer)(unsafe.Pointer(&receiver)) - } else { - if d.isElemPointerType { - *(*unsafe.Pointer)(ep) = nil // initialize elem pointer - } - if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil { - return err + if d.elemType.Kind() == reflect.Slice { + **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = receiver + } else { + **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = **(**unsafe.Pointer)(unsafe.Pointer(&receiver)) } + } else if d.isElemPointerType { + **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer + } + if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil { + return err } s.skipWhiteSpace() RETRY: @@ -244,25 +243,18 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) if d.isElemUnmarshalJSONType { receiver := unsafe_New(d.elemType) if d.elemType.Kind() == reflect.Slice { - *(*unsafe.Pointer)(ep) = receiver + **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = receiver } else { - *(*unsafe.Pointer)(ep) = *(*unsafe.Pointer)(receiver) + **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = **(**unsafe.Pointer)(unsafe.Pointer(&receiver)) } - c, err := d.valueDecoder.decode(buf, cursor, depth, ep) - if err != nil { - return 0, err - } - cursor = c - } else { - if d.isElemPointerType { - *(*unsafe.Pointer)(ep) = nil // initialize elem pointer - } - c, err := d.valueDecoder.decode(buf, cursor, depth, ep) - if err != nil { - return 0, err - } - cursor = c + } else if d.isElemPointerType { + **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer } + c, err := d.valueDecoder.decode(buf, cursor, depth, ep) + if err != nil { + return 0, err + } + cursor = c cursor = skipWhiteSpace(buf, cursor) switch buf[cursor] { case ']': From 75b72584a5e57085076402863d141e5099daa9ae Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sat, 1 May 2021 14:53:48 +0900 Subject: [PATCH 6/7] Use typedmemmove for copying element of slice --- decode_slice.go | 22 ++++++++++------------ decode_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 12 deletions(-) diff --git a/decode_slice.go b/decode_slice.go index f8ab536..da8350c 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -67,6 +67,12 @@ func copySlice(elemType *rtype, dst, src sliceHeader) int //go:linkname newArray reflect.unsafe_NewArray func newArray(*rtype, int) unsafe.Pointer +//go:linkname typedmemmovepartial reflect.typedmemmovepartial +func typedmemmovepartial(typ *rtype, dst, src unsafe.Pointer, off, size uintptr) + +//go:linkname typedmemmove reflect.typedmemmove +func typedmemmove(t *rtype, dst, src unsafe.Pointer) + func (d *sliceDecoder) errNumber(offset int64) *UnmarshalTypeError { return &UnmarshalTypeError{ Value: "number", @@ -120,12 +126,8 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er } ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) if d.isElemUnmarshalJSONType { - receiver := unsafe_New(d.elemType) - if d.elemType.Kind() == reflect.Slice { - **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = receiver - } else { - **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = **(**unsafe.Pointer)(unsafe.Pointer(&receiver)) - } + // assign new element to the slice + typedmemmove(d.elemType, ep, unsafe_New(d.elemType)) } else if d.isElemPointerType { **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer } @@ -241,12 +243,8 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) } ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) if d.isElemUnmarshalJSONType { - receiver := unsafe_New(d.elemType) - if d.elemType.Kind() == reflect.Slice { - **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = receiver - } else { - **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = **(**unsafe.Pointer)(unsafe.Pointer(&receiver)) - } + // assign new element to the slice + typedmemmove(d.elemType, ep, unsafe_New(d.elemType)) } else if d.isElemPointerType { **(**unsafe.Pointer)(unsafe.Pointer(&ep)) = nil // initialize elem pointer } diff --git a/decode_test.go b/decode_test.go index 3073553..4fbf1ec 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2996,6 +2996,53 @@ func TestDecodeMultipleUnmarshal(t *testing.T) { }) } +func TestMultipleDecodeWithRawMessage(t *testing.T) { + original := []byte(`{ + "Body": { + "List": [ + { + "Returns": [ + { + "Value": "10", + "nodeType": "Literal" + } + ], + "nodeKind": "Return", + "nodeType": "Statement" + } + ], + "nodeKind": "Block", + "nodeType": "Statement" + }, + "nodeType": "Function" + }`) + + var a map[string]json.RawMessage + if err := json.Unmarshal(original, &a); err != nil { + t.Fatal(err) + } + var b map[string]json.RawMessage + if err := json.Unmarshal(a["Body"], &b); err != nil { + t.Fatal(err) + } + var c []json.RawMessage + if err := json.Unmarshal(b["List"], &c); err != nil { + t.Fatal(err) + } + var d map[string]json.RawMessage + if err := json.Unmarshal(c[0], &d); err != nil { + t.Fatal(err) + } + var e []json.RawMessage + if err := json.Unmarshal(d["Returns"], &e); err != nil { + t.Fatal(err) + } + var f map[string]json.RawMessage + if err := json.Unmarshal(e[0], &f); err != nil { + t.Fatal(err) + } +} + type intUnmarshaler int func (u *intUnmarshaler) UnmarshalJSON(b []byte) error { From 9e882fb664f4f939787058a8d766f9c1772777ca Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sat, 1 May 2021 14:56:49 +0900 Subject: [PATCH 7/7] Remove unused code --- decode_slice.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/decode_slice.go b/decode_slice.go index da8350c..d0be101 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -67,9 +67,6 @@ func copySlice(elemType *rtype, dst, src sliceHeader) int //go:linkname newArray reflect.unsafe_NewArray func newArray(*rtype, int) unsafe.Pointer -//go:linkname typedmemmovepartial reflect.typedmemmovepartial -func typedmemmovepartial(typ *rtype, dst, src unsafe.Pointer, off, size uintptr) - //go:linkname typedmemmove reflect.typedmemmove func typedmemmove(t *rtype, dst, src unsafe.Pointer)