Support UnmarshalJSON

This commit is contained in:
Masaaki Goshima 2020-05-08 20:22:57 +09:00
parent ad8e97ff51
commit c23e5f43a7
8 changed files with 151 additions and 56 deletions

View File

@ -2,6 +2,7 @@ package json
import ( import (
"bytes" "bytes"
"encoding"
"io" "io"
"reflect" "reflect"
"strings" "strings"
@ -50,6 +51,8 @@ func (m *decoderMap) set(k uintptr, dec decoder) {
var ( var (
cachedDecoder decoderMap cachedDecoder decoderMap
unmarshalJSONType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
unmarshalTextType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
) )
func init() { func init() {
@ -78,7 +81,7 @@ func (d *Decoder) decode(src []byte, header *interfaceHeader) error {
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
dec := cachedDecoder.get(typeptr) dec := cachedDecoder.get(typeptr)
if dec == nil { if dec == nil {
compiledDec, err := d.compile(typ.Elem()) compiledDec, err := d.compileHead(typ)
if err != nil { if err != nil {
return err return err
} }
@ -117,7 +120,7 @@ func (d *Decoder) Decode(v interface{}) error {
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
dec := cachedDecoder.get(typeptr) dec := cachedDecoder.get(typeptr)
if dec == nil { if dec == nil {
compiledDec, err := d.compile(typ.Elem()) compiledDec, err := d.compileHead(typ)
if err != nil { if err != nil {
return err return err
} }
@ -145,7 +148,21 @@ func (d *Decoder) Decode(v interface{}) error {
return nil return nil
} }
func (d *Decoder) compileHead(typ *rtype) (decoder, error) {
if typ.Implements(unmarshalJSONType) {
return newUnmarshalJSONDecoder(typ), nil
} else if typ.Implements(unmarshalTextType) {
}
return d.compile(typ.Elem())
}
func (d *Decoder) compile(typ *rtype) (decoder, error) { func (d *Decoder) compile(typ *rtype) (decoder, error) {
if typ.Implements(unmarshalJSONType) {
return newUnmarshalJSONDecoder(typ), nil
} else if typ.Implements(unmarshalTextType) {
}
switch typ.Kind() { switch typ.Kind() {
case reflect.Ptr: case reflect.Ptr:
return d.compilePtr(typ) return d.compilePtr(typ)

View File

@ -1,5 +1,9 @@
package json package json
import (
"errors"
)
var ( var (
isWhiteSpace = [256]bool{} isWhiteSpace = [256]bool{}
) )
@ -19,3 +23,61 @@ LOOP:
} }
return cursor return cursor
} }
func skipValue(buf []byte, cursor int) (int, error) {
cursor = skipWhiteSpace(buf, cursor)
braceCount := 0
bracketCount := 0
buflen := len(buf)
for {
switch buf[cursor] {
case '\000':
return cursor, errors.New("unexpected error value")
case '{':
braceCount++
case '[':
bracketCount++
case '}':
braceCount--
if braceCount == -1 && bracketCount == 0 {
return cursor, nil
}
case ']':
bracketCount--
case ',':
if bracketCount == 0 && braceCount == 0 {
return cursor, nil
}
case '"':
cursor++
for ; cursor < buflen; cursor++ {
if buf[cursor] != '"' {
continue
}
if buf[cursor-1] == '\\' {
continue
}
if bracketCount == 0 && braceCount == 0 {
return cursor + 1, nil
}
break
}
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
cursor++
for ; cursor < buflen; cursor++ {
tk := int(buf[cursor])
if (int('0') <= tk && tk <= int('9')) || tk == '.' || tk == 'e' || tk == 'E' {
continue
}
break
}
if bracketCount == 0 && braceCount == 0 {
return cursor, nil
}
continue
}
cursor++
}
return cursor, errors.New("unexpected error value")
}

View File

@ -22,52 +22,6 @@ func newStructDecoder(fieldMap map[string]*structFieldSet) *structDecoder {
} }
} }
func (d *structDecoder) skipValue(buf []byte, cursor int) (int, error) {
cursor = skipWhiteSpace(buf, cursor)
braceCount := 0
bracketCount := 0
buflen := len(buf)
for {
switch buf[cursor] {
case '\000':
return cursor, errors.New("unexpected error value")
case '{':
braceCount++
case '[':
bracketCount++
case '}':
braceCount--
if braceCount == -1 && bracketCount == 0 {
return cursor, nil
}
case ']':
bracketCount--
case ',':
if bracketCount == 0 && braceCount == 0 {
return cursor, nil
}
case '"':
cursor++
for ; cursor < buflen; cursor++ {
if buf[cursor] != '"' {
continue
}
if buf[cursor-1] == '\\' {
continue
}
if bracketCount == 0 && braceCount == 0 {
return cursor + 1, nil
}
break
}
}
cursor++
}
return cursor, errors.New("unexpected error value")
}
func (d *structDecoder) decode(buf []byte, cursor int, p uintptr) (int, error) { func (d *structDecoder) decode(buf []byte, cursor int, p uintptr) (int, error) {
buflen := len(buf) buflen := len(buf)
cursor = skipWhiteSpace(buf, cursor) cursor = skipWhiteSpace(buf, cursor)
@ -101,7 +55,7 @@ func (d *structDecoder) decode(buf []byte, cursor int, p uintptr) (int, error) {
} }
cursor = c cursor = c
} else { } else {
c, err := d.skipValue(buf, cursor) c, err := skipValue(buf, cursor)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -176,3 +176,24 @@ func Test_Decoder(t *testing.T) {
}) })
}) })
} }
type unmarshalJSON struct {
v int
}
func (u *unmarshalJSON) UnmarshalJSON(b []byte) error {
var v int
if err := json.Unmarshal(b, &v); err != nil {
return err
}
u.v = v
return nil
}
func Test_UnmarshalJSON(t *testing.T) {
t.Run("*struct", func(t *testing.T) {
var v unmarshalJSON
assertErr(t, json.Unmarshal([]byte(`10`), &v))
assertEq(t, "unmarshal", v.v, 10)
})
}

