From 74276c6af3855b5fd4e984a2def3b1f572923d53 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 20 Aug 2020 12:38:50 +0900 Subject: [PATCH] Fix decoder for string tag --- decode_compile.go | 53 ++++++++---------- decode_interface.go | 79 -------------------------- decode_string.go | 116 +++++++++++++++++++++++++++++++++++++++ decode_wrapped_string.go | 52 ++++++++++++++++++ encode_compile.go | 44 +-------------- json.go | 9 +++ struct_field.go | 45 +++++++++++++++ 7 files changed, 246 insertions(+), 152 deletions(-) create mode 100644 decode_wrapped_string.go create mode 100644 struct_field.go diff --git a/decode_compile.go b/decode_compile.go index 71eed28..bb6ab06 100644 --- a/decode_compile.go +++ b/decode_compile.go @@ -7,20 +7,31 @@ import ( ) func (d *Decoder) compileHead(typ *rtype) (decoder, error) { - if typ.Implements(unmarshalJSONType) { + switch { + case typ.Implements(unmarshalJSONType): return newUnmarshalJSONDecoder(typ), nil - } else if typ.Implements(unmarshalTextType) { + case rtype_ptrTo(typ).Implements(marshalJSONType): + return newUnmarshalJSONDecoder(rtype_ptrTo(typ)), nil + case typ.Implements(unmarshalTextType): return newUnmarshalTextDecoder(typ), nil + case rtype_ptrTo(typ).Implements(unmarshalTextType): + return newUnmarshalTextDecoder(rtype_ptrTo(typ)), nil } return d.compile(typ.Elem()) } func (d *Decoder) compile(typ *rtype) (decoder, error) { - if typ.Implements(unmarshalJSONType) { + switch { + case typ.Implements(unmarshalJSONType): return newUnmarshalJSONDecoder(typ), nil - } else if typ.Implements(unmarshalTextType) { + case rtype_ptrTo(typ).Implements(marshalJSONType): + return newUnmarshalJSONDecoder(rtype_ptrTo(typ)), nil + case typ.Implements(unmarshalTextType): return newUnmarshalTextDecoder(typ), nil + case rtype_ptrTo(typ).Implements(unmarshalTextType): + return newUnmarshalTextDecoder(rtype_ptrTo(typ)), nil } + switch typ.Kind() { case reflect.Ptr: return d.compilePtr(typ) @@ -190,46 +201,26 @@ func (d *Decoder) compileInterface(typ *rtype) (decoder, error) { return newInterfaceDecoder(typ), nil } -func (d *Decoder) getTag(field reflect.StructField) string { - return field.Tag.Get("json") -} - -func (d *Decoder) isIgnoredStructField(field reflect.StructField) bool { - if field.PkgPath != "" && !field.Anonymous { - // private field - return true - } - tag := d.getTag(field) - if tag == "-" { - return true - } - return false -} - func (d *Decoder) compileStruct(typ *rtype) (decoder, error) { fieldNum := typ.NumField() fieldMap := map[string]*structFieldSet{} for i := 0; i < fieldNum; i++ { field := typ.Field(i) - if d.isIgnoredStructField(field) { + if isIgnoredStructField(field) { continue } - keyName := field.Name - tag := d.getTag(field) - opts := strings.Split(tag, ",") - if len(opts) > 0 { - if opts[0] != "" { - keyName = opts[0] - } - } + tag := structTagFromField(field) dec, err := d.compile(type2rtype(field.Type)) if err != nil { return nil, err } + if tag.isString { + dec = newWrappedStringDecoder(dec) + } fieldSet := &structFieldSet{dec: dec, offset: field.Offset} fieldMap[field.Name] = fieldSet - fieldMap[keyName] = fieldSet - fieldMap[strings.ToLower(keyName)] = fieldSet + fieldMap[tag.key] = fieldSet + fieldMap[strings.ToLower(tag.key)] = fieldSet } return newStructDecoder(fieldMap), nil } diff --git a/decode_interface.go b/decode_interface.go index d9218c7..bc8d39a 100644 --- a/decode_interface.go +++ b/decode_interface.go @@ -33,85 +33,6 @@ var ( ) ) -var ( - hexToInt = [256]int{ - '0': 0, - '1': 1, - '2': 2, - '3': 3, - '4': 4, - '5': 5, - '6': 6, - '7': 7, - '8': 8, - '9': 9, - 'A': 10, - 'B': 11, - 'C': 12, - 'D': 13, - 'E': 14, - 'F': 15, - 'a': 10, - 'b': 11, - 'c': 12, - 'd': 13, - 'e': 14, - 'f': 15, - } -) - -func unicodeToRune(code []byte) rune { - sum := 0 - for i := 0; i < len(code); i++ { - sum += hexToInt[code[i]] << (uint(len(code)-i-1) * 4) - } - return rune(sum) -} - -func decodeEscapeString(s *stream) error { - s.cursor++ -RETRY: - switch s.buf[s.cursor] { - case '"': - s.buf[s.cursor] = '"' - case '\\': - s.buf[s.cursor] = '\\' - case '/': - s.buf[s.cursor] = '/' - case 'b': - s.buf[s.cursor] = '\b' - case 'f': - s.buf[s.cursor] = '\f' - case 'n': - s.buf[s.cursor] = '\n' - case 'r': - s.buf[s.cursor] = '\r' - case 't': - s.buf[s.cursor] = '\t' - case 'u': - if s.cursor+5 >= s.length { - if !s.read() { - return errInvalidCharacter(s.char(), "escaped string", s.totalOffset()) - } - } - code := unicodeToRune(s.buf[s.cursor+1 : s.cursor+5]) - unicode := []byte(string(code)) - s.buf = append(append(s.buf[:s.cursor-1], unicode...), s.buf[s.cursor+5:]...) - s.cursor-- - return nil - case nul: - if !s.read() { - return errInvalidCharacter(s.char(), "escaped string", s.totalOffset()) - } - goto RETRY - default: - return errUnexpectedEndOfJSON("string", s.totalOffset()) - } - s.buf = append(s.buf[:s.cursor-1], s.buf[s.cursor:]...) - s.cursor-- - return nil -} - func (d *interfaceDecoder) decodeStream(s *stream, p uintptr) error { s.skipWhiteSpace() for { diff --git a/decode_string.go b/decode_string.go index 2512271..ea3855d 100644 --- a/decode_string.go +++ b/decode_string.go @@ -30,6 +30,85 @@ func (d *stringDecoder) decode(buf []byte, cursor int64, p uintptr) (int64, erro return cursor, nil } +var ( + hexToInt = [256]int{ + '0': 0, + '1': 1, + '2': 2, + '3': 3, + '4': 4, + '5': 5, + '6': 6, + '7': 7, + '8': 8, + '9': 9, + 'A': 10, + 'B': 11, + 'C': 12, + 'D': 13, + 'E': 14, + 'F': 15, + 'a': 10, + 'b': 11, + 'c': 12, + 'd': 13, + 'e': 14, + 'f': 15, + } +) + +func unicodeToRune(code []byte) rune { + sum := 0 + for i := 0; i < len(code); i++ { + sum += hexToInt[code[i]] << (uint(len(code)-i-1) * 4) + } + return rune(sum) +} + +func decodeEscapeString(s *stream) error { + s.cursor++ +RETRY: + switch s.buf[s.cursor] { + case '"': + s.buf[s.cursor] = '"' + case '\\': + s.buf[s.cursor] = '\\' + case '/': + s.buf[s.cursor] = '/' + case 'b': + s.buf[s.cursor] = '\b' + case 'f': + s.buf[s.cursor] = '\f' + case 'n': + s.buf[s.cursor] = '\n' + case 'r': + s.buf[s.cursor] = '\r' + case 't': + s.buf[s.cursor] = '\t' + case 'u': + if s.cursor+5 >= s.length { + if !s.read() { + return errInvalidCharacter(s.char(), "escaped string", s.totalOffset()) + } + } + code := unicodeToRune(s.buf[s.cursor+1 : s.cursor+5]) + unicode := []byte(string(code)) + s.buf = append(append(s.buf[:s.cursor-1], unicode...), s.buf[s.cursor+5:]...) + s.cursor-- + return nil + case nul: + if !s.read() { + return errInvalidCharacter(s.char(), "escaped string", s.totalOffset()) + } + goto RETRY + default: + return errUnexpectedEndOfJSON("string", s.totalOffset()) + } + s.buf = append(s.buf[:s.cursor-1], s.buf[s.cursor:]...) + s.cursor-- + return nil +} + func stringBytes(s *stream) ([]byte, error) { s.cursor++ start := s.cursor @@ -111,6 +190,43 @@ func (d *stringDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, err switch buf[cursor] { case '\\': cursor++ + switch buf[cursor] { + case '"': + buf[cursor] = '"' + buf = append(buf[:cursor-1], buf[cursor:]...) + case '\\': + buf[cursor] = '\\' + buf = append(buf[:cursor-1], buf[cursor:]...) + case '/': + buf[cursor] = '/' + buf = append(buf[:cursor-1], buf[cursor:]...) + case 'b': + buf[cursor] = '\b' + buf = append(buf[:cursor-1], buf[cursor:]...) + case 'f': + buf[cursor] = '\f' + buf = append(buf[:cursor-1], buf[cursor:]...) + case 'n': + buf[cursor] = '\n' + buf = append(buf[:cursor-1], buf[cursor:]...) + case 'r': + buf[cursor] = '\r' + buf = append(buf[:cursor-1], buf[cursor:]...) + case 't': + buf[cursor] = '\t' + buf = append(buf[:cursor-1], buf[cursor:]...) + case 'u': + buflen := int64(len(buf)) + if cursor+5 >= buflen { + return nil, 0, errUnexpectedEndOfJSON("escaped string", cursor) + } + code := unicodeToRune(buf[cursor+1 : cursor+5]) + unicode := []byte(string(code)) + buf = append(append(buf[:cursor-1], unicode...), buf[cursor+5:]...) + default: + return nil, 0, errUnexpectedEndOfJSON("escaped string", cursor) + } + continue case '"': literal := buf[start:cursor] cursor++ diff --git a/decode_wrapped_string.go b/decode_wrapped_string.go new file mode 100644 index 0000000..506f64b --- /dev/null +++ b/decode_wrapped_string.go @@ -0,0 +1,52 @@ +package json + +type wrappedStringDecoder struct { + dec decoder + stringDecoder *stringDecoder +} + +func newWrappedStringDecoder(dec decoder) *wrappedStringDecoder { + return &wrappedStringDecoder{ + dec: dec, + stringDecoder: newStringDecoder(), + } +} + +func (d *wrappedStringDecoder) decodeStream(s *stream, p uintptr) error { + bytes, err := d.stringDecoder.decodeStreamByte(s) + if err != nil { + return err + } + + // save current state + buf := s.buf + length := s.length + cursor := s.cursor + + // set content in string to stream + bytes = append(bytes, nul) + s.buf = bytes + s.cursor = 0 + s.length = int64(len(bytes)) + if err := d.dec.decodeStream(s, p); err != nil { + return nil + } + + // restore state + s.buf = buf + s.length = length + s.cursor = cursor + return nil +} + +func (d *wrappedStringDecoder) decode(buf []byte, cursor int64, p uintptr) (int64, error) { + bytes, c, err := d.stringDecoder.decodeByte(buf, cursor) + if err != nil { + return 0, err + } + bytes = append(bytes, nul) + if _, err := d.dec.decode(bytes, 0, p); err != nil { + return 0, err + } + return c, nil +} diff --git a/encode_compile.go b/encode_compile.go index ce028bd..730d695 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -3,7 +3,6 @@ package json import ( "fmt" "reflect" - "strings" "unsafe" ) @@ -341,22 +340,6 @@ func (e *Encoder) compileMap(typ *rtype, withLoad, root, withIndent bool) (*opco return (*opcode)(unsafe.Pointer(header)), nil } -func (e *Encoder) getTag(field reflect.StructField) string { - return field.Tag.Get("json") -} - -func (e *Encoder) isIgnoredStructField(field reflect.StructField) bool { - if field.PkgPath != "" && !field.Anonymous { - // private field - return true - } - tag := e.getTag(field) - if tag == "-" { - return true - } - return false -} - func (e *Encoder) typeToHeaderType(op opType) opType { switch op { case opInt: @@ -487,29 +470,6 @@ func (e *Encoder) compiledCode(typ *rtype, withIndent bool) *opcode { return nil } -type structTag struct { - key string - isOmitEmpty bool - isString bool -} - -func (e *Encoder) structTagFromField(field reflect.StructField) *structTag { - keyName := field.Name - tag := e.getTag(field) - opts := strings.Split(tag, ",") - if len(opts) > 0 { - if opts[0] != "" { - keyName = opts[0] - } - } - st := &structTag{key: keyName} - if len(opts) > 1 { - st.isOmitEmpty = opts[1] == "omitempty" - st.isString = opts[1] == "string" - } - return st -} - func (e *Encoder) structHeader(fieldCode *structFieldCode, valueCode *opcode, tag *structTag, withIndent bool) *opcode { fieldCode.indent-- op := e.optimizeStructHeader(valueCode.op, tag, withIndent) @@ -561,10 +521,10 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco e.indent++ for i := 0; i < fieldNum; i++ { field := typ.Field(i) - if e.isIgnoredStructField(field) { + if isIgnoredStructField(field) { continue } - tag := e.structTagFromField(field) + tag := structTagFromField(field) fieldType := type2rtype(field.Type) if isPtr && i == 0 { // head field of pointer structure at top level diff --git a/json.go b/json.go index 673fb3b..0b49c31 100644 --- a/json.go +++ b/json.go @@ -300,6 +300,15 @@ func (n Number) MarshalJSON() ([]byte, error) { return []byte(n), nil } +func (n *Number) UnmarshalJSON(b []byte) error { + s := string(b) + if _, err := strconv.ParseFloat(s, 64); err != nil { + return err + } + *n = Number(s) + return nil +} + // RawMessage is a raw encoded JSON value. // It implements Marshaler and Unmarshaler and can // be used to delay JSON decoding or precompute a JSON encoding. diff --git a/struct_field.go b/struct_field.go new file mode 100644 index 0000000..3352da1 --- /dev/null +++ b/struct_field.go @@ -0,0 +1,45 @@ +package json + +import ( + "reflect" + "strings" +) + +func getTag(field reflect.StructField) string { + return field.Tag.Get("json") +} + +func isIgnoredStructField(field reflect.StructField) bool { + if field.PkgPath != "" && !field.Anonymous { + // private field + return true + } + tag := getTag(field) + if tag == "-" { + return true + } + return false +} + +type structTag struct { + key string + isOmitEmpty bool + isString bool +} + +func structTagFromField(field reflect.StructField) *structTag { + keyName := field.Name + tag := getTag(field) + opts := strings.Split(tag, ",") + if len(opts) > 0 { + if opts[0] != "" { + keyName = opts[0] + } + } + st := &structTag{key: keyName} + if len(opts) > 1 { + st.isOmitEmpty = opts[1] == "omitempty" + st.isString = opts[1] == "string" + } + return st +}