Fix MarshalJSON/MarshalText

This commit is contained in:
Masaaki Goshima 2020-08-18 13:36:36 +09:00
parent a45bb76d99
commit 7ffe1ddb35
6 changed files with 142 additions and 11 deletions

View File

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
) )
func compact(dst *bytes.Buffer, src []byte) error { func compact(dst *bytes.Buffer, src []byte, escape bool) error {
length := len(src) length := len(src)
for cursor := 0; cursor < length; cursor++ { for cursor := 0; cursor < length; cursor++ {
c := src[cursor] c := src[cursor]
@ -17,10 +17,18 @@ func compact(dst *bytes.Buffer, src []byte) error {
} }
for { for {
cursor++ cursor++
if err := dst.WriteByte(src[cursor]); err != nil { c := src[cursor]
if escape && (c == '<' || c == '>' || c == '&') {
if _, err := dst.WriteString(`\u00`); err != nil {
return err
}
if _, err := dst.Write([]byte{hex[c>>4], hex[c&0xF]}); err != nil {
return err
}
} else if err := dst.WriteByte(c); err != nil {
return err return err
} }
switch src[cursor] { switch c {
case '\\': case '\\':
cursor++ cursor++
if err := dst.WriteByte(src[cursor]); err != nil { if err := dst.WriteByte(src[cursor]); err != nil {

View File

@ -8,10 +8,15 @@ import (
) )
func (e *Encoder) compileHead(typ *rtype, withIndent bool) (*opcode, error) { func (e *Encoder) compileHead(typ *rtype, withIndent bool) (*opcode, error) {
if typ.Implements(marshalJSONType) { switch {
case typ.Implements(marshalJSONType):
return newOpCode(opMarshalJSON, typ, e.indent, newEndOp(e.indent)), nil return newOpCode(opMarshalJSON, typ, e.indent, newEndOp(e.indent)), nil
} else if typ.Implements(marshalTextType) { case rtype_ptrTo(typ).Implements(marshalJSONType):
return newOpCode(opMarshalJSON, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil
case typ.Implements(marshalTextType):
return newOpCode(opMarshalText, typ, e.indent, newEndOp(e.indent)), nil return newOpCode(opMarshalText, typ, e.indent, newEndOp(e.indent)), nil
case rtype_ptrTo(typ).Implements(marshalTextType):
return newOpCode(opMarshalText, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil
} }
if typ.Kind() == reflect.Ptr { if typ.Kind() == reflect.Ptr {
typ = typ.Elem() typ = typ.Elem()
@ -24,10 +29,15 @@ func (e *Encoder) compileHead(typ *rtype, withIndent bool) (*opcode, error) {
} }
func (e *Encoder) compile(typ *rtype, root, withIndent bool) (*opcode, error) { func (e *Encoder) compile(typ *rtype, root, withIndent bool) (*opcode, error) {
if typ.Implements(marshalJSONType) { switch {
case typ.Implements(marshalJSONType):
return newOpCode(opMarshalJSON, typ, e.indent, newEndOp(e.indent)), nil return newOpCode(opMarshalJSON, typ, e.indent, newEndOp(e.indent)), nil
} else if typ.Implements(marshalTextType) { case rtype_ptrTo(typ).Implements(marshalJSONType):
return newOpCode(opMarshalJSON, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil
case typ.Implements(marshalTextType):
return newOpCode(opMarshalText, typ, e.indent, newEndOp(e.indent)), nil return newOpCode(opMarshalText, typ, e.indent, newEndOp(e.indent)), nil
case rtype_ptrTo(typ).Implements(marshalTextType):
return newOpCode(opMarshalText, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil
} }
switch typ.Kind() { switch typ.Kind() {
case reflect.Ptr: case reflect.Ptr:

View File

@ -503,3 +503,107 @@ func Test_MarshalerError(t *testing.T) {
expect := `json: error calling MarshalJSON for type *json_test.marshalerError: unexpected error` expect := `json: error calling MarshalJSON for type *json_test.marshalerError: unexpected error`
assertEq(t, "marshaler error", expect, fmt.Sprint(err)) assertEq(t, "marshaler error", expect, fmt.Sprint(err))
} }
// Ref has Marshaler and Unmarshaler methods with pointer receiver.
type Ref int
func (*Ref) MarshalJSON() ([]byte, error) {
return []byte(`"ref"`), nil
}
func (r *Ref) UnmarshalJSON([]byte) error {
*r = 12
return nil
}
// Val has Marshaler methods with value receiver.
type Val int
func (Val) MarshalJSON() ([]byte, error) {
return []byte(`"val"`), nil
}
// RefText has Marshaler and Unmarshaler methods with pointer receiver.
type RefText int
func (*RefText) MarshalText() ([]byte, error) {
return []byte(`"ref"`), nil
}
func (r *RefText) UnmarshalText([]byte) error {
*r = 13
return nil
}
// ValText has Marshaler methods with value receiver.
type ValText int
func (ValText) MarshalText() ([]byte, error) {
return []byte(`"val"`), nil
}
func TestRefValMarshal(t *testing.T) {
var s = struct {
R0 Ref
R1 *Ref
R2 RefText
R3 *RefText
V0 Val
V1 *Val
V2 ValText
V3 *ValText
}{
R0: 12,
R1: new(Ref),
R2: 14,
R3: new(RefText),
V0: 13,
V1: new(Val),
V2: 15,
V3: new(ValText),
}
const want = `{"R0":"ref","R1":"ref","R2":"\"ref\"","R3":"\"ref\"","V0":"val","V1":"val","V2":"\"val\"","V3":"\"val\""}`
b, err := json.Marshal(&s)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
if got := string(b); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
// C implements Marshaler and returns unescaped JSON.
type C int
func (C) MarshalJSON() ([]byte, error) {
return []byte(`"<&>"`), nil
}
// CText implements Marshaler and returns unescaped text.
type CText int
func (CText) MarshalText() ([]byte, error) {
return []byte(`"<&>"`), nil
}
func TestMarshalerEscaping(t *testing.T) {
var c C
want := `"\u003c\u0026\u003e"`
b, err := json.Marshal(c)
if err != nil {
t.Fatalf("Marshal(c): %v", err)
}
if got := string(b); got != want {
t.Errorf("Marshal(c) = %#q, want %#q", got, want)
}
var ct CText
want = `"\"\u003c\u0026\u003e\""`
b, err = json.Marshal(ct)
if err != nil {
t.Fatalf("Marshal(ct): %v", err)
}
if got := string(b); got != want {
t.Errorf("Marshal(ct) = %#q, want %#q", got, want)
}
}

View File

@ -1,6 +1,7 @@
package json package json
import ( import (
"bytes"
"encoding" "encoding"
"math" "math"
"reflect" "reflect"
@ -107,14 +108,18 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
bytes, err := v.(Marshaler).MarshalJSON() b, err := v.(Marshaler).MarshalJSON()
if err != nil { if err != nil {
return &MarshalerError{ return &MarshalerError{
Type: rtype2type(code.typ), Type: rtype2type(code.typ),
Err: err, Err: err,
} }
} }
e.encodeBytes(bytes) var buf bytes.Buffer
if err := compact(&buf, b, true); err != nil {
return err
}
e.encodeBytes(buf.Bytes())
code = code.next code = code.next
code.ptr = ptr code.ptr = ptr
case opMarshalText: case opMarshalText:
@ -130,7 +135,7 @@ func (e *Encoder) run(code *opcode) error {
Err: err, Err: err,
} }
} }
e.encodeBytes(bytes) e.encodeString(*(*string)(unsafe.Pointer(&bytes)))
code = code.next code = code.next
code.ptr = ptr code.ptr = ptr
case opSliceHead: case opSliceHead:

View File

@ -325,7 +325,7 @@ func (m *RawMessage) UnmarshalJSON(data []byte) error {
// Compact appends to dst the JSON-encoded src with // Compact appends to dst the JSON-encoded src with
// insignificant space characters elided. // insignificant space characters elided.
func Compact(dst *bytes.Buffer, src []byte) error { func Compact(dst *bytes.Buffer, src []byte) error {
return compact(dst, src) return compact(dst, src, false)
} }
// Indent appends to dst an indented form of the JSON-encoded src. // Indent appends to dst an indented form of the JSON-encoded src.

View File

@ -236,6 +236,10 @@ func (t *rtype) NumOut() int {
//go:noescape //go:noescape
func rtype_Out(*rtype, int) reflect.Type func rtype_Out(*rtype, int) reflect.Type
//go:linkname rtype_ptrTo reflect.(*rtype).ptrTo
//go:noescape
func rtype_ptrTo(*rtype) *rtype
func (t *rtype) Out(i int) reflect.Type { func (t *rtype) Out(i int) reflect.Type {
return rtype_Out(t, i) return rtype_Out(t, i)
} }