diff --git a/encode_compile.go b/encode_compile.go index 730d695..b7efaf6 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -92,6 +92,56 @@ func (e *Encoder) compile(typ *rtype, root, withIndent bool) (*opcode, error) { return nil, &UnsupportedTypeError{Type: rtype2type(typ)} } +func (e *Encoder) compileKey(typ *rtype, root, withIndent bool) (*opcode, error) { + switch { + case typ.Implements(marshalJSONType): + return newOpCode(opMarshalJSON, typ, e.indent, newEndOp(e.indent)), nil + case rtype_ptrTo(typ).Implements(marshalJSONType): + return newOpCode(opMarshalJSON, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil + case typ.Implements(marshalTextType): + return newOpCode(opMarshalText, typ, e.indent, newEndOp(e.indent)), nil + case rtype_ptrTo(typ).Implements(marshalTextType): + return newOpCode(opMarshalText, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil + } + switch typ.Kind() { + case reflect.Ptr: + return e.compilePtr(typ, root, withIndent) + case reflect.Interface: + return e.compileInterface(typ, root) + case reflect.Int: + return e.compileInt(typ) + case reflect.Int8: + return e.compileInt8(typ) + case reflect.Int16: + return e.compileInt16(typ) + case reflect.Int32: + return e.compileInt32(typ) + case reflect.Int64: + return e.compileInt64(typ) + case reflect.Uint: + return e.compileUint(typ) + case reflect.Uint8: + return e.compileUint8(typ) + case reflect.Uint16: + return e.compileUint16(typ) + case reflect.Uint32: + return e.compileUint32(typ) + case reflect.Uint64: + return e.compileUint64(typ) + case reflect.Uintptr: + return e.compileUint(typ) + case reflect.Float32: + return e.compileFloat32(typ) + case reflect.Float64: + return e.compileFloat64(typ) + case reflect.String: + return e.compileString(typ) + case reflect.Bool: + return e.compileBool(typ) + } + return nil, &UnsupportedTypeError{Type: rtype2type(typ)} +} + func (e *Encoder) optimizeStructFieldPtrHead(typ *rtype, code *opcode) *opcode { ptrHeadOp := code.op.headToPtrHead() if code.op != ptrHeadOp { @@ -289,7 +339,7 @@ func (e *Encoder) compileMap(typ *rtype, withLoad, root, withIndent bool) (*opco // |_______________________| e.indent++ keyType := typ.Key() - keyCode, err := e.compile(keyType, false, withIndent) + keyCode, err := e.compileKey(keyType, false, withIndent) if err != nil { return nil, err } diff --git a/encode_test.go b/encode_test.go index 448c3b2..40823b7 100644 --- a/encode_test.go +++ b/encode_test.go @@ -1,6 +1,7 @@ package json_test import ( + "bytes" "errors" "fmt" "log" @@ -867,6 +868,56 @@ func TestMarshalerError(t *testing.T) { } } +type unmarshalerText struct { + A, B string +} + +// needed for re-marshaling tests +func (u unmarshalerText) MarshalText() ([]byte, error) { + return []byte(u.A + ":" + u.B), nil +} + +func (u *unmarshalerText) UnmarshalText(b []byte) error { + pos := bytes.IndexByte(b, ':') + if pos == -1 { + return errors.New("missing separator") + } + u.A, u.B = string(b[:pos]), string(b[pos+1:]) + return nil +} + +func TestTextMarshalerMapKeysAreSorted(t *testing.T) { + b, err := json.Marshal(map[unmarshalerText]int{ + {"x", "y"}: 1, + {"y", "x"}: 2, + {"a", "z"}: 3, + {"z", "a"}: 4, + }) + if err != nil { + t.Fatalf("Failed to Marshal text.Marshaler: %v", err) + } + const want = `{"a:z":3,"x:y":1,"y:x":2,"z:a":4}` + if len(string(b)) != len(want) { + t.Errorf("Marshal map with text.Marshaler keys: got %#q, want %#q", b, want) + } +} + +// https://golang.org/issue/33675 +func TestNilMarshalerTextMapKey(t *testing.T) { + v := map[*unmarshalerText]int{ + (*unmarshalerText)(nil): 1, + {"A", "B"}: 2, + } + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("Failed to Marshal *text.Marshaler: %v", err) + } + const want = `{"":1,"A:B":2}` + if string(b) != want { + t.Errorf("Marshal map with *text.Marshaler keys: got %#q, want %#q", b, want) + } +} + var re = regexp.MustCompile // syntactic checks on form of marshaled floating point numbers. diff --git a/encode_vm.go b/encode_vm.go index cfe5dde..71e42ff 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -138,18 +138,27 @@ func (e *Encoder) run(code *opcode) error { code.ptr = ptr case opMarshalText: ptr := code.ptr - v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ - typ: code.typ, - ptr: unsafe.Pointer(ptr), - })) - bytes, err := v.(encoding.TextMarshaler).MarshalText() - if err != nil { - return &MarshalerError{ - Type: rtype2type(code.typ), - Err: err, + isPtr := code.typ.Kind() == reflect.Ptr + p := unsafe.Pointer(ptr) + if isPtr && *(*unsafe.Pointer)(p) == nil { + e.encodeBytes([]byte{'"', '"'}) + } else { + if isPtr && code.typ.Elem().Implements(marshalTextType) { + p = *(*unsafe.Pointer)(p) } + v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ + typ: code.typ, + ptr: p, + })) + bytes, err := v.(encoding.TextMarshaler).MarshalText() + if err != nil { + return &MarshalerError{ + Type: rtype2type(code.typ), + Err: err, + } + } + e.encodeString(*(*string)(unsafe.Pointer(&bytes))) } - e.encodeString(*(*string)(unsafe.Pointer(&bytes))) code = code.next code.ptr = ptr case opSliceHead: