Fix decoding of slice with unmarshal json type

This commit is contained in:
Masaaki Goshima 2021-04-30 22:55:08 +09:00
parent 17392ab716
commit e647dafb41
2 changed files with 195 additions and 35 deletions

View File

@ -9,7 +9,7 @@ import (
type sliceDecoder struct { type sliceDecoder struct {
elemType *rtype elemType *rtype
isElemPointerType bool isElemPointerType bool
isElemSliceType bool isElemUnmarshalJSONType bool
valueDecoder decoder valueDecoder decoder
size uintptr size uintptr
arrayPool sync.Pool arrayPool sync.Pool
@ -35,7 +35,7 @@ func newSliceDecoder(dec decoder, elemType *rtype, size uintptr, structName, fie
valueDecoder: dec, valueDecoder: dec,
elemType: elemType, elemType: elemType,
isElemPointerType: elemType.Kind() == reflect.Ptr || elemType.Kind() == reflect.Map, isElemPointerType: elemType.Kind() == reflect.Ptr || elemType.Kind() == reflect.Map,
isElemSliceType: elemType.Kind() == reflect.Slice, isElemUnmarshalJSONType: rtype_ptrTo(elemType).Implements(unmarshalJSONType),
size: size, size: size,
arrayPool: sync.Pool{ arrayPool: sync.Pool{
New: func() interface{} { New: func() interface{} {
@ -119,18 +119,20 @@ func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) er
copySlice(d.elemType, dst, src) copySlice(d.elemType, dst, src)
} }
ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size)
if d.isElemUnmarshalJSONType {
receiver := unsafe_New(d.elemType)
if err := d.valueDecoder.decodeStream(s, depth, receiver); err != nil {
return err
}
*(*unsafe.Pointer)(ep) = **(**unsafe.Pointer)(unsafe.Pointer(&receiver))
} else {
if d.isElemPointerType { if d.isElemPointerType {
*(*unsafe.Pointer)(ep) = nil // initialize elem pointer *(*unsafe.Pointer)(ep) = nil // initialize elem pointer
} else if d.isElemSliceType {
*(*sliceHeader)(ep) = sliceHeader{
data: newArray(d.elemType, 0),
len: 0,
cap: 0,
}
} }
if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil { if err := d.valueDecoder.decodeStream(s, depth, ep); err != nil {
return err return err
} }
}
s.skipWhiteSpace() s.skipWhiteSpace()
RETRY: RETRY:
switch s.char() { switch s.char() {
@ -239,20 +241,28 @@ func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer)
copySlice(d.elemType, dst, src) copySlice(d.elemType, dst, src)
} }
ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size) ep := unsafe.Pointer(uintptr(data) + uintptr(idx)*d.size)
if d.isElemPointerType { if d.isElemUnmarshalJSONType {
*(*unsafe.Pointer)(ep) = nil // initialize elem pointer receiver := unsafe_New(d.elemType)
} else if d.isElemSliceType { if d.elemType.Kind() == reflect.Slice {
*(*sliceHeader)(ep) = sliceHeader{ *(*unsafe.Pointer)(ep) = receiver
data: newArray(d.elemType, 0), } else {
len: 0, *(*unsafe.Pointer)(ep) = *(*unsafe.Pointer)(receiver)
cap: 0,
}
} }
c, err := d.valueDecoder.decode(buf, cursor, depth, ep) c, err := d.valueDecoder.decode(buf, cursor, depth, ep)
if err != nil { if err != nil {
return 0, err return 0, err
} }
cursor = c cursor = c
} else {
if d.isElemPointerType {
*(*unsafe.Pointer)(ep) = nil // initialize elem pointer
}
c, err := d.valueDecoder.decode(buf, cursor, depth, ep)
if err != nil {
return 0, err
}
cursor = c
}
cursor = skipWhiteSpace(buf, cursor) cursor = skipWhiteSpace(buf, cursor)
switch buf[cursor] { switch buf[cursor] {
case ']': case ']':

View File

@ -2996,6 +2996,156 @@ func TestDecodeMultipleUnmarshal(t *testing.T) {
}) })
} }
type intUnmarshaler int
func (u *intUnmarshaler) UnmarshalJSON(b []byte) error {
if *u != 0 {
return fmt.Errorf("failed to decode of slice with int unmarshaler")
}
*u = 10
return nil
}
type arrayUnmarshaler [5]int
func (u *arrayUnmarshaler) UnmarshalJSON(b []byte) error {
if (*u)[0] != 0 {
return fmt.Errorf("failed to decode of slice with array unmarshaler")
}
(*u)[0] = 10
return nil
}
type mapUnmarshaler map[string]int
func (u *mapUnmarshaler) UnmarshalJSON(b []byte) error {
if len(*u) != 0 {
return fmt.Errorf("failed to decode of slice with map unmarshaler")
}
*u = map[string]int{"a": 10}
return nil
}
type structUnmarshaler struct {
A int
}
func (u *structUnmarshaler) UnmarshalJSON(b []byte) error {
if u.A != 0 {
return fmt.Errorf("failed to decode of slice with struct unmarshaler")
}
u.A = 10
return nil
}
func TestSliceElemUnmarshaler(t *testing.T) {
t.Run("int", func(t *testing.T) {
var v []intUnmarshaler
if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil {
t.Fatal(err)
}
if len(v) != 5 {
t.Fatalf("failed to decode of slice with int unmarshaler: %v", v)
}
if v[0] != 10 {
t.Fatalf("failed to decode of slice with int unmarshaler: %v", v)
}
if err := json.Unmarshal([]byte(`[6]`), &v); err != nil {
t.Fatal(err)
}
if len(v) != 1 {
t.Fatalf("failed to decode of slice with int unmarshaler: %v", v)
}
if v[0] != 10 {
t.Fatalf("failed to decode of slice with int unmarshaler: %v", v)
}
})
t.Run("slice", func(t *testing.T) {
var v []json.RawMessage
if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil {
t.Fatal(err)
}
if len(v) != 5 {
t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v)
}
if len(v[0]) != 1 {
t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v)
}
if err := json.Unmarshal([]byte(`[6]`), &v); err != nil {
t.Fatal(err)
}
if len(v) != 1 {
t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v)
}
if len(v[0]) != 1 {
t.Fatalf("failed to decode of slice with slice unmarshaler: %v", v)
}
})
t.Run("array", func(t *testing.T) {
var v []arrayUnmarshaler
if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil {
t.Fatal(err)
}
if len(v) != 5 {
t.Fatalf("failed to decode of slice with array unmarshaler: %v", v)
}
if v[0][0] != 10 {
t.Fatalf("failed to decode of slice with array unmarshaler: %v", v)
}
if err := json.Unmarshal([]byte(`[6]`), &v); err != nil {
t.Fatal(err)
}
if len(v) != 1 {
t.Fatalf("failed to decode of slice with array unmarshaler: %v", v)
}
if v[0][0] != 10 {
t.Fatalf("failed to decode of slice with array unmarshaler: %v", v)
}
})
t.Run("map", func(t *testing.T) {
var v []mapUnmarshaler
if err := json.Unmarshal([]byte(`[{"a":1},{"b":2},{"c":3},{"d":4},{"e":5}]`), &v); err != nil {
t.Fatal(err)
}
if len(v) != 5 {
t.Fatalf("failed to decode of slice with map unmarshaler: %v", v)
}
if v[0]["a"] != 10 {
t.Fatalf("failed to decode of slice with map unmarshaler: %v", v)
}
if err := json.Unmarshal([]byte(`[6]`), &v); err != nil {
t.Fatal(err)
}
if len(v) != 1 {
t.Fatalf("failed to decode of slice with map unmarshaler: %v", v)
}
if v[0]["a"] != 10 {
t.Fatalf("failed to decode of slice with map unmarshaler: %v", v)
}
})
t.Run("struct", func(t *testing.T) {
var v []structUnmarshaler
if err := json.Unmarshal([]byte(`[1,2,3,4,5]`), &v); err != nil {
t.Fatal(err)
}
if len(v) != 5 {
t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v)
}
if v[0].A != 10 {
t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v)
}
if err := json.Unmarshal([]byte(`[6]`), &v); err != nil {
t.Fatal(err)
}
if len(v) != 1 {
t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v)
}
if v[0].A != 10 {
t.Fatalf("failed to decode of slice with struct unmarshaler: %v", v)
}
})
}
func TestInvalidTopLevelValue(t *testing.T) { func TestInvalidTopLevelValue(t *testing.T) {
t.Run("invalid end of buffer", func(t *testing.T) { t.Run("invalid end of buffer", func(t *testing.T) {
var v struct{} var v struct{}