Fix null validation

This commit is contained in:
Masaaki Goshima 2020-08-21 11:51:33 +09:00
parent 78fe23fc64
commit 3e03bdc53f
2 changed files with 33 additions and 17 deletions

View File

@ -1294,7 +1294,7 @@ func TestNilMarshal(t *testing.T) {
{v: struct{ M string }{"gopher"}, want: `{"M":"gopher"}`}, {v: struct{ M string }{"gopher"}, want: `{"M":"gopher"}`},
{v: struct{ M json.Marshaler }{}, want: `{"M":null}`}, {v: struct{ M json.Marshaler }{}, want: `{"M":null}`},
{v: struct{ M json.Marshaler }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, {v: struct{ M json.Marshaler }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`},
{v: struct{ M interface{} }{(*nilJSONMarshaler)(nil)}, want: `{"M":null}`}, {v: struct{ M interface{} }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, // doesn't compatible with encoding/json
{v: struct{ M encoding.TextMarshaler }{}, want: `{"M":null}`}, {v: struct{ M encoding.TextMarshaler }{}, want: `{"M":null}`},
{v: struct{ M encoding.TextMarshaler }{(*nilTextMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, {v: struct{ M encoding.TextMarshaler }{(*nilTextMarshaler)(nil)}, want: `{"M":"0zenil0"}`},
{v: struct{ M interface{} }{(*nilTextMarshaler)(nil)}, want: `{"M":null}`}, {v: struct{ M interface{} }{(*nilTextMarshaler)(nil)}, want: `{"M":null}`},

View File

@ -69,10 +69,16 @@ func (e *Encoder) run(code *opcode) error {
e.encodeBool(e.ptrToBool(code.ptr)) e.encodeBool(e.ptrToBool(code.ptr))
code = code.next code = code.next
case opBytes: case opBytes:
s := base64.StdEncoding.EncodeToString(e.ptrToBytes(code.ptr)) ptr := code.ptr
e.encodeByte('"') header := (*reflect.SliceHeader)(unsafe.Pointer(ptr))
e.encodeBytes(*(*[]byte)(unsafe.Pointer(&s))) if ptr == 0 || header.Data == 0 {
e.encodeByte('"') e.encodeNull()
} else {
s := base64.StdEncoding.EncodeToString(e.ptrToBytes(code.ptr))
e.encodeByte('"')
e.encodeBytes(*(*[]byte)(unsafe.Pointer(&s)))
e.encodeByte('"')
}
code = code.next code = code.next
case opInterface: case opInterface:
ifaceCode := code.toInterfaceCode() ifaceCode := code.toInterfaceCode()
@ -174,12 +180,12 @@ func (e *Encoder) run(code *opcode) error {
case opSliceHead: case opSliceHead:
p := code.ptr p := code.ptr
headerCode := code.toSliceHeaderCode() headerCode := code.toSliceHeaderCode()
if p == 0 { header := (*reflect.SliceHeader)(unsafe.Pointer(p))
if p == 0 || header.Data == 0 {
e.encodeNull() e.encodeNull()
code = headerCode.end.next code = headerCode.end.next
} else { } else {
e.encodeByte('[') e.encodeByte('[')
header := (*reflect.SliceHeader)(unsafe.Pointer(p))
headerCode.elem.set(header) headerCode.elem.set(header)
if header.Len > 0 { if header.Len > 0 {
code = code.next code = code.next
@ -1029,14 +1035,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
marshaler, ok := v.(Marshaler) rv := reflect.ValueOf(v)
if !ok { if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
// invalid marshaler
e.encodeNull() e.encodeNull()
code = field.end code = field.end
break break
} }
b, err := marshaler.MarshalJSON() b, err := rv.Interface().(Marshaler).MarshalJSON()
if err != nil { if err != nil {
return &MarshalerError{ return &MarshalerError{
Type: rtype2type(code.typ), Type: rtype2type(code.typ),
@ -1071,7 +1076,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
b, err := v.(Marshaler).MarshalJSON() rv := reflect.ValueOf(v)
if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
e.encodeNull()
code = field.end
break
}
b, err := rv.Interface().(Marshaler).MarshalJSON()
if err != nil { if err != nil {
return &MarshalerError{ return &MarshalerError{
Type: rtype2type(code.typ), Type: rtype2type(code.typ),
@ -1108,14 +1119,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
marshaler, ok := v.(encoding.TextMarshaler) rv := reflect.ValueOf(v)
if !ok { if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
// invalid marshaler
e.encodeNull() e.encodeNull()
code = field.end code = field.end
break break
} }
bytes, err := marshaler.MarshalText() bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil { if err != nil {
return &MarshalerError{ return &MarshalerError{
Type: rtype2type(code.typ), Type: rtype2type(code.typ),
@ -1140,7 +1150,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
bytes, err := v.(encoding.TextMarshaler).MarshalText() rv := reflect.ValueOf(v)
if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
e.encodeNull()
code = field.end
break
}
bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil { if err != nil {
return &MarshalerError{ return &MarshalerError{
Type: rtype2type(code.typ), Type: rtype2type(code.typ),