From 4332d1353e551ca02e26e9622e6d80a2ca52646c Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sat, 5 Dec 2020 22:27:33 +0900 Subject: [PATCH] Fix stream decoding --- decode.go | 2 +- decode_bytes.go | 8 +++-- decode_int.go | 7 ++-- decode_interface.go | 8 ++--- decode_map.go | 7 ++-- decode_number.go | 4 +-- decode_slice.go | 4 +-- decode_stream.go | 71 +++++++++++++++------------------------- decode_string.go | 4 +-- decode_struct.go | 7 ++-- decode_unmarshal_json.go | 6 ++-- decode_unmarshal_text.go | 9 +++-- decode_wrapped_string.go | 22 +++---------- 13 files changed, 67 insertions(+), 92 deletions(-) diff --git a/decode.go b/decode.go index 109a643..20e2e18 100644 --- a/decode.go +++ b/decode.go @@ -60,7 +60,7 @@ const ( // The decoder introduces its own buffering and may // read data from r beyond the JSON values requested. func NewDecoder(r io.Reader) *Decoder { - s := &stream{r: r} + s := newStream(r) s.read() return &Decoder{ s: s, diff --git a/decode_bytes.go b/decode_bytes.go index 4a9e612..be2423b 100644 --- a/decode_bytes.go +++ b/decode_bytes.go @@ -19,12 +19,17 @@ func (d *bytesDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } + if bytes == nil { + s.reset() + return nil + } decodedLen := base64.StdEncoding.DecodedLen(len(bytes)) buf := make([]byte, decodedLen) if _, err := base64.StdEncoding.Decode(buf, bytes); err != nil { return err } *(*[]byte)(p) = buf + s.reset() return nil } @@ -51,7 +56,6 @@ func binaryBytes(s *stream) ([]byte, error) { case '"': literal := s.buf[start:s.cursor] s.cursor++ - s.reset() return literal, nil case nul: if s.read() { @@ -77,7 +81,7 @@ func (d *bytesDecoder) decodeStreamBinary(s *stream) ([]byte, error) { if err := nullBytes(s); err != nil { return nil, err } - return []byte{}, nil + return nil, nil case nul: if s.read() { continue diff --git a/decode_int.go b/decode_int.go index e75befd..ee6886c 100644 --- a/decode_int.go +++ b/decode_int.go @@ -1,6 +1,8 @@ package json -import "unsafe" +import ( + "unsafe" +) type intDecoder struct { op func(unsafe.Pointer, int64) @@ -74,7 +76,6 @@ func (d *intDecoder) decodeStreamByte(s *stream) ([]byte, error) { break } num := s.buf[start:s.cursor] - s.reset() if len(num) < 2 { goto ERROR } @@ -94,7 +95,6 @@ func (d *intDecoder) decodeStreamByte(s *stream) ([]byte, error) { break } num := s.buf[start:s.cursor] - s.reset() return num, nil case nul: if s.read() { @@ -138,6 +138,7 @@ func (d *intDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return err } d.op(p, d.parseInt(bytes)) + s.reset() return nil } diff --git a/decode_interface.go b/decode_interface.go index b60c7c8..89aa559 100644 --- a/decode_interface.go +++ b/decode_interface.go @@ -52,7 +52,7 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { ).decodeStream(s, ptr); err != nil { return err } - **(**interface{})(unsafe.Pointer(&p)) = v + *(*interface{})(p) = v return nil case '[': var v []interface{} @@ -66,7 +66,7 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { ).decodeStream(s, ptr); err != nil { return err } - **(**interface{})(unsafe.Pointer(&p)) = v + *(*interface{})(p) = v return nil case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': return d.numDecoder(s).decodeStream(s, p) @@ -82,7 +82,7 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { case '"': literal := s.buf[start:s.cursor] s.cursor++ - **(**interface{})(unsafe.Pointer(&p)) = *(*string)(unsafe.Pointer(&literal)) + *(*interface{})(p) = string(literal) return nil case nul: if s.read() { @@ -109,7 +109,7 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err := nullBytes(s); err != nil { return err } - **(**interface{})(unsafe.Pointer(&p)) = nil + *(*interface{})(p) = nil return nil case nul: if s.read() { diff --git a/decode_map.go b/decode_map.go index d921ab4..4aa8e5e 100644 --- a/decode_map.go +++ b/decode_map.go @@ -64,8 +64,8 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { s.skipWhiteSpace() mapValue := makemap(d.mapType, 0) if s.buf[s.cursor+1] == '}' { - **(**unsafe.Pointer)(unsafe.Pointer(&p)) = mapValue - s.cursor++ + *(*unsafe.Pointer)(p) = mapValue + s.cursor += 2 return nil } for { @@ -82,9 +82,6 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return errExpected("colon after object key", s.totalOffset()) } s.cursor++ - if s.end() { - return errUnexpectedEndOfJSON("map", s.totalOffset()) - } var value interface{} if err := d.setValueStream(s, &value); err != nil { return err diff --git a/decode_number.go b/decode_number.go index 6f4f166..cf36979 100644 --- a/decode_number.go +++ b/decode_number.go @@ -25,8 +25,8 @@ func (d *numberDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } - str := *(*string)(unsafe.Pointer(&bytes)) - d.op(p, Number(str)) + d.op(p, Number(string(bytes))) + s.reset() return nil } diff --git a/decode_slice.go b/decode_slice.go index 0e59f0f..21489be 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -74,7 +74,7 @@ func (d *sliceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { s.cursor++ s.skipWhiteSpace() if s.char() == ']' { - **(**sliceHeader)(unsafe.Pointer(&p)) = sliceHeader{ + *(*sliceHeader)(p) = sliceHeader{ data: newArray(d.elemType, 0), len: 0, cap: 0, @@ -116,7 +116,7 @@ func (d *sliceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { len: slice.len, cap: slice.cap, }) - **(**sliceHeader)(unsafe.Pointer(&p)) = dst + *(*sliceHeader)(p) = dst d.releaseSlice(slice) s.cursor++ return nil diff --git a/decode_stream.go b/decode_stream.go index 063dcd7..c5d203b 100644 --- a/decode_stream.go +++ b/decode_stream.go @@ -6,20 +6,29 @@ import ( ) const ( - readChunkSize = 512 + initBufSize = 512 ) type stream struct { buf []byte + bufSize int64 length int64 r io.Reader offset int64 cursor int64 + readPos int64 allRead bool useNumber bool disallowUnknownFields bool } +func newStream(r io.Reader) *stream { + return &stream{ + r: r, + bufSize: initBufSize, + } +} + func (s *stream) buffered() io.Reader { return bytes.NewReader(s.buf[s.cursor:]) } @@ -36,59 +45,33 @@ func (s *stream) char() byte { return s.buf[s.cursor] } -func (s *stream) end() bool { - return s.allRead && s.length <= s.cursor -} - -func (s *stream) progressN(n int64) bool { - if s.cursor+n < s.length-1 || s.read() { - s.cursor += n - return true - } - s.cursor = s.length - return false -} - func (s *stream) reset() { + s.offset += s.cursor s.buf = s.buf[s.cursor:] - s.length -= s.cursor s.cursor = 0 + s.length = int64(len(s.buf)) +} + +func (s *stream) readBuf() []byte { + s.bufSize *= 2 + remainBuf := s.buf + s.buf = make([]byte, s.bufSize) + copy(s.buf, remainBuf) + return s.buf[s.cursor:] } func (s *stream) read() bool { if s.allRead { return false } - buf := make([]byte, readChunkSize) - n, err := s.r.Read(buf) - if err != nil && err != io.EOF { - return false - } - if n < readChunkSize || err == io.EOF { + buf := s.readBuf() + last := len(buf) - 1 + buf[last] = nul + n, err := s.r.Read(buf[:last]) + s.length = s.cursor + int64(n) + if n < last || err == io.EOF { s.allRead = true - } - // extend buffer (2) is protect ( s.cursor++ x2 ) - // e.g.) decodeEscapeString - const extendBufLength = int64(2) - - totalSize := s.length + int64(n) + extendBufLength - if totalSize > readChunkSize { - newBuf := make([]byte, totalSize) - copy(newBuf, s.buf) - copy(newBuf[s.length:], buf) - s.buf = newBuf - s.length = totalSize - extendBufLength - } else if s.length > 0 { - copy(buf[s.length:], buf) - copy(buf, s.buf[:s.length]) - s.buf = buf - s.length = totalSize - extendBufLength - } else { - s.buf = buf - s.length = totalSize - extendBufLength - } - s.offset += s.cursor - if n == 0 { + } else if err != nil { return false } return true diff --git a/decode_string.go b/decode_string.go index 82b301a..3cb469d 100644 --- a/decode_string.go +++ b/decode_string.go @@ -34,7 +34,8 @@ func (d *stringDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } - **(**string)(unsafe.Pointer(&p)) = *(*string)(unsafe.Pointer(&bytes)) + *(*string)(p) = string(bytes) + s.reset() return nil } @@ -160,7 +161,6 @@ func stringBytes(s *stream) ([]byte, error) { case '"': literal := s.buf[start:s.cursor] s.cursor++ - s.reset() return literal, nil case nul: if s.read() { diff --git a/decode_struct.go b/decode_struct.go index 10f0330..8eeca73 100644 --- a/decode_struct.go +++ b/decode_struct.go @@ -53,10 +53,9 @@ func (d *structDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } s.cursor++ if s.char() == nul { - s.read() - } - if s.end() { - return errExpected("object value after colon", s.totalOffset()) + if !s.read() { + return errExpected("object value after colon", s.totalOffset()) + } } k := *(*string)(unsafe.Pointer(&key)) field, exists := d.fieldMap[k] diff --git a/decode_unmarshal_json.go b/decode_unmarshal_json.go index 4e7af4c..d951e75 100644 --- a/decode_unmarshal_json.go +++ b/decode_unmarshal_json.go @@ -36,13 +36,15 @@ func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return err } src := s.buf[start:s.cursor] + dst := make([]byte, len(src)) + copy(dst, src) if d.isDoublePointer { newptr := unsafe_New(d.typ.Elem()) v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ typ: d.typ, ptr: newptr, })) - if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { + if err := v.(Unmarshaler).UnmarshalJSON(dst); err != nil { d.annotateError(s.cursor, err) return err } @@ -52,7 +54,7 @@ func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error { typ: d.typ, ptr: p, })) - if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { + if err := v.(Unmarshaler).UnmarshalJSON(dst); err != nil { d.annotateError(s.cursor, err) return err } diff --git a/decode_unmarshal_text.go b/decode_unmarshal_text.go index 3656a8a..170f95a 100644 --- a/decode_unmarshal_text.go +++ b/decode_unmarshal_text.go @@ -39,15 +39,18 @@ func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return err } src := s.buf[start:s.cursor] - if s, ok := unquoteBytes(src); ok { - src = s + dst := make([]byte, len(src)) + copy(dst, src) + + if b, ok := unquoteBytes(dst); ok { + dst = b } newptr := unsafe_New(d.typ.Elem()) v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ typ: d.typ, ptr: newptr, })) - if err := v.(encoding.TextUnmarshaler).UnmarshalText(src); err != nil { + if err := v.(encoding.TextUnmarshaler).UnmarshalText(dst); err != nil { d.annotateError(s.cursor, err) return err } diff --git a/decode_wrapped_string.go b/decode_wrapped_string.go index 94b5956..2a09870 100644 --- a/decode_wrapped_string.go +++ b/decode_wrapped_string.go @@ -23,25 +23,11 @@ func (d *wrappedStringDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } - - // save current state - buf := s.buf - length := s.length - cursor := s.cursor - - // set content in string to stream - bytes = append(bytes, nul) - s.buf = bytes - s.cursor = 0 - s.length = int64(len(bytes)) - if err := d.dec.decodeStream(s, p); err != nil { - return nil + b := make([]byte, len(bytes)+1) + copy(b, bytes) + if _, err := d.dec.decode(b, 0, p); err != nil { + return err } - - // restore state - s.buf = buf - s.length = length - s.cursor = cursor return nil }