diff --git a/compact.go b/compact.go index 4166bb1..fa48328 100644 --- a/compact.go +++ b/compact.go @@ -4,7 +4,7 @@ import ( "bytes" ) -func compact(dst *bytes.Buffer, src []byte) error { +func compact(dst *bytes.Buffer, src []byte, escape bool) error { length := len(src) for cursor := 0; cursor < length; cursor++ { c := src[cursor] @@ -17,10 +17,18 @@ func compact(dst *bytes.Buffer, src []byte) error { } for { cursor++ - if err := dst.WriteByte(src[cursor]); err != nil { + c := src[cursor] + if escape && (c == '<' || c == '>' || c == '&') { + if _, err := dst.WriteString(`\u00`); err != nil { + return err + } + if _, err := dst.Write([]byte{hex[c>>4], hex[c&0xF]}); err != nil { + return err + } + } else if err := dst.WriteByte(c); err != nil { return err } - switch src[cursor] { + switch c { case '\\': cursor++ if err := dst.WriteByte(src[cursor]); err != nil { diff --git a/encode_compile.go b/encode_compile.go index d6c8b93..7f4c649 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -8,10 +8,15 @@ import ( ) func (e *Encoder) compileHead(typ *rtype, withIndent bool) (*opcode, error) { - if typ.Implements(marshalJSONType) { + switch { + case typ.Implements(marshalJSONType): return newOpCode(opMarshalJSON, typ, e.indent, newEndOp(e.indent)), nil - } else if typ.Implements(marshalTextType) { + case rtype_ptrTo(typ).Implements(marshalJSONType): + return newOpCode(opMarshalJSON, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil + case typ.Implements(marshalTextType): return newOpCode(opMarshalText, typ, e.indent, newEndOp(e.indent)), nil + case rtype_ptrTo(typ).Implements(marshalTextType): + return newOpCode(opMarshalText, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil } if typ.Kind() == reflect.Ptr { typ = typ.Elem() @@ -24,10 +29,15 @@ func (e *Encoder) compileHead(typ *rtype, withIndent bool) (*opcode, error) { } func (e *Encoder) compile(typ *rtype, root, withIndent bool) (*opcode, error) { - if typ.Implements(marshalJSONType) { + switch { + case typ.Implements(marshalJSONType): return newOpCode(opMarshalJSON, typ, e.indent, newEndOp(e.indent)), nil - } else if typ.Implements(marshalTextType) { + case rtype_ptrTo(typ).Implements(marshalJSONType): + return newOpCode(opMarshalJSON, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil + case typ.Implements(marshalTextType): return newOpCode(opMarshalText, typ, e.indent, newEndOp(e.indent)), nil + case rtype_ptrTo(typ).Implements(marshalTextType): + return newOpCode(opMarshalText, rtype_ptrTo(typ), e.indent, newEndOp(e.indent)), nil } switch typ.Kind() { case reflect.Ptr: diff --git a/encode_test.go b/encode_test.go index 94d9604..f92ecf6 100644 --- a/encode_test.go +++ b/encode_test.go @@ -503,3 +503,107 @@ func Test_MarshalerError(t *testing.T) { expect := `json: error calling MarshalJSON for type *json_test.marshalerError: unexpected error` assertEq(t, "marshaler error", expect, fmt.Sprint(err)) } + +// Ref has Marshaler and Unmarshaler methods with pointer receiver. +type Ref int + +func (*Ref) MarshalJSON() ([]byte, error) { + return []byte(`"ref"`), nil +} + +func (r *Ref) UnmarshalJSON([]byte) error { + *r = 12 + return nil +} + +// Val has Marshaler methods with value receiver. +type Val int + +func (Val) MarshalJSON() ([]byte, error) { + return []byte(`"val"`), nil +} + +// RefText has Marshaler and Unmarshaler methods with pointer receiver. +type RefText int + +func (*RefText) MarshalText() ([]byte, error) { + return []byte(`"ref"`), nil +} + +func (r *RefText) UnmarshalText([]byte) error { + *r = 13 + return nil +} + +// ValText has Marshaler methods with value receiver. +type ValText int + +func (ValText) MarshalText() ([]byte, error) { + return []byte(`"val"`), nil +} + +func TestRefValMarshal(t *testing.T) { + var s = struct { + R0 Ref + R1 *Ref + R2 RefText + R3 *RefText + V0 Val + V1 *Val + V2 ValText + V3 *ValText + }{ + R0: 12, + R1: new(Ref), + R2: 14, + R3: new(RefText), + V0: 13, + V1: new(Val), + V2: 15, + V3: new(ValText), + } + const want = `{"R0":"ref","R1":"ref","R2":"\"ref\"","R3":"\"ref\"","V0":"val","V1":"val","V2":"\"val\"","V3":"\"val\""}` + b, err := json.Marshal(&s) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if got := string(b); got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +// C implements Marshaler and returns unescaped JSON. +type C int + +func (C) MarshalJSON() ([]byte, error) { + return []byte(`"<&>"`), nil +} + +// CText implements Marshaler and returns unescaped text. +type CText int + +func (CText) MarshalText() ([]byte, error) { + return []byte(`"<&>"`), nil +} + +func TestMarshalerEscaping(t *testing.T) { + var c C + want := `"\u003c\u0026\u003e"` + b, err := json.Marshal(c) + if err != nil { + t.Fatalf("Marshal(c): %v", err) + } + if got := string(b); got != want { + t.Errorf("Marshal(c) = %#q, want %#q", got, want) + } + + var ct CText + want = `"\"\u003c\u0026\u003e\""` + b, err = json.Marshal(ct) + if err != nil { + t.Fatalf("Marshal(ct): %v", err) + } + if got := string(b); got != want { + t.Errorf("Marshal(ct) = %#q, want %#q", got, want) + } +} diff --git a/encode_vm.go b/encode_vm.go index 8ed7cec..16940e4 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -1,6 +1,7 @@ package json import ( + "bytes" "encoding" "math" "reflect" @@ -107,14 +108,18 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - bytes, err := v.(Marshaler).MarshalJSON() + b, err := v.(Marshaler).MarshalJSON() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), Err: err, } } - e.encodeBytes(bytes) + var buf bytes.Buffer + if err := compact(&buf, b, true); err != nil { + return err + } + e.encodeBytes(buf.Bytes()) code = code.next code.ptr = ptr case opMarshalText: @@ -130,7 +135,7 @@ func (e *Encoder) run(code *opcode) error { Err: err, } } - e.encodeBytes(bytes) + e.encodeString(*(*string)(unsafe.Pointer(&bytes))) code = code.next code.ptr = ptr case opSliceHead: diff --git a/json.go b/json.go index d740ed3..673fb3b 100644 --- a/json.go +++ b/json.go @@ -325,7 +325,7 @@ func (m *RawMessage) UnmarshalJSON(data []byte) error { // Compact appends to dst the JSON-encoded src with // insignificant space characters elided. func Compact(dst *bytes.Buffer, src []byte) error { - return compact(dst, src) + return compact(dst, src, false) } // Indent appends to dst an indented form of the JSON-encoded src. diff --git a/rtype.go b/rtype.go index ca27668..fd2a895 100644 --- a/rtype.go +++ b/rtype.go @@ -236,6 +236,10 @@ func (t *rtype) NumOut() int { //go:noescape func rtype_Out(*rtype, int) reflect.Type +//go:linkname rtype_ptrTo reflect.(*rtype).ptrTo +//go:noescape +func rtype_ptrTo(*rtype) *rtype + func (t *rtype) Out(i int) reflect.Type { return rtype_Out(t, i) }