31
decode_unmarshal_json.go Normal file
View File

@ -0,0 +1,31 @@
package json
import (
"unsafe"
)
type unmarshalJSONDecoder struct {
typ *rtype
}
func newUnmarshalJSONDecoder(typ *rtype) *unmarshalJSONDecoder {
return &unmarshalJSONDecoder{typ: typ}
}
func (d *unmarshalJSONDecoder) decode(buf []byte, cursor int, p uintptr) (int, error) {
cursor = skipWhiteSpace(buf, cursor)
start := cursor
end, err := skipValue(buf, cursor)
if err != nil {
return 0, err
}
src := buf[start:end]
v := *(*interface{})(unsafe.Pointer(&interfaceHeader{
typ: d.typ,
ptr: unsafe.Pointer(p),
}))
if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil {
return 0, err
}
return end, nil
}

View File

@ -2,6 +2,7 @@ package json
import ( import (
"bytes" "bytes"
"encoding"
"io" "io"
"reflect" "reflect"
"strconv" "strconv"
@ -45,10 +46,6 @@ func (m *opcodeMap) set(k uintptr, op *opcodeSet) {
m.Store(k, op) m.Store(k, op)
} }
type marshalText interface {
MarshalText() ([]byte, error)
}
var ( var (
encPool sync.Pool encPool sync.Pool
cachedOpcode opcodeMap cachedOpcode opcodeMap
@ -67,7 +64,7 @@ func init() {
} }
cachedOpcode = opcodeMap{} cachedOpcode = opcodeMap{}
marshalJSONType = reflect.TypeOf((*Marshaler)(nil)).Elem() marshalJSONType = reflect.TypeOf((*Marshaler)(nil)).Elem()
marshalTextType = reflect.TypeOf((*marshalText)(nil)).Elem() marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
} }
// NewEncoder returns a new encoder that writes to w. // NewEncoder returns a new encoder that writes to w.

View File

@ -1,6 +1,7 @@
package json package json
import ( import (
"encoding"
"reflect" "reflect"
"unsafe" "unsafe"
) )
@ -93,7 +94,7 @@ func (e *Encoder) run(code *opcode) error {
typ: code.typ, typ: code.typ,
ptr: unsafe.Pointer(ptr), ptr: unsafe.Pointer(ptr),
})) }))
bytes, err := v.(marshalText).MarshalText() bytes, err := v.(encoding.TextMarshaler).MarshalText()
if err != nil { if err != nil {
return err return err
} }

12
json.go
View File

@ -8,6 +8,18 @@ type Marshaler interface {
MarshalJSON() ([]byte, error) MarshalJSON() ([]byte, error)
} }
// Unmarshaler is the interface implemented by types
// that can unmarshal a JSON description of themselves.
// The input can be assumed to be a valid encoding of
// a JSON value. UnmarshalJSON must copy the JSON data
// if it wishes to retain the data after returning.
//
// By convention, to approximate the behavior of Unmarshal itself,
// Unmarshalers implement UnmarshalJSON([]byte("null")) as a no-op.
type Unmarshaler interface {
UnmarshalJSON([]byte) error
}
// Marshal returns the JSON encoding of v. // Marshal returns the JSON encoding of v.
// //
// Marshal traverses the value v recursively. // Marshal traverses the value v recursively.