From c23e5f43a77ca2b6b4fa67d196b5dd511168f123 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 8 May 2020 20:22:57 +0900 Subject: [PATCH] Support UnmarshalJSON --- decode.go | 23 +++++++++++++-- decode_context.go | 62 ++++++++++++++++++++++++++++++++++++++++ decode_struct.go | 48 +------------------------------ decode_test.go | 21 ++++++++++++++ decode_unmarshal_json.go | 31 ++++++++++++++++++++ encode.go | 7 ++--- encode_vm.go | 3 +- json.go | 12 ++++++++ 8 files changed, 151 insertions(+), 56 deletions(-) create mode 100644 decode_unmarshal_json.go diff --git a/decode.go b/decode.go index f8fbcd2..143294a 100644 --- a/decode.go +++ b/decode.go @@ -2,6 +2,7 @@ package json import ( "bytes" + "encoding" "io" "reflect" "strings" @@ -49,7 +50,9 @@ func (m *decoderMap) set(k uintptr, dec decoder) { } var ( - cachedDecoder decoderMap + cachedDecoder decoderMap + unmarshalJSONType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() + unmarshalTextType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() ) func init() { @@ -78,7 +81,7 @@ func (d *Decoder) decode(src []byte, header *interfaceHeader) error { typeptr := uintptr(unsafe.Pointer(typ)) dec := cachedDecoder.get(typeptr) if dec == nil { - compiledDec, err := d.compile(typ.Elem()) + compiledDec, err := d.compileHead(typ) if err != nil { return err } @@ -117,7 +120,7 @@ func (d *Decoder) Decode(v interface{}) error { typeptr := uintptr(unsafe.Pointer(typ)) dec := cachedDecoder.get(typeptr) if dec == nil { - compiledDec, err := d.compile(typ.Elem()) + compiledDec, err := d.compileHead(typ) if err != nil { return err } @@ -145,7 +148,21 @@ func (d *Decoder) Decode(v interface{}) error { return nil } +func (d *Decoder) compileHead(typ *rtype) (decoder, error) { + if typ.Implements(unmarshalJSONType) { + return newUnmarshalJSONDecoder(typ), nil + } else if typ.Implements(unmarshalTextType) { + + } + return d.compile(typ.Elem()) +} + func (d *Decoder) compile(typ *rtype) (decoder, error) { + if typ.Implements(unmarshalJSONType) { + return newUnmarshalJSONDecoder(typ), nil + } else if typ.Implements(unmarshalTextType) { + + } switch typ.Kind() { case reflect.Ptr: return d.compilePtr(typ) diff --git a/decode_context.go b/decode_context.go index 32ec0e9..1bb8f5d 100644 --- a/decode_context.go +++ b/decode_context.go @@ -1,5 +1,9 @@ package json +import ( + "errors" +) + var ( isWhiteSpace = [256]bool{} ) @@ -19,3 +23,61 @@ LOOP: } return cursor } + +func skipValue(buf []byte, cursor int) (int, error) { + cursor = skipWhiteSpace(buf, cursor) + braceCount := 0 + bracketCount := 0 + buflen := len(buf) + for { + switch buf[cursor] { + case '\000': + return cursor, errors.New("unexpected error value") + case '{': + braceCount++ + case '[': + bracketCount++ + case '}': + braceCount-- + if braceCount == -1 && bracketCount == 0 { + return cursor, nil + } + case ']': + bracketCount-- + case ',': + if bracketCount == 0 && braceCount == 0 { + return cursor, nil + } + case '"': + cursor++ + + for ; cursor < buflen; cursor++ { + if buf[cursor] != '"' { + continue + } + if buf[cursor-1] == '\\' { + continue + } + if bracketCount == 0 && braceCount == 0 { + return cursor + 1, nil + } + break + } + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + cursor++ + for ; cursor < buflen; cursor++ { + tk := int(buf[cursor]) + if (int('0') <= tk && tk <= int('9')) || tk == '.' || tk == 'e' || tk == 'E' { + continue + } + break + } + if bracketCount == 0 && braceCount == 0 { + return cursor, nil + } + continue + } + cursor++ + } + return cursor, errors.New("unexpected error value") +} diff --git a/decode_struct.go b/decode_struct.go index 13d510c..3e08b23 100644 --- a/decode_struct.go +++ b/decode_struct.go @@ -22,52 +22,6 @@ func newStructDecoder(fieldMap map[string]*structFieldSet) *structDecoder { } } -func (d *structDecoder) skipValue(buf []byte, cursor int) (int, error) { - cursor = skipWhiteSpace(buf, cursor) - braceCount := 0 - bracketCount := 0 - buflen := len(buf) - for { - switch buf[cursor] { - case '\000': - return cursor, errors.New("unexpected error value") - case '{': - braceCount++ - case '[': - bracketCount++ - case '}': - braceCount-- - if braceCount == -1 && bracketCount == 0 { - return cursor, nil - } - case ']': - bracketCount-- - case ',': - if bracketCount == 0 && braceCount == 0 { - return cursor, nil - } - case '"': - cursor++ - - for ; cursor < buflen; cursor++ { - if buf[cursor] != '"' { - continue - } - if buf[cursor-1] == '\\' { - continue - } - if bracketCount == 0 && braceCount == 0 { - return cursor + 1, nil - } - break - } - - } - cursor++ - } - return cursor, errors.New("unexpected error value") -} - func (d *structDecoder) decode(buf []byte, cursor int, p uintptr) (int, error) { buflen := len(buf) cursor = skipWhiteSpace(buf, cursor) @@ -101,7 +55,7 @@ func (d *structDecoder) decode(buf []byte, cursor int, p uintptr) (int, error) { } cursor = c } else { - c, err := d.skipValue(buf, cursor) + c, err := skipValue(buf, cursor) if err != nil { return 0, err } diff --git a/decode_test.go b/decode_test.go index 0b74f6a..cd9011d 100644 --- a/decode_test.go +++ b/decode_test.go @@ -176,3 +176,24 @@ func Test_Decoder(t *testing.T) { }) }) } + +type unmarshalJSON struct { + v int +} + +func (u *unmarshalJSON) UnmarshalJSON(b []byte) error { + var v int + if err := json.Unmarshal(b, &v); err != nil { + return err + } + u.v = v + return nil +} + +func Test_UnmarshalJSON(t *testing.T) { + t.Run("*struct", func(t *testing.T) { + var v unmarshalJSON + assertErr(t, json.Unmarshal([]byte(`10`), &v)) + assertEq(t, "unmarshal", v.v, 10) + }) +} diff --git a/decode_unmarshal_json.go b/decode_unmarshal_json.go new file mode 100644 index 0000000..b3fb1e2 --- /dev/null +++ b/decode_unmarshal_json.go @@ -0,0 +1,31 @@ +package json + +import ( + "unsafe" +) + +type unmarshalJSONDecoder struct { + typ *rtype +} + +func newUnmarshalJSONDecoder(typ *rtype) *unmarshalJSONDecoder { + return &unmarshalJSONDecoder{typ: typ} +} + +func (d *unmarshalJSONDecoder) decode(buf []byte, cursor int, p uintptr) (int, error) { + cursor = skipWhiteSpace(buf, cursor) + start := cursor + end, err := skipValue(buf, cursor) + if err != nil { + return 0, err + } + src := buf[start:end] + v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ + typ: d.typ, + ptr: unsafe.Pointer(p), + })) + if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { + return 0, err + } + return end, nil +} diff --git a/encode.go b/encode.go index b7696c5..e64780a 100644 --- a/encode.go +++ b/encode.go @@ -2,6 +2,7 @@ package json import ( "bytes" + "encoding" "io" "reflect" "strconv" @@ -45,10 +46,6 @@ func (m *opcodeMap) set(k uintptr, op *opcodeSet) { m.Store(k, op) } -type marshalText interface { - MarshalText() ([]byte, error) -} - var ( encPool sync.Pool cachedOpcode opcodeMap @@ -67,7 +64,7 @@ func init() { } cachedOpcode = opcodeMap{} marshalJSONType = reflect.TypeOf((*Marshaler)(nil)).Elem() - marshalTextType = reflect.TypeOf((*marshalText)(nil)).Elem() + marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() } // NewEncoder returns a new encoder that writes to w. diff --git a/encode_vm.go b/encode_vm.go index 12a065f..1849359 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -1,6 +1,7 @@ package json import ( + "encoding" "reflect" "unsafe" ) @@ -93,7 +94,7 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - bytes, err := v.(marshalText).MarshalText() + bytes, err := v.(encoding.TextMarshaler).MarshalText() if err != nil { return err } diff --git a/json.go b/json.go index e8f497e..3a81fb3 100644 --- a/json.go +++ b/json.go @@ -8,6 +8,18 @@ type Marshaler interface { MarshalJSON() ([]byte, error) } +// Unmarshaler is the interface implemented by types +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid encoding of +// a JSON value. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +// +// By convention, to approximate the behavior of Unmarshal itself, +// Unmarshalers implement UnmarshalJSON([]byte("null")) as a no-op. +type Unmarshaler interface { + UnmarshalJSON([]byte) error +} + // Marshal returns the JSON encoding of v. // // Marshal traverses the value v recursively.