diff --git a/decode_slice.go b/decode_slice.go index aa67ef2..98f24d5 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -7,14 +7,14 @@ import ( ) type sliceDecoder struct { - elemType *rtype - isElemPointerType bool - isElemSliceType bool - valueDecoder decoder - size uintptr - arrayPool sync.Pool - structName string - fieldName string + elemType *rtype + isElemPointerType bool + isElemUnmarshalJSONType bool + valueDecoder decoder + size uintptr + arrayPool sync.Pool + structName string + fieldName string } // If use reflect.SliceHeader, data type is uintptr. @@ -32,11 +32,11 @@ const ( func newSliceDecoder(dec decoder, elemType *rtype, size uintptr, structName, fieldName string) *sliceDecoder { return &sliceDecoder{ - valueDecoder: dec, - elemType: elemType, - isElemPointerType: elemType.Kind() == reflect.Ptr || elemType.Kind() == reflect.Map, - isElemSliceType: elemType.Kind() == reflect.Slice, - size: size, + valueDecoder: dec, + elemType: elemType, + isElemPointerType: elemType.Kind() == reflect.Ptr || elemType.Kind() == reflect.Map, + isElemUnmarshalJSONType: rtype_ptrTo(elemType).Implements(unmarshalJSONType), + size: size, arrayPool: sync.Pool{ New: func() interface{} { return &sliceHeader{ @@ -119,17 +119,19 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er copySlice(d.elemType, dst, src) } ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) - if d.isElemPointerType { - *(*unsafe.Pointer)(ep) = nil // initialize elem pointer - } else if d.isElemSliceType { - *(*sliceHeader)(ep) = sliceHeader{ - data: newArray(d.elemType, 0), - len: 0, - cap: 0, + if d.isElemUnmarshalJSONType { + receiver := unsafe_New(d.elemType) + if err := d.valueDecoder.decodeStream(s, depth, receiver); err != nil { + return err + } + *(*unsafe.Pointer)(ep) = **(**unsafe.Pointer)(unsafe.Pointer(&receiver)) + } else { + if d.isElemPointerType { + *(*unsafe.Pointer)(ep) = nil // initialize elem pointer + } + if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil { + return err } - } - if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil { - return err } s.skipWhiteSpace() RETRY: @@ -239,20 +241,28 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) copySlice(d.elemType, dst, src) } ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) - if d.isElemPointerType { - *(*unsafe.Pointer)(ep) = nil // initialize elem pointer - } else if d.isElemSliceType { - *(*sliceHeader)(ep) = sliceHeader{ - data: newArray(d.elemType, 0), - len: 0, - cap: 0, + if d.isElemUnmarshalJSONType { + receiver := unsafe_New(d.elemType) + if d.elemType.Kind() == reflect.Slice { + *(*unsafe.Pointer)(ep) = receiver + } else { + *(*unsafe.Pointer)(ep) = *(*unsafe.Pointer)(receiver) } + c, err := d.valueDecoder.decode(buf, cursor, depth, ep) + if err != nil { + return 0, err + } + cursor = c + } else { + if d.isElemPointerType { + *(*unsafe.Pointer)(ep) = nil // initialize elem pointer + } + c, err := d.valueDecoder.decode(buf, cursor, depth, ep) + if err != nil { + return 0, err + } + cursor = c } - c, err := d.valueDecoder.decode(buf, cursor, depth, ep) - if err != nil { - return 0, err - } - cursor = c cursor = skipWhiteSpace(buf, cursor) switch buf[cursor] { case ']': diff --git a/decode_test.go b/decode_test.go index a7a7bec..3073553 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2996,6 +2996,156 @@ func TestDecodeMultipleUnmarshal(t *testing.T) { }) } +type intUnmarshaler int + +func (u *intUnmarshaler) UnmarshalJSON(b []byte) error { + if *u != 0 { + return fmt.Errorf("failed to decode of slice with int unmarshaler") + } + *u = 10 + return nil +} + +type arrayUnmarshaler [5]int + +func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error { + if (*u)[0] != 0 { + return fmt.Errorf("failed to decode of slice with array unmarshaler") + } + (*u)[0] = 10 + return nil +} + +type mapUnmarshaler map[string]int + +func (u *mapUnmarshaler) UnmarshalJSON(b []byte) error { + if len(*u) != 0 { + return fmt.Errorf("failed to decode of slice with map unmarshaler") + } + *u = map[string]int{"a": 10} + return nil +} + +type structUnmarshaler struct { + A int +} + +func (u *structUnmarshaler) UnmarshalJSON(b []byte) error { + if u.A != 0 { + return fmt.Errorf("failed to decode of slice with struct unmarshaler") + } + u.A = 10 + return nil +} + +func TestSliceElemUnmarshaler(t *testing.T) { + t.Run("int", func(t *testing.T) { + var v []intUnmarshaler + if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 5 { + t.Fatalf("failed to decode of slice with int unmarshaler: %v", v) + } + if v[0] != 10 { + t.Fatalf("failed to decode of slice with int unmarshaler: %v", v) + } + if err := json.Unmarshal([]byte(`[6]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 1 { + t.Fatalf("failed to decode of slice with int unmarshaler: %v", v) + } + if v[0] != 10 { + t.Fatalf("failed to decode of slice with int unmarshaler: %v", v) + } + }) + t.Run("slice", func(t *testing.T) { + var v []json.RawMessage + if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 5 { + t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v) + } + if len(v[0]) != 1 { + t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v) + } + if err := json.Unmarshal([]byte(`[6]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 1 { + t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v) + } + if len(v[0]) != 1 { + t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v) + } + }) + t.Run("array", func(t *testing.T) { + var v []arrayUnmarshaler + if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 5 { + t.Fatalf("failed to decode of slice with array unmarshaler: %v", v) + } + if v[0][0] != 10 { + t.Fatalf("failed to decode of slice with array unmarshaler: %v", v) + } + if err := json.Unmarshal([]byte(`[6]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 1 { + t.Fatalf("failed to decode of slice with array unmarshaler: %v", v) + } + if v[0][0] != 10 { + t.Fatalf("failed to decode of slice with array unmarshaler: %v", v) + } + }) + t.Run("map", func(t *testing.T) { + var v []mapUnmarshaler + if err := json.Unmarshal([]byte(`[{"a":1},{"b":2},{"c":3},{"d":4},{"e":5}]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 5 { + t.Fatalf("failed to decode of slice with map unmarshaler: %v", v) + } + if v[0]["a"] != 10 { + t.Fatalf("failed to decode of slice with map unmarshaler: %v", v) + } + if err := json.Unmarshal([]byte(`[6]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 1 { + t.Fatalf("failed to decode of slice with map unmarshaler: %v", v) + } + if v[0]["a"] != 10 { + t.Fatalf("failed to decode of slice with map unmarshaler: %v", v) + } + }) + t.Run("struct", func(t *testing.T) { + var v []structUnmarshaler + if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 5 { + t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v) + } + if v[0].A != 10 { + t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v) + } + if err := json.Unmarshal([]byte(`[6]`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 1 { + t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v) + } + if v[0].A != 10 { + t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v) + } + }) +} + func TestInvalidTopLevelValue(t *testing.T) { t.Run("invalid end of buffer", func(t *testing.T) { var v struct{}