diff --git a/encode_compile.go b/encode_compile.go index 899aae3..a95d696 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -16,6 +16,9 @@ func (e *Encoder) compileHead(typ *rtype, withIndent bool) (*opcode, error) { if typ.Kind() == reflect.Ptr { typ = typ.Elem() } + if typ.Kind() == reflect.Map { + return e.compileMap(typ, false, withIndent) + } return e.compile(typ, withIndent) } @@ -33,7 +36,7 @@ func (e *Encoder) compile(typ *rtype, withIndent bool) (*opcode, error) { case reflect.Array: return e.compileArray(typ, withIndent) case reflect.Map: - return e.compileMap(typ, withIndent) + return e.compileMap(typ, true, withIndent) case reflect.Struct: return e.compileStruct(typ, withIndent) case reflect.Int: @@ -366,7 +369,7 @@ func mapiternext(it unsafe.Pointer) //go:noescape func maplen(m unsafe.Pointer) int -func (e *Encoder) compileMap(typ *rtype, withIndent bool) (*opcode, error) { +func (e *Encoder) compileMap(typ *rtype, withLoad, withIndent bool) (*opcode, error) { // header => code => value => code => key => code => value => code => end // ^ | // |_______________________| @@ -387,13 +390,17 @@ func (e *Encoder) compileMap(typ *rtype, withIndent bool) (*opcode, error) { e.indent-- - header := newMapHeaderCode(typ, e.indent) + header := newMapHeaderCode(typ, withLoad, e.indent) header.key = key header.value = value end := newOpCode(opMapEnd, nil, e.indent, newEndOp(e.indent)) if withIndent { - header.op = opMapHeadIndent + if header.op == opMapHead { + header.op = opMapHeadIndent + } else { + header.op = opMapHeadLoadIndent + } key.op = opMapKeyIndent value.op = opMapValueIndent end.op = opMapEndIndent diff --git a/encode_opcode.go b/encode_opcode.go index 341a100..aedbbd8 100644 --- a/encode_opcode.go +++ b/encode_opcode.go @@ -47,11 +47,13 @@ const ( opArrayEndIndent opMapHead + opMapHeadLoad opMapKey opMapValue opMapEnd opMapHeadIndent + opMapHeadLoadIndent opMapKeyIndent opMapValueIndent opMapEndIndent @@ -327,8 +329,9 @@ func (t opType) String() string { case opArrayEndIndent: return "ARRAY_END_INDENT" case opMapHead: - return "MAP_HEAD" + case opMapHeadLoad: + return "MAP_HEAD_LOAD" case opMapKey: return "MAP_KEY" case opMapValue: @@ -338,6 +341,8 @@ func (t opType) String() string { case opMapHeadIndent: return "MAP_HEAD_INDENT" + case opMapHeadLoadIndent: + return "MAP_HEAD_LOAD_INDENT" case opMapKeyIndent: return "MAP_KEY_INDENT" case opMapValueIndent: @@ -919,10 +924,16 @@ func (c *mapValueCode) set(iter unsafe.Pointer) { c.iter = iter } -func newMapHeaderCode(typ *rtype, indent int) *mapHeaderCode { +func newMapHeaderCode(typ *rtype, withLoad bool, indent int) *mapHeaderCode { + var op opType + if withLoad { + op = opMapHeadLoad + } else { + op = opMapHead + } return &mapHeaderCode{ opcodeHeader: &opcodeHeader{ - op: opMapHead, + op: op, typ: typ, indent: indent, }, diff --git a/encode_test.go b/encode_test.go index 8e53205..a7e518a 100644 --- a/encode_test.go +++ b/encode_test.go @@ -129,6 +129,116 @@ func Test_Marshal(t *testing.T) { bytes, err := json.Marshal(&v) assertErr(t, err) assertEq(t, "struct", `{"t":1}`, string(bytes)) + t.Run("int", func(t *testing.T) { + var v struct { + A int `json:"a,omitempty"` + B int `json:"b"` + } + v.B = 1 + bytes, err := json.Marshal(&v) + assertErr(t, err) + assertEq(t, "int", `{"b":1}`, string(bytes)) + }) + t.Run("int8", func(t *testing.T) { + var v struct { + A int `json:"a,omitempty"` + B int8 `json:"b"` + } + v.B = 1 + bytes, err := json.Marshal(&v) + assertErr(t, err) + assertEq(t, "int8", `{"b":1}`, string(bytes)) + }) + t.Run("int16", func(t *testing.T) { + var v struct { + A int `json:"a,omitempty"` + B int16 `json:"b"` + } + v.B = 1 + bytes, err := json.Marshal(&v) + assertErr(t, err) + assertEq(t, "int16", `{"b":1}`, string(bytes)) + }) + t.Run("int32", func(t *testing.T) { + var v struct { + A int `json:"a,omitempty"` + B int32 `json:"b"` + } + v.B = 1 + bytes, err := json.Marshal(&v) + assertErr(t, err) + assertEq(t, "int32", `{"b":1}`, string(bytes)) + }) + t.Run("int64", func(t *testing.T) { + var v struct { + A int `json:"a,omitempty"` + B int64 `json:"b"` + } + v.B = 1 + bytes, err := json.Marshal(&v) + assertErr(t, err) + assertEq(t, "int64", `{"b":1}`, string(bytes)) + }) + t.Run("string", func(t *testing.T) { + var v struct { + A int `json:"a,omitempty"` + B string `json:"b"` + } + v.B = "b" + bytes, err := json.Marshal(&v) + assertErr(t, err) + assertEq(t, "string", `{"b":"b"}`, string(bytes)) + }) + t.Run("float32", func(t *testing.T) { + var v struct { + A int `json:"a,omitempty"` + B float32 `json:"b"` + } + v.B = 1.1 + bytes, err := json.Marshal(&v) + assertErr(t, err) + assertEq(t, "float32", `{"b":1.1}`, string(bytes)) + }) + t.Run("float64", func(t *testing.T) { + var v struct { + A int `json:"a,omitempty"` + B float64 `json:"b"` + } + v.B = 3.14 + bytes, err := json.Marshal(&v) + assertErr(t, err) + assertEq(t, "float64", `{"b":3.14}`, string(bytes)) + }) + t.Run("slice", func(t *testing.T) { + var v struct { + A int `json:"a,omitempty"` + B []int `json:"b"` + } + v.B = []int{1, 2, 3} + bytes, err := json.Marshal(&v) + assertErr(t, err) + assertEq(t, "slice", `{"b":[1,2,3]}`, string(bytes)) + }) + t.Run("array", func(t *testing.T) { + var v struct { + A int `json:"a,omitempty"` + B [2]int `json:"b"` + } + v.B = [2]int{1, 2} + bytes, err := json.Marshal(&v) + assertErr(t, err) + assertEq(t, "array", `{"b":[1,2]}`, string(bytes)) + }) + t.Run("map", func(t *testing.T) { + v := new(struct { + A int `json:"a,omitempty"` + B map[string]interface{} `json:"b"` + }) + v.B = map[string]interface{}{"c": 1} + bytes, err := json.Marshal(v) + assertErr(t, err) + assertEq(t, "array", `{"b":{"c":1}}`, string(bytes)) + }) }) t.Run("head_omitempty", func(t *testing.T) { type T struct { diff --git a/encode_vm.go b/encode_vm.go index 7ee3802..804d49a 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -264,6 +264,30 @@ func (e *Encoder) run(code *opcode) error { code = mapHeadCode.end.next } } + case opMapHeadLoad: + ptr := code.ptr + mapHeadCode := code.toMapHeadCode() + if ptr == 0 { + e.encodeNull() + code = mapHeadCode.end.next + } else { + // load pointer + ptr = uintptr(*(*unsafe.Pointer)(unsafe.Pointer(ptr))) + e.encodeByte('{') + mlen := maplen(unsafe.Pointer(ptr)) + if mlen > 0 { + iter := mapiterinit(code.typ, unsafe.Pointer(ptr)) + mapHeadCode.key.set(mlen, iter) + mapHeadCode.value.set(iter) + key := mapiterkey(iter) + code.next.ptr = uintptr(key) + code = code.next + } else { + e.encodeByte('}') + code = mapHeadCode.end.next + } + } + case opMapKey: c := code.toMapKeyCode() c.idx++ @@ -1681,91 +1705,117 @@ func (e *Encoder) run(code *opcode) error { field.nextField.ptr = field.ptr } case opStructField: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() e.encodeBytes(c.key) code = code.next code.ptr = c.ptr + c.offset c.nextField.ptr = c.ptr case opStructFieldInt: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeInt(e.ptrToInt(c.ptr + c.offset)) code = code.next case opStructFieldInt8: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeInt8(e.ptrToInt8(c.ptr + c.offset)) code = code.next case opStructFieldInt16: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeInt16(e.ptrToInt16(c.ptr + c.offset)) code = code.next case opStructFieldInt32: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeInt32(e.ptrToInt32(c.ptr + c.offset)) code = code.next case opStructFieldInt64: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeInt64(e.ptrToInt64(c.ptr + c.offset)) code = code.next case opStructFieldUint: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeUint(e.ptrToUint(c.ptr + c.offset)) code = code.next case opStructFieldUint8: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeUint8(e.ptrToUint8(c.ptr + c.offset)) code = code.next case opStructFieldUint16: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeUint16(e.ptrToUint16(c.ptr + c.offset)) code = code.next case opStructFieldUint32: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeUint32(e.ptrToUint32(c.ptr + c.offset)) code = code.next case opStructFieldUint64: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeUint64(e.ptrToUint64(c.ptr + c.offset)) code = code.next case opStructFieldFloat32: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeFloat32(e.ptrToFloat32(c.ptr + c.offset)) code = code.next case opStructFieldFloat64: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) @@ -1779,14 +1829,18 @@ func (e *Encoder) run(code *opcode) error { e.encodeFloat64(v) code = code.next case opStructFieldString: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) e.encodeString(e.ptrToString(c.ptr + c.offset)) code = code.next case opStructFieldBool: - e.encodeByte(',') + if e.buf[len(e.buf)-1] != '{' { + e.encodeByte(',') + } c := code.toStructFieldCode() c.nextField.ptr = c.ptr e.encodeBytes(c.key) @@ -1795,7 +1849,9 @@ func (e *Encoder) run(code *opcode) error { case opStructFieldIndent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1804,7 +1860,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldIntIndent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1813,7 +1871,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldInt8Indent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1822,7 +1882,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldInt16Indent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1831,7 +1893,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldInt32Indent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1840,7 +1904,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldInt64Indent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1849,7 +1915,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldUintIndent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1858,7 +1926,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldUint8Indent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1867,7 +1937,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldUint16Indent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1876,7 +1948,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldUint32Indent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1885,7 +1959,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldUint64Indent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1894,7 +1970,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldFloat32Indent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1903,7 +1981,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldFloat64Indent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1919,7 +1999,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldStringIndent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ') @@ -1928,7 +2010,9 @@ func (e *Encoder) run(code *opcode) error { c.nextField.ptr = c.ptr case opStructFieldBoolIndent: c := code.toStructFieldCode() - e.encodeBytes([]byte{',', '\n'}) + if e.buf[len(e.buf)-2] != '{' { + e.encodeBytes([]byte{',', '\n'}) + } e.encodeIndent(c.indent) e.encodeBytes(c.key) e.encodeByte(' ')