Add validation for null value

This commit is contained in:
Masaaki Goshima 2020-08-21 11:07:55 +09:00
parent 991c8c411b
commit 95bfc8c549
2 changed files with 23 additions and 3 deletions

View File

@ -148,6 +148,10 @@ func (e *Encoder) encodeForMarshal(v interface{}) ([]byte, error) {
} }
func (e *Encoder) encode(v interface{}) error { func (e *Encoder) encode(v interface{}) error {
if v == nil {
e.encodeNull()
return nil
}
header := (*interfaceHeader)(unsafe.Pointer(&v)) header := (*interfaceHeader)(unsafe.Pointer(&v))
typ := header.typ typ := header.typ

View File

@ -148,7 +148,9 @@ func (e *Encoder) run(code *opcode) error {
ptr := code.ptr ptr := code.ptr
isPtr := code.typ.Kind() == reflect.Ptr isPtr := code.typ.Kind() == reflect.Ptr
p := unsafe.Pointer(ptr) p := unsafe.Pointer(ptr)
if isPtr && *(*unsafe.Pointer)(p) == nil { if p == nil {
e.encodeNull()
} else if isPtr && *(*unsafe.Pointer)(p) == nil {
e.encodeBytes([]byte{'"', '"'}) e.encodeBytes([]byte{'"', '"'})
} else { } else {
if isPtr && code.typ.Elem().Implements(marshalTextType) { if isPtr && code.typ.Elem().Implements(marshalTextType) {
@ -1027,7 +1029,14 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
b, err := v.(Marshaler).MarshalJSON() marshaler, ok := v.(Marshaler)
if !ok {
// invalid marshaler
e.encodeNull()
code = field.end
break
}
b, err := marshaler.MarshalJSON()
if err != nil { if err != nil {
return &MarshalerError{ return &MarshalerError{
Type: rtype2type(code.typ), Type: rtype2type(code.typ),
@ -1099,7 +1108,14 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
bytes, err := v.(encoding.TextMarshaler).MarshalText() marshaler, ok := v.(encoding.TextMarshaler)
if !ok {
// invalid marshaler
e.encodeNull()
code = field.end
break
}
bytes, err := marshaler.MarshalText()
if err != nil { if err != nil {
return &MarshalerError{ return &MarshalerError{
Type: rtype2type(code.typ), Type: rtype2type(code.typ),