Fix map key by UnmarshalText

This commit is contained in:
Masaaki Goshima 2020-08-20 17:47:38 +09:00
parent 652e7a9369
commit 8d029cddbe
3 changed files with 121 additions and 11 deletions

View File

@ -92,6 +92,56 @@ func (e *Encoder) compile(typ *rtype, root, withIndent bool) (*opcode, error) {
return nil, &UnsupportedTypeError{Type: rtype2type(typ)} return nil, &UnsupportedTypeError{Type: rtype2type(typ)}
} }
func (e *Encoder) compileKey(typ *rtype, root, withIndent bool) (*opcode, error) {
switch {
case typ.Implements(marshalJSONType):
return newOpCode(opMarshalJSON, typ, e.indent, newEndOp(e.indent)), nil
case rtype_ptrTo(typ).Implements(marshalJSONType):
return newOpCode(opMarshalJSON, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil
case typ.Implements(marshalTextType):
return newOpCode(opMarshalText, typ, e.indent, newEndOp(e.indent)), nil
case rtype_ptrTo(typ).Implements(marshalTextType):
return newOpCode(opMarshalText, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil
}
switch typ.Kind() {
case reflect.Ptr:
return e.compilePtr(typ, root, withIndent)
case reflect.Interface:
return e.compileInterface(typ, root)
case reflect.Int:
return e.compileInt(typ)
case reflect.Int8:
return e.compileInt8(typ)
case reflect.Int16:
return e.compileInt16(typ)
case reflect.Int32:
return e.compileInt32(typ)
case reflect.Int64:
return e.compileInt64(typ)
case reflect.Uint:
return e.compileUint(typ)
case reflect.Uint8:
return e.compileUint8(typ)
case reflect.Uint16:
return e.compileUint16(typ)
case reflect.Uint32:
return e.compileUint32(typ)
case reflect.Uint64:
return e.compileUint64(typ)
case reflect.Uintptr:
return e.compileUint(typ)
case reflect.Float32:
return e.compileFloat32(typ)
case reflect.Float64:
return e.compileFloat64(typ)
case reflect.String:
return e.compileString(typ)
case reflect.Bool:
return e.compileBool(typ)
}
return nil, &UnsupportedTypeError{Type: rtype2type(typ)}
}
func (e *Encoder) optimizeStructFieldPtrHead(typ *rtype, code *opcode) *opcode { func (e *Encoder) optimizeStructFieldPtrHead(typ *rtype, code *opcode) *opcode {
ptrHeadOp := code.op.headToPtrHead() ptrHeadOp := code.op.headToPtrHead()
if code.op != ptrHeadOp { if code.op != ptrHeadOp {
@ -289,7 +339,7 @@ func (e *Encoder) compileMap(typ *rtype, withLoad, root, withIndent bool) (*opco
// |_______________________| // |_______________________|
e.indent++ e.indent++
keyType := typ.Key() keyType := typ.Key()
keyCode, err := e.compile(keyType, false, withIndent) keyCode, err := e.compileKey(keyType, false, withIndent)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package json_test package json_test
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"log" "log"
@ -867,6 +868,56 @@ func TestMarshalerError(t *testing.T) {
} }
} }
type unmarshalerText struct {
A, B string
}
// needed for re-marshaling tests
func (u unmarshalerText) MarshalText() ([]byte, error) {
return []byte(u.A + ":" + u.B), nil
}
func (u *unmarshalerText) UnmarshalText(b []byte) error {
pos := bytes.IndexByte(b, ':')
if pos == -1 {
return errors.New("missing separator")
}
u.A, u.B = string(b[:pos]), string(b[pos+1:])
return nil
}
func TestTextMarshalerMapKeysAreSorted(t *testing.T) {
b, err := json.Marshal(map[unmarshalerText]int{
{"x", "y"}: 1,
{"y", "x"}: 2,
{"a", "z"}: 3,
{"z", "a"}: 4,
})
if err != nil {
t.Fatalf("Failed to Marshal text.Marshaler: %v", err)
}
const want = `{"a:z":3,"x:y":1,"y:x":2,"z:a":4}`
if len(string(b)) != len(want) {
t.Errorf("Marshal map with text.Marshaler keys: got %#q, want %#q", b, want)
}
}
// https://golang.org/issue/33675
func TestNilMarshalerTextMapKey(t *testing.T) {
v := map[*unmarshalerText]int{
(*unmarshalerText)(nil): 1,
{"A", "B"}: 2,
}
b, err := json.Marshal(v)
if err != nil {
t.Fatalf("Failed to Marshal *text.Marshaler: %v", err)
}
const want = `{"":1,"A:B":2}`
if string(b) != want {
t.Errorf("Marshal map with *text.Marshaler keys: got %#q, want %#q", b, want)
}
}
var re = regexp.MustCompile var re = regexp.MustCompile
// syntactic checks on form of marshaled floating point numbers. // syntactic checks on form of marshaled floating point numbers.

View File

@ -138,18 +138,27 @@ func (e *Encoder) run(code *opcode) error {
code.ptr = ptr code.ptr = ptr
case opMarshalText: case opMarshalText:
ptr := code.ptr ptr := code.ptr
v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ isPtr := code.typ.Kind() == reflect.Ptr
typ: code.typ, p := unsafe.Pointer(ptr)
ptr: unsafe.Pointer(ptr), if isPtr && *(*unsafe.Pointer)(p) == nil {
})) e.encodeBytes([]byte{'"', '"'})
bytes, err := v.(encoding.TextMarshaler).MarshalText() } else {
if err != nil { if isPtr && code.typ.Elem().Implements(marshalTextType) {
return &MarshalerError{ p = *(*unsafe.Pointer)(p)
Type: rtype2type(code.typ),
Err: err,
} }
v := *(*interface{})(unsafe.Pointer(&interfaceHeader{
typ: code.typ,
ptr: p,
}))
bytes, err := v.(encoding.TextMarshaler).MarshalText()
if err != nil {
return &MarshalerError{
Type: rtype2type(code.typ),
Err: err,
}
}
e.encodeString(*(*string)(unsafe.Pointer(&bytes)))
} }
e.encodeString(*(*string)(unsafe.Pointer(&bytes)))
code = code.next code = code.next
code.ptr = ptr code.ptr = ptr
case opSliceHead: case opSliceHead: