Fix null validation

This commit is contained in:
Masaaki Goshima 2020-08-21 11:51:33 +09:00
parent 78fe23fc64
commit 3e03bdc53f
2 changed files with 33 additions and 17 deletions

View File

@ -1294,7 +1294,7 @@ func TestNilMarshal(t *testing.T) {
{v: struct{ M string }{"gopher"}, want: `{"M":"gopher"}`},
{v: struct{ M json.Marshaler }{}, want: `{"M":null}`},
{v: struct{ M json.Marshaler }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`},
{v: struct{ M interface{} }{(*nilJSONMarshaler)(nil)}, want: `{"M":null}`},
{v: struct{ M interface{} }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, // doesn't compatible with encoding/json
{v: struct{ M encoding.TextMarshaler }{}, want: `{"M":null}`},
{v: struct{ M encoding.TextMarshaler }{(*nilTextMarshaler)(nil)}, want: `{"M":"0zenil0"}`},
{v: struct{ M interface{} }{(*nilTextMarshaler)(nil)}, want: `{"M":null}`},

View File

@ -69,10 +69,16 @@ func (e *Encoder) run(code *opcode) error {
e.encodeBool(e.ptrToBool(code.ptr))
code = code.next
case opBytes:
ptr := code.ptr
header := (*reflect.SliceHeader)(unsafe.Pointer(ptr))
if ptr == 0 || header.Data == 0 {
e.encodeNull()
} else {
s := base64.StdEncoding.EncodeToString(e.ptrToBytes(code.ptr))
e.encodeByte('"')
e.encodeBytes(*(*[]byte)(unsafe.Pointer(&s)))
e.encodeByte('"')
}
code = code.next
case opInterface:
ifaceCode := code.toInterfaceCode()
@ -174,12 +180,12 @@ func (e *Encoder) run(code *opcode) error {
case opSliceHead:
p := code.ptr
headerCode := code.toSliceHeaderCode()
if p == 0 {
header := (*reflect.SliceHeader)(unsafe.Pointer(p))
if p == 0 || header.Data == 0 {
e.encodeNull()
code = headerCode.end.next
} else {
e.encodeByte('[')
header := (*reflect.SliceHeader)(unsafe.Pointer(p))
headerCode.elem.set(header)
if header.Len > 0 {
code = code.next
@ -1029,14 +1035,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ,
ptr: unsafe.Pointer(ptr),
}))
marshaler, ok := v.(Marshaler)
if !ok {
// invalid marshaler
rv := reflect.ValueOf(v)
if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
e.encodeNull()
code = field.end
break
}
b, err := marshaler.MarshalJSON()
b, err := rv.Interface().(Marshaler).MarshalJSON()
if err != nil {
return &MarshalerError{
Type: rtype2type(code.typ),
@ -1071,7 +1076,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ,
ptr: unsafe.Pointer(ptr),
}))
b, err := v.(Marshaler).MarshalJSON()
rv := reflect.ValueOf(v)
if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
e.encodeNull()
code = field.end
break
}
b, err := rv.Interface().(Marshaler).MarshalJSON()
if err != nil {
return &MarshalerError{
Type: rtype2type(code.typ),
@ -1108,14 +1119,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ,
ptr: unsafe.Pointer(ptr),
}))
marshaler, ok := v.(encoding.TextMarshaler)
if !ok {
// invalid marshaler
rv := reflect.ValueOf(v)
if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
e.encodeNull()
code = field.end
break
}
bytes, err := marshaler.MarshalText()
bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return &MarshalerError{
Type: rtype2type(code.typ),
@ -1140,7 +1150,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ,
ptr: unsafe.Pointer(ptr),
}))
bytes, err := v.(encoding.TextMarshaler).MarshalText()
rv := reflect.ValueOf(v)
if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
e.encodeNull()
code = field.end
break
}
bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return &MarshalerError{
Type: rtype2type(code.typ),