From f8eb061538dedfada70141b83bc937dd3f0788dd Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sun, 22 Nov 2020 02:47:18 +0900 Subject: [PATCH] Fix decoder --- decode_bytes.go | 130 +++++++++++++++++++++++++++++++++++++++ decode_compile.go | 12 +++- decode_string.go | 4 +- decode_test.go | 2 +- decode_unmarshal_json.go | 51 +++++++++++---- decode_unmarshal_text.go | 4 +- 6 files changed, 185 insertions(+), 18 deletions(-) create mode 100644 decode_bytes.go diff --git a/decode_bytes.go b/decode_bytes.go new file mode 100644 index 0000000..7a2f89d --- /dev/null +++ b/decode_bytes.go @@ -0,0 +1,130 @@ +package json + +import ( + "encoding/base64" + "unsafe" +) + +type bytesDecoder struct{} + +func newBytesDecoder() *bytesDecoder { + return &bytesDecoder{} +} + +func (d *bytesDecoder) decodeStream(s *stream, p unsafe.Pointer) error { + bytes, err := d.decodeStreamBinary(s) + if err != nil { + return err + } + decodedLen := base64.StdEncoding.DecodedLen(len(bytes)) + buf := make([]byte, decodedLen) + if _, err := base64.StdEncoding.Decode(buf, bytes); err != nil { + return err + } + *(*[]byte)(p) = buf + return nil +} + +func (d *bytesDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { + bytes, c, err := d.decodeBinary(buf, cursor) + if err != nil { + return 0, err + } + cursor = c + decodedLen := base64.StdEncoding.DecodedLen(len(bytes)) + b := make([]byte, decodedLen) + if _, err := base64.StdEncoding.Decode(b, bytes); err != nil { + return 0, err + } + *(*[]byte)(p) = b + return cursor, nil +} + +func binaryBytes(s *stream) ([]byte, error) { + s.cursor++ + start := s.cursor + for { + switch s.char() { + case '"': + literal := s.buf[start:s.cursor] + s.cursor++ + s.reset() + return literal, nil + case nul: + if s.read() { + continue + } + goto ERROR + } + s.cursor++ + } +ERROR: + return nil, errUnexpectedEndOfJSON("[]byte", s.totalOffset()) +} + +func (d *bytesDecoder) decodeStreamBinary(s *stream) ([]byte, error) { + for { + switch s.char() { + case ' ', '\n', '\t', '\r': + s.cursor++ + continue + case '"': + return binaryBytes(s) + case 'n': + if err := nullBytes(s); err != nil { + return nil, err + } + return []byte{}, nil + case nul: + if s.read() { + continue + } + } + break + } + return nil, errNotAtBeginningOfValue(s.totalOffset()) +} + +func (d *bytesDecoder) decodeBinary(buf []byte, cursor int64) ([]byte, int64, error) { + for { + switch buf[cursor] { + case ' ', '\n', '\t', '\r': + cursor++ + case '"': + cursor++ + start := cursor + for { + switch buf[cursor] { + case '"': + literal := buf[start:cursor] + cursor++ + return literal, cursor, nil + case nul: + return nil, 0, errUnexpectedEndOfJSON("[]byte", cursor) + } + cursor++ + } + return nil, 0, errUnexpectedEndOfJSON("[]byte", cursor) + case 'n': + buflen := int64(len(buf)) + if cursor+3 >= buflen { + return nil, 0, errUnexpectedEndOfJSON("null", cursor) + } + if buf[cursor+1] != 'u' { + return nil, 0, errInvalidCharacter(buf[cursor+1], "null", cursor) + } + if buf[cursor+2] != 'l' { + return nil, 0, errInvalidCharacter(buf[cursor+2], "null", cursor) + } + if buf[cursor+3] != 'l' { + return nil, 0, errInvalidCharacter(buf[cursor+3], "null", cursor) + } + cursor += 4 + return []byte{}, cursor, nil + default: + goto ERROR + } + } +ERROR: + return nil, 0, errNotAtBeginningOfValue(cursor) +} diff --git a/decode_compile.go b/decode_compile.go index c898eb1..15f0bee 100644 --- a/decode_compile.go +++ b/decode_compile.go @@ -10,7 +10,7 @@ func (d *Decoder) compileHead(typ *rtype) (decoder, error) { switch { case typ.Implements(unmarshalJSONType): return newUnmarshalJSONDecoder(typ), nil - case rtype_ptrTo(typ).Implements(marshalJSONType): + case rtype_ptrTo(typ).Implements(unmarshalJSONType): return newUnmarshalJSONDecoder(rtype_ptrTo(typ)), nil case typ.Implements(unmarshalTextType): return newUnmarshalTextDecoder(typ), nil @@ -24,7 +24,7 @@ func (d *Decoder) compile(typ *rtype) (decoder, error) { switch { case typ.Implements(unmarshalJSONType): return newUnmarshalJSONDecoder(typ), nil - case rtype_ptrTo(typ).Implements(marshalJSONType): + case rtype_ptrTo(typ).Implements(unmarshalJSONType): return newUnmarshalJSONDecoder(rtype_ptrTo(typ)), nil case typ.Implements(unmarshalTextType): return newUnmarshalTextDecoder(typ), nil @@ -38,6 +38,10 @@ func (d *Decoder) compile(typ *rtype) (decoder, error) { case reflect.Struct: return d.compileStruct(typ) case reflect.Slice: + elem := typ.Elem() + if elem.Kind() == reflect.Uint8 { + return d.compileBytes() + } return d.compileSlice(typ) case reflect.Array: return d.compileArray(typ) @@ -167,6 +171,10 @@ func (d *Decoder) compileBool() (decoder, error) { return newBoolDecoder(), nil } +func (d *Decoder) compileBytes() (decoder, error) { + return newBytesDecoder(), nil +} + func (d *Decoder) compileSlice(typ *rtype) (decoder, error) { elem := typ.Elem() decoder, err := d.compile(elem) diff --git a/decode_string.go b/decode_string.go index e3ba9db..f017209 100644 --- a/decode_string.go +++ b/decode_string.go @@ -115,7 +115,9 @@ func stringBytes(s *stream) ([]byte, error) { for { switch s.char() { case '\\': - s.cursor++ + if err := decodeEscapeString(s); err != nil { + return nil, err + } case '"': literal := s.buf[start:s.cursor] s.cursor++ diff --git a/decode_test.go b/decode_test.go index 1c109e4..8681320 100644 --- a/decode_test.go +++ b/decode_test.go @@ -265,7 +265,7 @@ func Test_UnmarshalJSON(t *testing.T) { t.Run("*struct", func(t *testing.T) { var v unmarshalJSON assertErr(t, json.Unmarshal([]byte(`10`), &v)) - assertEq(t, "unmarshal", v.v, 10) + assertEq(t, "unmarshal", 10, v.v) }) } diff --git a/decode_unmarshal_json.go b/decode_unmarshal_json.go index 0c77110..c3274e6 100644 --- a/decode_unmarshal_json.go +++ b/decode_unmarshal_json.go @@ -5,7 +5,8 @@ import ( ) type unmarshalJSONDecoder struct { - typ *rtype + typ *rtype + isDoublePointer bool } func newUnmarshalJSONDecoder(typ *rtype) *unmarshalJSONDecoder { @@ -19,12 +20,24 @@ func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return err } src := s.buf[start:s.cursor] - v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ - typ: d.typ, - ptr: *(*unsafe.Pointer)(unsafe.Pointer(&p)), - })) - if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { - return err + if d.isDoublePointer { + newptr := unsafe_New(d.typ.Elem()) + v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ + typ: d.typ, + ptr: newptr, + })) + if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { + return err + } + *(*unsafe.Pointer)(p) = newptr + } else { + v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ + typ: d.typ, + ptr: p, + })) + if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { + return err + } } return nil } @@ -37,12 +50,24 @@ func (d *unmarshalJSONDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer return 0, err } src := buf[start:end] - v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ - typ: d.typ, - ptr: *(*unsafe.Pointer)(unsafe.Pointer(&p)), - })) - if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { - return 0, err + if d.isDoublePointer { + newptr := unsafe_New(d.typ.Elem()) + v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ + typ: d.typ, + ptr: newptr, + })) + if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { + return 0, err + } + *(*unsafe.Pointer)(p) = newptr + } else { + v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ + typ: d.typ, + ptr: p, + })) + if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { + return 0, err + } } return end, nil } diff --git a/decode_unmarshal_text.go b/decode_unmarshal_text.go index 4960028..5b25834 100644 --- a/decode_unmarshal_text.go +++ b/decode_unmarshal_text.go @@ -26,13 +26,15 @@ func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if s, ok := unquoteBytes(src); ok { src = s } + newptr := unsafe_New(d.typ.Elem()) v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ typ: d.typ, - ptr: *(*unsafe.Pointer)(unsafe.Pointer(&p)), + ptr: newptr, })) if err := v.(encoding.TextUnmarshaler).UnmarshalText(src); err != nil { return err } + *(*unsafe.Pointer)(p) = newptr return nil }