From 3e03bdc53f1f7060a8081a4711879a4b5d0d3e58 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 21 Aug 2020 11:51:33 +0900 Subject: [PATCH] Fix null validation --- encode_test.go | 2 +- encode_vm.go | 48 ++++++++++++++++++++++++++++++++---------------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/encode_test.go b/encode_test.go index fd85bb2..10bf27e 100644 --- a/encode_test.go +++ b/encode_test.go @@ -1294,7 +1294,7 @@ func TestNilMarshal(t *testing.T) { {v: struct{ M string }{"gopher"}, want: `{"M":"gopher"}`}, {v: struct{ M json.Marshaler }{}, want: `{"M":null}`}, {v: struct{ M json.Marshaler }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, - {v: struct{ M interface{} }{(*nilJSONMarshaler)(nil)}, want: `{"M":null}`}, + {v: struct{ M interface{} }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, // doesn't compatible with encoding/json {v: struct{ M encoding.TextMarshaler }{}, want: `{"M":null}`}, {v: struct{ M encoding.TextMarshaler }{(*nilTextMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, {v: struct{ M interface{} }{(*nilTextMarshaler)(nil)}, want: `{"M":null}`}, diff --git a/encode_vm.go b/encode_vm.go index a18cf5d..edd3edc 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -69,10 +69,16 @@ func (e *Encoder) run(code *opcode) error { e.encodeBool(e.ptrToBool(code.ptr)) code = code.next case opBytes: - s := base64.StdEncoding.EncodeToString(e.ptrToBytes(code.ptr)) - e.encodeByte('"') - e.encodeBytes(*(*[]byte)(unsafe.Pointer(&s))) - e.encodeByte('"') + ptr := code.ptr + header := (*reflect.SliceHeader)(unsafe.Pointer(ptr)) + if ptr == 0 || header.Data == 0 { + e.encodeNull() + } else { + s := base64.StdEncoding.EncodeToString(e.ptrToBytes(code.ptr)) + e.encodeByte('"') + e.encodeBytes(*(*[]byte)(unsafe.Pointer(&s))) + e.encodeByte('"') + } code = code.next case opInterface: ifaceCode := code.toInterfaceCode() @@ -174,12 +180,12 @@ func (e *Encoder) run(code *opcode) error { case opSliceHead: p := code.ptr headerCode := code.toSliceHeaderCode() - if p == 0 { + header := (*reflect.SliceHeader)(unsafe.Pointer(p)) + if p == 0 || header.Data == 0 { e.encodeNull() code = headerCode.end.next } else { e.encodeByte('[') - header := (*reflect.SliceHeader)(unsafe.Pointer(p)) headerCode.elem.set(header) if header.Len > 0 { code = code.next @@ -1029,14 +1035,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - marshaler, ok := v.(Marshaler) - if !ok { - // invalid marshaler + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { e.encodeNull() code = field.end break } - b, err := marshaler.MarshalJSON() + b, err := rv.Interface().(Marshaler).MarshalJSON() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), @@ -1071,7 +1076,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - b, err := v.(Marshaler).MarshalJSON() + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { + e.encodeNull() + code = field.end + break + } + b, err := rv.Interface().(Marshaler).MarshalJSON() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), @@ -1108,14 +1119,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - marshaler, ok := v.(encoding.TextMarshaler) - if !ok { - // invalid marshaler + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { e.encodeNull() code = field.end break } - bytes, err := marshaler.MarshalText() + bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), @@ -1140,7 +1150,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - bytes, err := v.(encoding.TextMarshaler).MarshalText() + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { + e.encodeNull() + code = field.end + break + } + bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ),