Merge pull request #39 from goccy/feature/fix-nil-marshaler

Fix nested embedded structure that has same name fields
This commit is contained in:
Masaaki Goshima 2020-08-22 15:48:05 +09:00 committed by GitHub
commit 97787fd7b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 437 additions and 33 deletions

View File

@ -148,6 +148,10 @@ func (e *Encoder) encodeForMarshal(v interface{}) ([]byte, error) {
} }
func (e *Encoder) encode(v interface{}) error { func (e *Encoder) encode(v interface{}) error {
if v == nil {
e.encodeNull()
return nil
}
header := (*interfaceHeader)(unsafe.Pointer(&v)) header := (*interfaceHeader)(unsafe.Pointer(&v))
typ := header.typ typ := header.typ

View File

@ -566,6 +566,151 @@ func (e *Encoder) structField(fieldCode *structFieldCode, valueCode *opcode, tag
} }
return code return code
} }
func (e *Encoder) isNotExistsField(head *structFieldCode) bool {
if head == nil {
return false
}
if head.op != opStructFieldAnonymousHead {
return false
}
if head.next == nil {
return false
}
if head.nextField == nil {
return false
}
if head.nextField.op != opStructAnonymousEnd {
return false
}
if head.next.op == opStructAnonymousEnd {
return true
}
if head.next.op.codeType() != codeStructField {
return false
}
return e.isNotExistsField(head.next.toStructFieldCode())
}
func (e *Encoder) optimizeAnonymousFields(head *structFieldCode) {
code := head
var prev *structFieldCode
for {
if code.op == opStructEnd || code.op == opStructEndIndent {
break
}
if code.op == opStructField || code.op == opStructFieldIndent {
codeType := code.next.op.codeType()
if codeType == codeStructField {
if e.isNotExistsField(code.next.toStructFieldCode()) {
code.next = code.nextField
linkPrevToNextField(prev, code)
code = prev
}
}
}
prev = code
code = code.nextField.toStructFieldCode()
}
}
type structFieldPair struct {
prevField *structFieldCode
curField *structFieldCode
isTaggedKey bool
linked bool
}
func (e *Encoder) anonymousStructFieldPairMap(typ *rtype, tags structTags, valueCode *structFieldCode) map[string][]structFieldPair {
//fmt.Println("type = ", typ, "valueCode = ", valueCode.dump())
anonymousFields := map[string][]structFieldPair{}
f := valueCode
var prevAnonymousField *structFieldCode
for {
existsKey := tags.existsKey(f.displayKey)
op := f.op.headToAnonymousHead()
if op != f.op {
if existsKey {
f.op = opStructFieldAnonymousHead
} else {
f.op = op
}
} else if f.op == opStructEnd {
f.op = opStructAnonymousEnd
} else if existsKey {
linkPrevToNextField(prevAnonymousField, f)
}
if f.displayKey == "" {
if f.nextField == nil {
break
}
prevAnonymousField = f
f = f.nextField.toStructFieldCode()
continue
}
anonymousFields[f.displayKey] = append(anonymousFields[f.displayKey], structFieldPair{
prevField: prevAnonymousField,
curField: f,
isTaggedKey: f.isTaggedKey,
})
if f.next != nil && f.nextField != f.next && f.next.op.codeType() == codeStructField {
for k, v := range e.anonymousStructFieldPairMap(typ, tags, f.next.toStructFieldCode()) {
anonymousFields[k] = append(anonymousFields[k], v...)
}
}
if f.nextField == nil {
break
}
prevAnonymousField = f
f = f.nextField.toStructFieldCode()
}
return anonymousFields
}
func (e *Encoder) optimizeConflictAnonymousFields(anonymousFields map[string][]structFieldPair) {
for _, fieldPairs := range anonymousFields {
if len(fieldPairs) == 1 {
continue
}
// conflict anonymous fields
taggedPairs := []structFieldPair{}
for _, fieldPair := range fieldPairs {
if fieldPair.isTaggedKey {
taggedPairs = append(taggedPairs, fieldPair)
} else {
if !fieldPair.linked {
if fieldPair.prevField == nil {
// head operation
fieldPair.curField.op = opStructFieldAnonymousHead
} else {
linkPrevToNextField(fieldPair.prevField, fieldPair.curField)
}
fieldPair.linked = true
}
}
}
if len(taggedPairs) > 1 {
for _, fieldPair := range taggedPairs {
if !fieldPair.linked {
if fieldPair.prevField == nil {
// head operation
fieldPair.curField.op = opStructFieldAnonymousHead
} else {
linkPrevToNextField(fieldPair.prevField, fieldPair.curField)
}
fieldPair.linked = true
}
}
} else {
for _, fieldPair := range taggedPairs {
fieldPair.curField.isTaggedKey = false
}
}
}
}
func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opcode, error) { func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opcode, error) {
if code := e.compiledCode(typ, withIndent); code != nil { if code := e.compiledCode(typ, withIndent); code != nil {
return code, nil return code, nil
@ -588,12 +733,17 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco
prevField *structFieldCode prevField *structFieldCode
) )
e.indent++ e.indent++
tags := structTags{}
anonymousFields := map[string][]structFieldPair{}
for i := 0; i < fieldNum; i++ { for i := 0; i < fieldNum; i++ {
field := typ.Field(i) field := typ.Field(i)
if isIgnoredStructField(field) { if isIgnoredStructField(field) {
continue continue
} }
tag := structTagFromField(field) tags = append(tags, structTagFromField(field))
}
for i, tag := range tags {
field := tag.field
fieldType := type2rtype(field.Type) fieldType := type2rtype(field.Type)
if isPtr && i == 0 { if isPtr && i == 0 {
// head field of pointer structure at top level // head field of pointer structure at top level
@ -609,16 +759,8 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco
return nil, err return nil, err
} }
if field.Anonymous { if field.Anonymous {
f := valueCode.toStructFieldCode() for k, v := range e.anonymousStructFieldPairMap(typ, tags, valueCode.toStructFieldCode()) {
for { anonymousFields[k] = append(anonymousFields[k], v...)
f.op = f.op.headToAnonymousHead()
if f.op == opStructEnd {
f.op = opStructAnonymousEnd
}
if f.nextField == nil {
break
}
f = f.nextField.toStructFieldCode()
} }
} }
if fieldNum == 1 && valueCode.op == opPtr { if fieldNum == 1 && valueCode.op == opPtr {
@ -640,6 +782,8 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco
}, },
anonymousKey: field.Anonymous, anonymousKey: field.Anonymous,
key: []byte(key), key: []byte(key),
isTaggedKey: tag.isTaggedKey,
displayKey: tag.key,
offset: field.Offset, offset: field.Offset,
} }
if fieldIdx == 0 { if fieldIdx == 0 {
@ -690,6 +834,10 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco
} }
head.end = structEndCode head.end = structEndCode
code.next = structEndCode code.next = structEndCode
e.optimizeConflictAnonymousFields(anonymousFields)
e.optimizeAnonymousFields(head)
ret := (*opcode)(unsafe.Pointer(head)) ret := (*opcode)(unsafe.Pointer(head))
compiled.code = ret compiled.code = ret

View File

@ -53,12 +53,12 @@ func (c *opcode) beforeLastCode() *opcode {
code := c code := c
for { for {
var nextCode *opcode var nextCode *opcode
switch code.op { switch code.op.codeType() {
case opArrayElem, opArrayElemIndent: case codeArrayElem:
nextCode = code.toArrayElemCode().end nextCode = code.toArrayElemCode().end
case opSliceElem, opSliceElemIndent, opRootSliceElemIndent: case codeSliceElem:
nextCode = code.toSliceElemCode().end nextCode = code.toSliceElemCode().end
case opMapKey, opMapKeyIndent, opRootMapKeyIndent: case codeMapKey:
nextCode = code.toMapKeyCode().end nextCode = code.toMapKeyCode().end
default: default:
nextCode = code.next nextCode = code.next
@ -112,13 +112,13 @@ func (c *opcode) dump() string {
codes := []string{} codes := []string{}
for code := c; code.op != opEnd; { for code := c; code.op != opEnd; {
indent := strings.Repeat(" ", code.indent) indent := strings.Repeat(" ", code.indent)
codes = append(codes, fmt.Sprintf("%s%s", indent, code.op)) codes = append(codes, fmt.Sprintf("%s%s ( %p )", indent, code.op, unsafe.Pointer(code)))
switch code.op { switch code.op.codeType() {
case opArrayElem, opArrayElemIndent: case codeArrayElem:
code = code.toArrayElemCode().end code = code.toArrayElemCode().end
case opSliceElem, opSliceElemIndent, opRootSliceElemIndent: case codeSliceElem:
code = code.toSliceElemCode().end code = code.toSliceElemCode().end
case opMapKey, opMapKeyIndent, opRootMapKeyIndent: case codeMapKey:
code = code.toMapKeyCode().end code = code.toMapKeyCode().end
default: default:
code = code.next code = code.next
@ -305,12 +305,44 @@ func (c *arrayElemCode) copy(codeMap map[uintptr]*opcode) *opcode {
type structFieldCode struct { type structFieldCode struct {
*opcodeHeader *opcodeHeader
key []byte key []byte
displayKey string
isTaggedKey bool
offset uintptr offset uintptr
anonymousKey bool anonymousKey bool
nextField *opcode nextField *opcode
end *opcode end *opcode
} }
func linkPrevToNextField(prev, cur *structFieldCode) {
prev.nextField = cur.nextField
code := prev.toOpcode()
fcode := cur.toOpcode()
for {
var nextCode *opcode
switch code.op.codeType() {
case codeArrayElem:
nextCode = code.toArrayElemCode().end
case codeSliceElem:
nextCode = code.toSliceElemCode().end
case codeMapKey:
nextCode = code.toMapKeyCode().end
default:
nextCode = code.next
}
if nextCode == fcode {
code.next = fcode.next
break
} else if nextCode.op == opEnd {
break
}
code = nextCode
}
}
func (c *structFieldCode) toOpcode() *opcode {
return (*opcode)(unsafe.Pointer(c))
}
func (c *structFieldCode) copy(codeMap map[uintptr]*opcode) *opcode { func (c *structFieldCode) copy(codeMap map[uintptr]*opcode) *opcode {
if c == nil { if c == nil {
return nil return nil
@ -321,6 +353,8 @@ func (c *structFieldCode) copy(codeMap map[uintptr]*opcode) *opcode {
} }
field := &structFieldCode{ field := &structFieldCode{
key: c.key, key: c.key,
isTaggedKey: c.isTaggedKey,
displayKey: c.displayKey,
anonymousKey: c.anonymousKey, anonymousKey: c.anonymousKey,
offset: c.offset, offset: c.offset,
} }

View File

@ -2,6 +2,7 @@ package json_test
import ( import (
"bytes" "bytes"
"encoding"
"errors" "errors"
"fmt" "fmt"
"log" "log"
@ -1233,3 +1234,161 @@ func TestHTMLEscape(t *testing.T) {
t.Errorf("HTMLEscape(&b, []byte(m)) = %s; want %s", b.Bytes(), want.Bytes()) t.Errorf("HTMLEscape(&b, []byte(m)) = %s; want %s", b.Bytes(), want.Bytes())
} }
} }
type BugA struct {
S string
}
type BugB struct {
BugA
S string
}
type BugC struct {
S string
}
// Legal Go: We never use the repeated embedded field (S).
type BugX struct {
A int
BugA
BugB
}
// golang.org/issue/16042.
// Even if a nil interface value is passed in, as long as
// it implements Marshaler, it should be marshaled.
type nilJSONMarshaler string
func (nm *nilJSONMarshaler) MarshalJSON() ([]byte, error) {
if nm == nil {
return json.Marshal("0zenil0")
}
return json.Marshal("zenil:" + string(*nm))
}
// golang.org/issue/34235.
// Even if a nil interface value is passed in, as long as
// it implements encoding.TextMarshaler, it should be marshaled.
type nilTextMarshaler string
func (nm *nilTextMarshaler) MarshalText() ([]byte, error) {
if nm == nil {
return []byte("0zenil0"), nil
}
return []byte("zenil:" + string(*nm)), nil
}
// See golang.org/issue/16042 and golang.org/issue/34235.
func TestNilMarshal(t *testing.T) {
testCases := []struct {
v interface{}
want string
}{
{v: nil, want: `null`},
{v: new(float64), want: `0`},
{v: []interface{}(nil), want: `null`},
{v: []string(nil), want: `null`},
{v: map[string]string(nil), want: `null`},
{v: []byte(nil), want: `null`},
{v: struct{ M string }{"gopher"}, want: `{"M":"gopher"}`},
{v: struct{ M json.Marshaler }{}, want: `{"M":null}`},
{v: struct{ M json.Marshaler }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`},
{v: struct{ M interface{} }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, // doesn't compatible with encoding/json
{v: struct{ M encoding.TextMarshaler }{}, want: `{"M":null}`},
{v: struct{ M encoding.TextMarshaler }{(*nilTextMarshaler)(nil)}, want: `{"M":"0zenil0"}`},
{v: struct{ M interface{} }{(*nilTextMarshaler)(nil)}, want: `{"M":null}`},
}
for i, tt := range testCases {
out, err := json.Marshal(tt.v)
if err != nil || string(out) != tt.want {
t.Errorf("%d: Marshal(%#v) = %#q, %#v, want %#q, nil", i, tt.v, out, err, tt.want)
continue
}
}
}
// Issue 5245.
func TestEmbeddedBug(t *testing.T) {
v := BugB{
BugA{"A"},
"B",
}
b, err := json.Marshal(v)
if err != nil {
t.Fatal("Marshal:", err)
}
want := `{"S":"B"}`
got := string(b)
if got != want {
t.Fatalf("Marshal: got %s want %s", got, want)
}
// Now check that the duplicate field, S, does not appear.
x := BugX{
A: 23,
}
b, err = json.Marshal(x)
if err != nil {
t.Fatal("Marshal:", err)
}
want = `{"A":23}`
got = string(b)
if got != want {
t.Fatalf("Marshal: got %s want %s", got, want)
}
}
type BugD struct { // Same as BugA after tagging.
XXX string `json:"S"`
}
// BugD's tagged S field should dominate BugA's.
type BugY struct {
BugA
BugD
}
// Test that a field with a tag dominates untagged fields.
func TestTaggedFieldDominates(t *testing.T) {
v := BugY{
BugA{"BugA"},
BugD{"BugD"},
}
b, err := json.Marshal(v)
if err != nil {
t.Fatal("Marshal:", err)
}
want := `{"S":"BugD"}`
got := string(b)
if got != want {
t.Fatalf("Marshal: got %s want %s", got, want)
}
}
// There are no tags here, so S should not appear.
type BugZ struct {
BugA
BugC
BugY // Contains a tagged S field through BugD; should not dominate.
}
func TestDuplicatedFieldDisappears(t *testing.T) {
v := BugZ{
BugA{"BugA"},
BugC{"BugC"},
BugY{
BugA{"nested BugA"},
BugD{"nested BugD"},
},
}
b, err := json.Marshal(v)
if err != nil {
t.Fatal("Marshal:", err)
}
want := `{}`
got := string(b)
if got != want {
t.Fatalf("Marshal: got %s want %s", got, want)
}
}

View File

@ -69,10 +69,16 @@ func (e *Encoder) run(code *opcode) error {
e.encodeBool(e.ptrToBool(code.ptr)) e.encodeBool(e.ptrToBool(code.ptr))
code = code.next code = code.next
case opBytes: case opBytes:
s := base64.StdEncoding.EncodeToString(e.ptrToBytes(code.ptr)) ptr := code.ptr
e.encodeByte('"') header := (*reflect.SliceHeader)(unsafe.Pointer(ptr))
e.encodeBytes(*(*[]byte)(unsafe.Pointer(&s))) if ptr == 0 || header.Data == 0 {
e.encodeByte('"') e.encodeNull()
} else {
s := base64.StdEncoding.EncodeToString(e.ptrToBytes(code.ptr))
e.encodeByte('"')
e.encodeBytes(*(*[]byte)(unsafe.Pointer(&s)))
e.encodeByte('"')
}
code = code.next code = code.next
case opInterface: case opInterface:
ifaceCode := code.toInterfaceCode() ifaceCode := code.toInterfaceCode()
@ -148,7 +154,9 @@ func (e *Encoder) run(code *opcode) error {
ptr := code.ptr ptr := code.ptr
isPtr := code.typ.Kind() == reflect.Ptr isPtr := code.typ.Kind() == reflect.Ptr
p := unsafe.Pointer(ptr) p := unsafe.Pointer(ptr)
if isPtr && *(*unsafe.Pointer)(p) == nil { if p == nil {
e.encodeNull()
} else if isPtr && *(*unsafe.Pointer)(p) == nil {
e.encodeBytes([]byte{'"', '"'}) e.encodeBytes([]byte{'"', '"'})
} else { } else {
if isPtr && code.typ.Elem().Implements(marshalTextType) { if isPtr && code.typ.Elem().Implements(marshalTextType) {
@ -172,12 +180,12 @@ func (e *Encoder) run(code *opcode) error {
case opSliceHead: case opSliceHead:
p := code.ptr p := code.ptr
headerCode := code.toSliceHeaderCode() headerCode := code.toSliceHeaderCode()
if p == 0 { header := (*reflect.SliceHeader)(unsafe.Pointer(p))
if p == 0 || header.Data == 0 {
e.encodeNull() e.encodeNull()
code = headerCode.end.next code = headerCode.end.next
} else { } else {
e.encodeByte('[') e.encodeByte('[')
header := (*reflect.SliceHeader)(unsafe.Pointer(p))
headerCode.elem.set(header) headerCode.elem.set(header)
if header.Len > 0 { if header.Len > 0 {
code = code.next code = code.next
@ -541,6 +549,16 @@ func (e *Encoder) run(code *opcode) error {
code.ptr = ptr code.ptr = ptr
field.nextField.ptr = ptr field.nextField.ptr = ptr
} }
case opStructFieldAnonymousHead:
field := code.toStructFieldCode()
ptr := field.ptr
if ptr == 0 {
code = field.end.next
} else {
code = field.next
code.ptr = ptr
field.nextField.ptr = ptr
}
case opStructFieldPtrHeadInt: case opStructFieldPtrHeadInt:
code.ptr = e.ptrToPtr(code.ptr) code.ptr = e.ptrToPtr(code.ptr)
fallthrough fallthrough
@ -1027,7 +1045,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
b, err := v.(Marshaler).MarshalJSON() rv := reflect.ValueOf(v)
if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
e.encodeNull()
code = field.end
break
}
b, err := rv.Interface().(Marshaler).MarshalJSON()
if err != nil { if err != nil {
return &MarshalerError{ return &MarshalerError{
Type: rtype2type(code.typ), Type: rtype2type(code.typ),
@ -1062,7 +1086,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
b, err := v.(Marshaler).MarshalJSON() rv := reflect.ValueOf(v)
if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
e.encodeNull()
code = field.end
break
}
b, err := rv.Interface().(Marshaler).MarshalJSON()
if err != nil { if err != nil {
return &MarshalerError{ return &MarshalerError{
Type: rtype2type(code.typ), Type: rtype2type(code.typ),
@ -1099,7 +1129,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
bytes, err := v.(encoding.TextMarshaler).MarshalText() rv := reflect.ValueOf(v)
if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
e.encodeNull()
code = field.end
break
}
bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil { if err != nil {
return &MarshalerError{ return &MarshalerError{
Type: rtype2type(code.typ), Type: rtype2type(code.typ),
@ -1124,7 +1160,13 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
bytes, err := v.(encoding.TextMarshaler).MarshalText() rv := reflect.ValueOf(v)
if rv.Type().Kind() == reflect.Interface && rv.IsNil() {
e.encodeNull()
code = field.end
break
}
bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil { if err != nil {
return &MarshalerError{ return &MarshalerError{
Type: rtype2type(code.typ), Type: rtype2type(code.typ),
@ -3840,7 +3882,9 @@ func (e *Encoder) run(code *opcode) error {
e.encodeByte(',') e.encodeByte(',')
} }
c := code.toStructFieldCode() c := code.toStructFieldCode()
e.encodeBytes(c.key) if !c.anonymousKey {
e.encodeBytes(c.key)
}
code = code.next code = code.next
code.ptr = c.ptr + c.offset code.ptr = c.ptr + c.offset
c.nextField.ptr = c.ptr c.nextField.ptr = c.ptr

View File

@ -23,20 +23,35 @@ func isIgnoredStructField(field reflect.StructField) bool {
type structTag struct { type structTag struct {
key string key string
isTaggedKey bool
isOmitEmpty bool isOmitEmpty bool
isString bool isString bool
field reflect.StructField
}
type structTags []*structTag
func (t structTags) existsKey(key string) bool {
for _, tt := range t {
if tt.key == key {
return true
}
}
return false
} }
func structTagFromField(field reflect.StructField) *structTag { func structTagFromField(field reflect.StructField) *structTag {
keyName := field.Name keyName := field.Name
tag := getTag(field) tag := getTag(field)
st := &structTag{field: field}
opts := strings.Split(tag, ",") opts := strings.Split(tag, ",")
if len(opts) > 0 { if len(opts) > 0 {
if opts[0] != "" { if opts[0] != "" {
keyName = opts[0] keyName = opts[0]
st.isTaggedKey = true
} }
} }
st := &structTag{key: keyName} st.key = keyName
if len(opts) > 1 { if len(opts) > 1 {
st.isOmitEmpty = opts[1] == "omitempty" st.isOmitEmpty = opts[1] == "omitempty"
st.isString = opts[1] == "string" st.isString = opts[1] == "string"