diff --git a/encode.go b/encode.go index f41114e..58f019a 100644 --- a/encode.go +++ b/encode.go @@ -148,6 +148,10 @@ func (e *Encoder) encodeForMarshal(v interface{}) ([]byte, error) { } func (e *Encoder) encode(v interface{}) error { + if v == nil { + e.encodeNull() + return nil + } header := (*interfaceHeader)(unsafe.Pointer(&v)) typ := header.typ diff --git a/encode_compile.go b/encode_compile.go index 040e0be..8d81da9 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -566,6 +566,151 @@ func (e *Encoder) structField(fieldCode *structFieldCode, valueCode *opcode, tag } return code } + +func (e *Encoder) isNotExistsField(head *structFieldCode) bool { + if head == nil { + return false + } + if head.op != opStructFieldAnonymousHead { + return false + } + if head.next == nil { + return false + } + if head.nextField == nil { + return false + } + if head.nextField.op != opStructAnonymousEnd { + return false + } + if head.next.op == opStructAnonymousEnd { + return true + } + if head.next.op.codeType() != codeStructField { + return false + } + return e.isNotExistsField(head.next.toStructFieldCode()) +} + +func (e *Encoder) optimizeAnonymousFields(head *structFieldCode) { + code := head + var prev *structFieldCode + for { + if code.op == opStructEnd || code.op == opStructEndIndent { + break + } + if code.op == opStructField || code.op == opStructFieldIndent { + codeType := code.next.op.codeType() + if codeType == codeStructField { + if e.isNotExistsField(code.next.toStructFieldCode()) { + code.next = code.nextField + linkPrevToNextField(prev, code) + code = prev + } + } + } + prev = code + code = code.nextField.toStructFieldCode() + } +} + +type structFieldPair struct { + prevField *structFieldCode + curField *structFieldCode + isTaggedKey bool + linked bool +} + +func (e *Encoder) anonymousStructFieldPairMap(typ *rtype, tags structTags, valueCode *structFieldCode) map[string][]structFieldPair { + //fmt.Println("type = ", typ, "valueCode = ", valueCode.dump()) + anonymousFields := map[string][]structFieldPair{} + f := valueCode + var prevAnonymousField *structFieldCode + for { + existsKey := tags.existsKey(f.displayKey) + op := f.op.headToAnonymousHead() + if op != f.op { + if existsKey { + f.op = opStructFieldAnonymousHead + } else { + f.op = op + } + } else if f.op == opStructEnd { + f.op = opStructAnonymousEnd + } else if existsKey { + linkPrevToNextField(prevAnonymousField, f) + } + + if f.displayKey == "" { + if f.nextField == nil { + break + } + prevAnonymousField = f + f = f.nextField.toStructFieldCode() + continue + } + + anonymousFields[f.displayKey] = append(anonymousFields[f.displayKey], structFieldPair{ + prevField: prevAnonymousField, + curField: f, + isTaggedKey: f.isTaggedKey, + }) + if f.next != nil && f.nextField != f.next && f.next.op.codeType() == codeStructField { + for k, v := range e.anonymousStructFieldPairMap(typ, tags, f.next.toStructFieldCode()) { + anonymousFields[k] = append(anonymousFields[k], v...) + } + } + if f.nextField == nil { + break + } + prevAnonymousField = f + f = f.nextField.toStructFieldCode() + } + return anonymousFields +} + +func (e *Encoder) optimizeConflictAnonymousFields(anonymousFields map[string][]structFieldPair) { + for _, fieldPairs := range anonymousFields { + if len(fieldPairs) == 1 { + continue + } + // conflict anonymous fields + taggedPairs := []structFieldPair{} + for _, fieldPair := range fieldPairs { + if fieldPair.isTaggedKey { + taggedPairs = append(taggedPairs, fieldPair) + } else { + if !fieldPair.linked { + if fieldPair.prevField == nil { + // head operation + fieldPair.curField.op = opStructFieldAnonymousHead + } else { + linkPrevToNextField(fieldPair.prevField, fieldPair.curField) + } + fieldPair.linked = true + } + } + } + if len(taggedPairs) > 1 { + for _, fieldPair := range taggedPairs { + if !fieldPair.linked { + if fieldPair.prevField == nil { + // head operation + fieldPair.curField.op = opStructFieldAnonymousHead + } else { + linkPrevToNextField(fieldPair.prevField, fieldPair.curField) + } + fieldPair.linked = true + } + } + } else { + for _, fieldPair := range taggedPairs { + fieldPair.curField.isTaggedKey = false + } + } + } +} + func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opcode, error) { if code := e.compiledCode(typ, withIndent); code != nil { return code, nil @@ -588,12 +733,17 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco prevField *structFieldCode ) e.indent++ + tags := structTags{} + anonymousFields := map[string][]structFieldPair{} for i := 0; i < fieldNum; i++ { field := typ.Field(i) if isIgnoredStructField(field) { continue } - tag := structTagFromField(field) + tags = append(tags, structTagFromField(field)) + } + for i, tag := range tags { + field := tag.field fieldType := type2rtype(field.Type) if isPtr && i == 0 { // head field of pointer structure at top level @@ -609,16 +759,8 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco return nil, err } if field.Anonymous { - f := valueCode.toStructFieldCode() - for { - f.op = f.op.headToAnonymousHead() - if f.op == opStructEnd { - f.op = opStructAnonymousEnd - } - if f.nextField == nil { - break - } - f = f.nextField.toStructFieldCode() + for k, v := range e.anonymousStructFieldPairMap(typ, tags, valueCode.toStructFieldCode()) { + anonymousFields[k] = append(anonymousFields[k], v...) } } if fieldNum == 1 && valueCode.op == opPtr { @@ -640,6 +782,8 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco }, anonymousKey: field.Anonymous, key: []byte(key), + isTaggedKey: tag.isTaggedKey, + displayKey: tag.key, offset: field.Offset, } if fieldIdx == 0 { @@ -690,6 +834,10 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco } head.end = structEndCode code.next = structEndCode + + e.optimizeConflictAnonymousFields(anonymousFields) + e.optimizeAnonymousFields(head) + ret := (*opcode)(unsafe.Pointer(head)) compiled.code = ret diff --git a/encode_opcode.go b/encode_opcode.go index 9fef8bf..b43dc41 100644 --- a/encode_opcode.go +++ b/encode_opcode.go @@ -53,12 +53,12 @@ func (c *opcode) beforeLastCode() *opcode { code := c for { var nextCode *opcode - switch code.op { - case opArrayElem, opArrayElemIndent: + switch code.op.codeType() { + case codeArrayElem: nextCode = code.toArrayElemCode().end - case opSliceElem, opSliceElemIndent, opRootSliceElemIndent: + case codeSliceElem: nextCode = code.toSliceElemCode().end - case opMapKey, opMapKeyIndent, opRootMapKeyIndent: + case codeMapKey: nextCode = code.toMapKeyCode().end default: nextCode = code.next @@ -112,13 +112,13 @@ func (c *opcode) dump() string { codes := []string{} for code := c; code.op != opEnd; { indent := strings.Repeat(" ", code.indent) - codes = append(codes, fmt.Sprintf("%s%s", indent, code.op)) - switch code.op { - case opArrayElem, opArrayElemIndent: + codes = append(codes, fmt.Sprintf("%s%s ( %p )", indent, code.op, unsafe.Pointer(code))) + switch code.op.codeType() { + case codeArrayElem: code = code.toArrayElemCode().end - case opSliceElem, opSliceElemIndent, opRootSliceElemIndent: + case codeSliceElem: code = code.toSliceElemCode().end - case opMapKey, opMapKeyIndent, opRootMapKeyIndent: + case codeMapKey: code = code.toMapKeyCode().end default: code = code.next @@ -305,12 +305,44 @@ func (c *arrayElemCode) copy(codeMap map[uintptr]*opcode) *opcode { type structFieldCode struct { *opcodeHeader key []byte + displayKey string + isTaggedKey bool offset uintptr anonymousKey bool nextField *opcode end *opcode } +func linkPrevToNextField(prev, cur *structFieldCode) { + prev.nextField = cur.nextField + code := prev.toOpcode() + fcode := cur.toOpcode() + for { + var nextCode *opcode + switch code.op.codeType() { + case codeArrayElem: + nextCode = code.toArrayElemCode().end + case codeSliceElem: + nextCode = code.toSliceElemCode().end + case codeMapKey: + nextCode = code.toMapKeyCode().end + default: + nextCode = code.next + } + if nextCode == fcode { + code.next = fcode.next + break + } else if nextCode.op == opEnd { + break + } + code = nextCode + } +} + +func (c *structFieldCode) toOpcode() *opcode { + return (*opcode)(unsafe.Pointer(c)) +} + func (c *structFieldCode) copy(codeMap map[uintptr]*opcode) *opcode { if c == nil { return nil @@ -321,6 +353,8 @@ func (c *structFieldCode) copy(codeMap map[uintptr]*opcode) *opcode { } field := &structFieldCode{ key: c.key, + isTaggedKey: c.isTaggedKey, + displayKey: c.displayKey, anonymousKey: c.anonymousKey, offset: c.offset, } diff --git a/encode_test.go b/encode_test.go index 4ef90f9..10bf27e 100644 --- a/encode_test.go +++ b/encode_test.go @@ -2,6 +2,7 @@ package json_test import ( "bytes" + "encoding" "errors" "fmt" "log" @@ -1233,3 +1234,161 @@ func TestHTMLEscape(t *testing.T) { t.Errorf("HTMLEscape(&b, []byte(m)) = %s; want %s", b.Bytes(), want.Bytes()) } } + +type BugA struct { + S string +} + +type BugB struct { + BugA + S string +} + +type BugC struct { + S string +} + +// Legal Go: We never use the repeated embedded field (S). +type BugX struct { + A int + BugA + BugB +} + +// golang.org/issue/16042. +// Even if a nil interface value is passed in, as long as +// it implements Marshaler, it should be marshaled. +type nilJSONMarshaler string + +func (nm *nilJSONMarshaler) MarshalJSON() ([]byte, error) { + if nm == nil { + return json.Marshal("0zenil0") + } + return json.Marshal("zenil:" + string(*nm)) +} + +// golang.org/issue/34235. +// Even if a nil interface value is passed in, as long as +// it implements encoding.TextMarshaler, it should be marshaled. +type nilTextMarshaler string + +func (nm *nilTextMarshaler) MarshalText() ([]byte, error) { + if nm == nil { + return []byte("0zenil0"), nil + } + return []byte("zenil:" + string(*nm)), nil +} + +// See golang.org/issue/16042 and golang.org/issue/34235. +func TestNilMarshal(t *testing.T) { + testCases := []struct { + v interface{} + want string + }{ + {v: nil, want: `null`}, + {v: new(float64), want: `0`}, + {v: []interface{}(nil), want: `null`}, + {v: []string(nil), want: `null`}, + {v: map[string]string(nil), want: `null`}, + {v: []byte(nil), want: `null`}, + {v: struct{ M string }{"gopher"}, want: `{"M":"gopher"}`}, + {v: struct{ M json.Marshaler }{}, want: `{"M":null}`}, + {v: struct{ M json.Marshaler }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, + {v: struct{ M interface{} }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, // doesn't compatible with encoding/json + {v: struct{ M encoding.TextMarshaler }{}, want: `{"M":null}`}, + {v: struct{ M encoding.TextMarshaler }{(*nilTextMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, + {v: struct{ M interface{} }{(*nilTextMarshaler)(nil)}, want: `{"M":null}`}, + } + + for i, tt := range testCases { + out, err := json.Marshal(tt.v) + if err != nil || string(out) != tt.want { + t.Errorf("%d: Marshal(%#v) = %#q, %#v, want %#q, nil", i, tt.v, out, err, tt.want) + continue + } + } +} + +// Issue 5245. +func TestEmbeddedBug(t *testing.T) { + v := BugB{ + BugA{"A"}, + "B", + } + b, err := json.Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{"S":"B"}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } + // Now check that the duplicate field, S, does not appear. + x := BugX{ + A: 23, + } + b, err = json.Marshal(x) + if err != nil { + t.Fatal("Marshal:", err) + } + want = `{"A":23}` + got = string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} + +type BugD struct { // Same as BugA after tagging. + XXX string `json:"S"` +} + +// BugD's tagged S field should dominate BugA's. +type BugY struct { + BugA + BugD +} + +// Test that a field with a tag dominates untagged fields. +func TestTaggedFieldDominates(t *testing.T) { + v := BugY{ + BugA{"BugA"}, + BugD{"BugD"}, + } + b, err := json.Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{"S":"BugD"}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} + +// There are no tags here, so S should not appear. +type BugZ struct { + BugA + BugC + BugY // Contains a tagged S field through BugD; should not dominate. +} + +func TestDuplicatedFieldDisappears(t *testing.T) { + v := BugZ{ + BugA{"BugA"}, + BugC{"BugC"}, + BugY{ + BugA{"nested BugA"}, + BugD{"nested BugD"}, + }, + } + b, err := json.Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} diff --git a/encode_vm.go b/encode_vm.go index e2d0971..8265cde 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -69,10 +69,16 @@ func (e *Encoder) run(code *opcode) error { e.encodeBool(e.ptrToBool(code.ptr)) code = code.next case opBytes: - s := base64.StdEncoding.EncodeToString(e.ptrToBytes(code.ptr)) - e.encodeByte('"') - e.encodeBytes(*(*[]byte)(unsafe.Pointer(&s))) - e.encodeByte('"') + ptr := code.ptr + header := (*reflect.SliceHeader)(unsafe.Pointer(ptr)) + if ptr == 0 || header.Data == 0 { + e.encodeNull() + } else { + s := base64.StdEncoding.EncodeToString(e.ptrToBytes(code.ptr)) + e.encodeByte('"') + e.encodeBytes(*(*[]byte)(unsafe.Pointer(&s))) + e.encodeByte('"') + } code = code.next case opInterface: ifaceCode := code.toInterfaceCode() @@ -148,7 +154,9 @@ func (e *Encoder) run(code *opcode) error { ptr := code.ptr isPtr := code.typ.Kind() == reflect.Ptr p := unsafe.Pointer(ptr) - if isPtr && *(*unsafe.Pointer)(p) == nil { + if p == nil { + e.encodeNull() + } else if isPtr && *(*unsafe.Pointer)(p) == nil { e.encodeBytes([]byte{'"', '"'}) } else { if isPtr && code.typ.Elem().Implements(marshalTextType) { @@ -172,12 +180,12 @@ func (e *Encoder) run(code *opcode) error { case opSliceHead: p := code.ptr headerCode := code.toSliceHeaderCode() - if p == 0 { + header := (*reflect.SliceHeader)(unsafe.Pointer(p)) + if p == 0 || header.Data == 0 { e.encodeNull() code = headerCode.end.next } else { e.encodeByte('[') - header := (*reflect.SliceHeader)(unsafe.Pointer(p)) headerCode.elem.set(header) if header.Len > 0 { code = code.next @@ -541,6 +549,16 @@ func (e *Encoder) run(code *opcode) error { code.ptr = ptr field.nextField.ptr = ptr } + case opStructFieldAnonymousHead: + field := code.toStructFieldCode() + ptr := field.ptr + if ptr == 0 { + code = field.end.next + } else { + code = field.next + code.ptr = ptr + field.nextField.ptr = ptr + } case opStructFieldPtrHeadInt: code.ptr = e.ptrToPtr(code.ptr) fallthrough @@ -1027,7 +1045,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - b, err := v.(Marshaler).MarshalJSON() + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { + e.encodeNull() + code = field.end + break + } + b, err := rv.Interface().(Marshaler).MarshalJSON() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), @@ -1062,7 +1086,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - b, err := v.(Marshaler).MarshalJSON() + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { + e.encodeNull() + code = field.end + break + } + b, err := rv.Interface().(Marshaler).MarshalJSON() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), @@ -1099,7 +1129,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - bytes, err := v.(encoding.TextMarshaler).MarshalText() + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { + e.encodeNull() + code = field.end + break + } + bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), @@ -1124,7 +1160,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - bytes, err := v.(encoding.TextMarshaler).MarshalText() + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { + e.encodeNull() + code = field.end + break + } + bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), @@ -3840,7 +3882,9 @@ func (e *Encoder) run(code *opcode) error { e.encodeByte(',') } c := code.toStructFieldCode() - e.encodeBytes(c.key) + if !c.anonymousKey { + e.encodeBytes(c.key) + } code = code.next code.ptr = c.ptr + c.offset c.nextField.ptr = c.ptr diff --git a/struct_field.go b/struct_field.go index 3352da1..b2a6f50 100644 --- a/struct_field.go +++ b/struct_field.go @@ -23,20 +23,35 @@ func isIgnoredStructField(field reflect.StructField) bool { type structTag struct { key string + isTaggedKey bool isOmitEmpty bool isString bool + field reflect.StructField +} + +type structTags []*structTag + +func (t structTags) existsKey(key string) bool { + for _, tt := range t { + if tt.key == key { + return true + } + } + return false } func structTagFromField(field reflect.StructField) *structTag { keyName := field.Name tag := getTag(field) + st := &structTag{field: field} opts := strings.Split(tag, ",") if len(opts) > 0 { if opts[0] != "" { keyName = opts[0] + st.isTaggedKey = true } } - st := &structTag{key: keyName} + st.key = keyName if len(opts) > 1 { st.isOmitEmpty = opts[1] == "omitempty" st.isString = opts[1] == "string"