Fix Marshaler

This commit is contained in:
Masaaki Goshima 2021-03-11 13:09:49 +09:00
parent 55cfca3b0a
commit 4f372bebd0
5 changed files with 92 additions and 90 deletions

View File

@ -5,6 +5,9 @@ import (
)
func compact(dst *bytes.Buffer, src []byte, escape bool) error {
if len(src) == 0 {
return errUnexpectedEndOfJSON("", 0)
}
length := len(src)
for cursor := 0; cursor < length; cursor++ {
c := src[cursor]

View File

@ -4,7 +4,6 @@ import (
"bytes"
"encoding"
"encoding/base64"
"fmt"
"io"
"math"
"reflect"
@ -367,6 +366,17 @@ func appendIndent(ctx *encodeRuntimeContext, b []byte, indent int) []byte {
return append(b, bytes.Repeat(ctx.indentStr, ctx.baseIndent+indent)...)
}
func encodeIsNilForMarshaler(v interface{}) bool {
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Interface, reflect.Map, reflect.Ptr:
return rv.IsNil()
case reflect.Slice:
return rv.IsNil() || rv.Len() == 0
}
return false
}
func encodeMarshalJSON(code *opcode, b []byte, v interface{}, escape bool) ([]byte, error) {
rv := reflect.ValueOf(v) // convert by dynamic interface type
if code.addrForMarshaler {
@ -387,16 +397,10 @@ func encodeMarshalJSON(code *opcode, b []byte, v interface{}, escape bool) ([]by
if err != nil {
return nil, &MarshalerError{Type: reflect.TypeOf(v), Err: err}
}
if len(bb) == 0 {
return nil, errUnexpectedEndOfJSON(
fmt.Sprintf("error calling MarshalJSON for type %s", reflect.TypeOf(v)),
0,
)
}
buf := bytes.NewBuffer(b)
//TODO: we should validate buffer with `compact`
if err := compact(buf, bb, escape); err != nil {
return nil, err
return nil, &MarshalerError{Type: reflect.TypeOf(v), Err: err}
}
return buf.Bytes(), nil
}
@ -421,24 +425,18 @@ func encodeMarshalJSONIndent(ctx *encodeRuntimeContext, code *opcode, b []byte,
if err != nil {
return nil, &MarshalerError{Type: reflect.TypeOf(v), Err: err}
}
if len(bb) == 0 {
return nil, errUnexpectedEndOfJSON(
fmt.Sprintf("error calling MarshalJSON for type %s", reflect.TypeOf(v)),
0,
)
}
var compactBuf bytes.Buffer
if err := compact(&compactBuf, bb, escape); err != nil {
return nil, err
return nil, &MarshalerError{Type: reflect.TypeOf(v), Err: err}
}
var indentBuf bytes.Buffer
if err := encodeWithIndent(
&indentBuf,
compactBuf.Bytes(),
string(ctx.prefix)+strings.Repeat(string(ctx.indentStr), ctx.baseIndent+indent),
string(ctx.prefix)+strings.Repeat(string(ctx.indentStr), ctx.baseIndent+indent+1),
string(ctx.indentStr),
); err != nil {
return nil, err
return nil, &MarshalerError{Type: reflect.TypeOf(v), Err: err}
}
return append(b, indentBuf.Bytes()...), nil
}

View File

@ -56,21 +56,24 @@ func encodeCompileToGetCodeSetSlowPath(typeptr uintptr) (*opcodeSet, error) {
func encodeCompileHead(ctx *encodeCompileContext) (*opcode, error) {
typ := ctx.typ
switch {
case typ.Implements(marshalJSONType):
if typ.Kind() != reflect.Ptr || !typ.Elem().Implements(marshalJSONType) {
return encodeCompileMarshalJSON(ctx)
}
case typ.Implements(marshalTextType):
if typ.Kind() != reflect.Ptr || !typ.Elem().Implements(marshalTextType) {
return encodeCompileMarshalText(ctx)
}
case encodeImplementsMarshalJSON(typ):
return encodeCompileMarshalJSON(ctx)
case encodeImplementsMarshalText(typ):
return encodeCompileMarshalText(ctx)
}
isPtr := false
orgType := typ
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
isPtr = true
}
switch {
case encodeImplementsMarshalJSON(typ):
return encodeCompileMarshalJSON(ctx)
case encodeImplementsMarshalText(typ):
return encodeCompileMarshalText(ctx)
}
if typ.Kind() == reflect.Map {
if isPtr {
return encodeCompilePtr(ctx.withType(rtype_ptrTo(typ)))
@ -185,22 +188,42 @@ func encodeOptimizeStructEnd(c *opcode) {
}
}
func encodeImplementsMarshaler(typ *rtype) bool {
switch {
case typ.Implements(marshalJSONType):
return true
case typ.Implements(marshalTextType):
func encodeImplementsMarshalJSON(typ *rtype) bool {
if !typ.Implements(marshalJSONType) {
return false
}
if typ.Kind() != reflect.Ptr {
return true
}
// type kind is reflect.Ptr
if !typ.Elem().Implements(marshalJSONType) {
return true
}
// needs to dereference
return false
}
func encodeImplementsMarshalText(typ *rtype) bool {
if !typ.Implements(marshalTextType) {
return false
}
if typ.Kind() != reflect.Ptr {
return true
}
// type kind is reflect.Ptr
if !typ.Elem().Implements(marshalTextType) {
return true
}
// needs to dereference
return false
}
func encodeCompile(ctx *encodeCompileContext, isPtr bool) (*opcode, error) {
typ := ctx.typ
switch {
case typ.Implements(marshalJSONType) && (typ.Kind() != reflect.Ptr || !typ.Elem().Implements(marshalJSONType)):
case encodeImplementsMarshalJSON(typ):
return encodeCompileMarshalJSON(ctx)
case typ.Implements(marshalTextType) && (typ.Kind() != reflect.Ptr || !typ.Elem().Implements(marshalTextType)):
case encodeImplementsMarshalText(typ):
return encodeCompileMarshalText(ctx)
}
switch typ.Kind() {
@ -208,8 +231,11 @@ func encodeCompile(ctx *encodeCompileContext, isPtr bool) (*opcode, error) {
return encodeCompilePtr(ctx)
case reflect.Slice:
elem := typ.Elem()
if !encodeImplementsMarshaler(elem) && elem.Kind() == reflect.Uint8 {
return encodeCompileBytes(ctx)
if elem.Kind() == reflect.Uint8 {
p := rtype_ptrTo(elem)
if !p.Implements(marshalJSONType) && !p.Implements(marshalTextType) {
return encodeCompileBytes(ctx)
}
}
return encodeCompileSlice(ctx)
case reflect.Array:
@ -293,10 +319,10 @@ func encodeConvertPtrOp(code *opcode) opType {
func encodeCompileKey(ctx *encodeCompileContext) (*opcode, error) {
typ := ctx.typ
switch {
case rtype_ptrTo(typ).Implements(marshalJSONType):
return encodeCompileMarshalJSONPtr(ctx)
case rtype_ptrTo(typ).Implements(marshalTextType):
return encodeCompileMarshalTextPtr(ctx)
case encodeImplementsMarshalJSON(typ):
return encodeCompileMarshalJSON(ctx)
case encodeImplementsMarshalText(typ):
return encodeCompileMarshalText(ctx)
}
switch typ.Kind() {
case reflect.Ptr:
@ -343,24 +369,20 @@ func encodeCompilePtr(ctx *encodeCompileContext) (*opcode, error) {
func encodeCompileMarshalJSON(ctx *encodeCompileContext) (*opcode, error) {
code := newOpCode(ctx, opMarshalJSON)
ctx.incIndex()
return code, nil
}
func encodeCompileMarshalJSONPtr(ctx *encodeCompileContext) (*opcode, error) {
code := newOpCode(ctx.withType(rtype_ptrTo(ctx.typ)), opMarshalJSONPtr)
typ := ctx.typ
if !typ.Implements(marshalJSONType) && rtype_ptrTo(typ).Implements(marshalJSONType) {
code.addrForMarshaler = true
}
ctx.incIndex()
return code, nil
}
func encodeCompileMarshalText(ctx *encodeCompileContext) (*opcode, error) {
code := newOpCode(ctx, opMarshalText)
ctx.incIndex()
return code, nil
}
func encodeCompileMarshalTextPtr(ctx *encodeCompileContext) (*opcode, error) {
code := newOpCode(ctx.withType(rtype_ptrTo(ctx.typ)), opMarshalText)
typ := ctx.typ
if !typ.Implements(marshalTextType) && rtype_ptrTo(typ).Implements(marshalTextType) {
code.addrForMarshaler = true
}
ctx.incIndex()
return code, nil
}
@ -595,7 +617,7 @@ func encodeCompileSlice(ctx *encodeCompileContext) (*opcode, error) {
header := newSliceHeaderCode(ctx)
ctx.incIndex()
code, err := encodeCompile(ctx.withType(ctx.typ.Elem()).incIndent(), false)
code, err := encodeCompileSliceElem(ctx.withType(elem).incIndent())
if err != nil {
return nil, err
}
@ -619,6 +641,18 @@ func encodeCompileSlice(ctx *encodeCompileContext) (*opcode, error) {
return (*opcode)(unsafe.Pointer(header)), nil
}
func encodeCompileSliceElem(ctx *encodeCompileContext) (*opcode, error) {
typ := ctx.typ
switch {
case !typ.Implements(marshalJSONType) && rtype_ptrTo(typ).Implements(marshalJSONType):
return encodeCompileMarshalJSON(ctx)
case !typ.Implements(marshalTextType) && rtype_ptrTo(typ).Implements(marshalTextType):
return encodeCompileMarshalText(ctx)
default:
return encodeCompile(ctx, false)
}
}
func encodeCompileArray(ctx *encodeCompileContext) (*opcode, error) {
ctx.root = false
typ := ctx.typ
@ -1190,7 +1224,7 @@ func encodeCompileStruct(ctx *encodeCompileContext, isPtr bool) (*opcode, error)
ctx.incIndex()
nilcheck := true
var valueCode *opcode
isNilValue := fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Interface
isNilValue := fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Interface || fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Map
addrForMarshaler := false
if i == 0 && fieldNum == 1 && isPtr && !isNilValue && rtype_ptrTo(fieldType).Implements(marshalJSONType) && !fieldType.Implements(marshalJSONType) {
// *struct{ field T } => struct { field *T }
@ -1216,7 +1250,7 @@ func encodeCompileStruct(ctx *encodeCompileContext, isPtr bool) (*opcode, error)
nilcheck = false
indirect = false
disableIndirectConversion = true
} else if isPtr && !isNilValue && !fieldType.Implements(marshalJSONType) && rtype_ptrTo(fieldType).Implements(marshalJSONType) {
} else if isPtr && !fieldType.Implements(marshalJSONType) && rtype_ptrTo(fieldType).Implements(marshalJSONType) {
// *struct{ field T }
// func (*T) MarshalJSON() ([]byte, error)
code, err := encodeCompileMarshalJSON(ctx.withType(fieldType))
@ -1226,7 +1260,7 @@ func encodeCompileStruct(ctx *encodeCompileContext, isPtr bool) (*opcode, error)
addrForMarshaler = true
nilcheck = false
valueCode = code
} else if isPtr && !isNilValue && !fieldType.Implements(marshalTextType) && rtype_ptrTo(fieldType).Implements(marshalTextType) {
} else if isPtr && !fieldType.Implements(marshalTextType) && rtype_ptrTo(fieldType).Implements(marshalTextType) {
// *struct{ field T }
// func (*T) MarshalText() ([]byte, error)
code, err := encodeCompileMarshalText(ctx.withType(fieldType))
@ -1236,30 +1270,6 @@ func encodeCompileStruct(ctx *encodeCompileContext, isPtr bool) (*opcode, error)
addrForMarshaler = true
nilcheck = false
valueCode = code
} else if fieldType.Implements(marshalJSONType) && fieldType.Kind() != reflect.Ptr {
code, err := encodeCompileMarshalJSON(ctx.withType(fieldType))
if err != nil {
return nil, err
}
valueCode = code
} else if fieldType.Implements(marshalTextType) && fieldType.Kind() != reflect.Ptr {
code, err := encodeCompileMarshalText(ctx.withType(fieldType))
if err != nil {
return nil, err
}
valueCode = code
} else if fieldType.Implements(marshalJSONType) && fieldType.Kind() == reflect.Ptr && !fieldType.Elem().Implements(marshalJSONType) {
code, err := encodeCompileMarshalJSON(ctx.withType(fieldType))
if err != nil {
return nil, err
}
valueCode = code
} else if fieldType.Implements(marshalTextType) && fieldType.Kind() == reflect.Ptr && !fieldType.Elem().Implements(marshalTextType) {
code, err := encodeCompileMarshalText(ctx.withType(fieldType))
if err != nil {
return nil, err
}
valueCode = code
} else {
code, err := encodeCompile(ctx.withType(fieldType), isPtr)
if err != nil {

View File

@ -2241,11 +2241,12 @@ func encodeRunEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet, o
p = ptrToPtr(p + code.offset)
}
}
if p == 0 && code.nilcheck {
iface := ptrToInterface(code, p)
if code.nilcheck && encodeIsNilForMarshaler(iface) {
code = code.nextField
} else {
b = append(b, code.escapedKey...)
bb, err := encodeMarshalJSON(code, b, ptrToInterface(code, p), true)
bb, err := encodeMarshalJSON(code, b, iface, true)
if err != nil {
return nil, err
}

10
json.go
View File

@ -296,16 +296,6 @@ func (n Number) Int64() (int64, error) {
return strconv.ParseInt(string(n), 10, 64)
}
func (n Number) MarshalJSON() ([]byte, error) {
if n == "" {
return []byte("0"), nil
}
if _, err := n.Float64(); err != nil {
return nil, err
}
return []byte(n), nil
}
func (n *Number) UnmarshalJSON(b []byte) error {
s := string(b)
if _, err := strconv.ParseFloat(s, 64); err != nil {