From d7b9036e885c1efadab1532bc012855a1a17a0ca Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Mon, 21 Dec 2020 15:48:57 +0900 Subject: [PATCH] Fix overflow handling for int/uint decoder --- decode_compile.go | 32 ++++++++++---------- decode_int.go | 47 +++++++++++++++++++++-------- decode_uint.go | 64 ++++++++++++++++++++++++++++++++++++---- decode_wrapped_string.go | 4 ++- 4 files changed, 112 insertions(+), 35 deletions(-) diff --git a/decode_compile.go b/decode_compile.go index 3fd8396..c23f7b4 100644 --- a/decode_compile.go +++ b/decode_compile.go @@ -42,7 +42,7 @@ func (d *Decoder) compile(typ *rtype, structName, fieldName string) (decoder, er case reflect.Interface: return d.compileInterface(typ, structName, fieldName) case reflect.Uintptr: - return d.compileUint(structName, fieldName) + return d.compileUint(typ, structName, fieldName) case reflect.Int: return d.compileInt(typ, structName, fieldName) case reflect.Int8: @@ -54,15 +54,15 @@ func (d *Decoder) compile(typ *rtype, structName, fieldName string) (decoder, er case reflect.Int64: return d.compileInt64(typ, structName, fieldName) case reflect.Uint: - return d.compileUint(structName, fieldName) + return d.compileUint(typ, structName, fieldName) case reflect.Uint8: - return d.compileUint8(structName, fieldName) + return d.compileUint8(typ, structName, fieldName) case reflect.Uint16: - return d.compileUint16(structName, fieldName) + return d.compileUint16(typ, structName, fieldName) case reflect.Uint32: - return d.compileUint32(structName, fieldName) + return d.compileUint32(typ, structName, fieldName) case reflect.Uint64: - return d.compileUint64(structName, fieldName) + return d.compileUint64(typ, structName, fieldName) case reflect.String: return d.compileString(structName, fieldName) case reflect.Bool: @@ -134,32 +134,32 @@ func (d *Decoder) compileInt64(typ *rtype, structName, fieldName string) (decode }), nil } -func (d *Decoder) compileUint(structName, fieldName string) (decoder, error) { - return newUintDecoder(structName, fieldName, func(p unsafe.Pointer, v uint64) { +func (d *Decoder) compileUint(typ *rtype, structName, fieldName string) (decoder, error) { + return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) { *(*uint)(p) = uint(v) }), nil } -func (d *Decoder) compileUint8(structName, fieldName string) (decoder, error) { - return newUintDecoder(structName, fieldName, func(p unsafe.Pointer, v uint64) { +func (d *Decoder) compileUint8(typ *rtype, structName, fieldName string) (decoder, error) { + return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) { *(*uint8)(p) = uint8(v) }), nil } -func (d *Decoder) compileUint16(structName, fieldName string) (decoder, error) { - return newUintDecoder(structName, fieldName, func(p unsafe.Pointer, v uint64) { +func (d *Decoder) compileUint16(typ *rtype, structName, fieldName string) (decoder, error) { + return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) { *(*uint16)(p) = uint16(v) }), nil } -func (d *Decoder) compileUint32(structName, fieldName string) (decoder, error) { - return newUintDecoder(structName, fieldName, func(p unsafe.Pointer, v uint64) { +func (d *Decoder) compileUint32(typ *rtype, structName, fieldName string) (decoder, error) { + return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) { *(*uint32)(p) = uint32(v) }), nil } -func (d *Decoder) compileUint64(structName, fieldName string) (decoder, error) { - return newUintDecoder(structName, fieldName, func(p unsafe.Pointer, v uint64) { +func (d *Decoder) compileUint64(typ *rtype, structName, fieldName string) (decoder, error) { + return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) { *(*uint64)(p) = v }), nil } diff --git a/decode_int.go b/decode_int.go index 36874ce..acdf657 100644 --- a/decode_int.go +++ b/decode_int.go @@ -2,11 +2,13 @@ package json import ( "fmt" + "reflect" "unsafe" ) type intDecoder struct { typ *rtype + kind reflect.Kind op func(unsafe.Pointer, int64) structName string fieldName string @@ -15,6 +17,7 @@ type intDecoder struct { func newIntDecoder(typ *rtype, structName, fieldName string, op func(unsafe.Pointer, int64)) *intDecoder { return &intDecoder{ typ: typ, + kind: typ.Kind(), op: op, structName: structName, fieldName: fieldName, @@ -29,16 +32,6 @@ func (d *intDecoder) typeError(buf []byte, offset int64) *UnmarshalTypeError { } } -func (d *intDecoder) annotateError(cursor int64, err error) { - switch e := err.(type) { - case *UnmarshalTypeError: - e.Struct = d.structName - e.Field = d.fieldName - case *SyntaxError: - e.Offset = cursor - } -} - var ( pow10i64 = [...]int64{ 1e00, 1e01, 1e02, 1e03, 1e04, 1e05, 1e06, 1e07, 1e08, 1e09, @@ -162,7 +155,22 @@ func (d *intDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } - d.op(p, d.parseInt(bytes)) + i64 := d.parseInt(bytes) + switch d.kind { + case reflect.Int8: + if i64 <= -1*(1<<7) || (1<<7) <= i64 { + return d.typeError(bytes, s.totalOffset()) + } + case reflect.Int16: + if i64 <= -1*(1<<15) || (1<<15) <= i64 { + return d.typeError(bytes, s.totalOffset()) + } + case reflect.Int32: + if i64 <= -1*(1<<31) || (1<<31) <= i64 { + return d.typeError(bytes, s.totalOffset()) + } + } + d.op(p, i64) s.reset() return nil } @@ -173,6 +181,21 @@ func (d *intDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, return 0, err } cursor = c - d.op(p, d.parseInt(bytes)) + i64 := d.parseInt(bytes) + switch d.kind { + case reflect.Int8: + if i64 <= -1*(1<<7) || (1<<7) <= i64 { + return 0, d.typeError(bytes, cursor) + } + case reflect.Int16: + if i64 <= -1*(1<<15) || (1<<15) <= i64 { + return 0, d.typeError(bytes, cursor) + } + case reflect.Int32: + if i64 <= -1*(1<<31) || (1<<31) <= i64 { + return 0, d.typeError(bytes, cursor) + } + } + d.op(p, i64) return cursor, nil } diff --git a/decode_uint.go b/decode_uint.go index 1fd851b..be242ee 100644 --- a/decode_uint.go +++ b/decode_uint.go @@ -1,15 +1,35 @@ package json -import "unsafe" +import ( + "fmt" + "reflect" + "unsafe" +) type uintDecoder struct { + typ *rtype + kind reflect.Kind op func(unsafe.Pointer, uint64) structName string fieldName string } -func newUintDecoder(structName, fieldName string, op func(unsafe.Pointer, uint64)) *uintDecoder { - return &uintDecoder{op: op, structName: structName, fieldName: fieldName} +func newUintDecoder(typ *rtype, structName, fieldName string, op func(unsafe.Pointer, uint64)) *uintDecoder { + return &uintDecoder{ + typ: typ, + kind: typ.Kind(), + op: op, + structName: structName, + fieldName: fieldName, + } +} + +func (d *uintDecoder) typeError(buf []byte, offset int64) *UnmarshalTypeError { + return &UnmarshalTypeError{ + Value: fmt.Sprintf("number %s", string(buf)), + Type: rtype2type(d.typ), + Offset: offset, + } } var pow10u64 = [...]uint64{ @@ -50,6 +70,8 @@ func (d *uintDecoder) decodeStreamByte(s *stream) ([]byte, error) { } num := s.buf[start:s.cursor] return num, nil + default: + return nil, d.typeError([]byte{s.char()}, s.totalOffset()) case nul: if s.read() { continue @@ -79,7 +101,7 @@ func (d *uintDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error num := buf[start:cursor] return num, cursor, nil default: - return nil, 0, errInvalidCharacter(buf[cursor], "number(unsigned integer)", cursor) + return nil, 0, d.typeError([]byte{buf[cursor]}, cursor) } } return nil, 0, errUnexpectedEndOfJSON("number(unsigned integer)", cursor) @@ -90,7 +112,22 @@ func (d *uintDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } - d.op(p, d.parseUint(bytes)) + u64 := d.parseUint(bytes) + switch d.kind { + case reflect.Uint8: + if (1 << 8) <= u64 { + return d.typeError(bytes, s.totalOffset()) + } + case reflect.Uint16: + if (1 << 16) <= u64 { + return d.typeError(bytes, s.totalOffset()) + } + case reflect.Uint32: + if (1 << 32) <= u64 { + return d.typeError(bytes, s.totalOffset()) + } + } + d.op(p, u64) return nil } @@ -100,6 +137,21 @@ func (d *uintDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, return 0, err } cursor = c - d.op(p, d.parseUint(bytes)) + u64 := d.parseUint(bytes) + switch d.kind { + case reflect.Uint8: + if (1 << 8) <= u64 { + return 0, d.typeError(bytes, cursor) + } + case reflect.Uint16: + if (1 << 16) <= u64 { + return 0, d.typeError(bytes, cursor) + } + case reflect.Uint32: + if (1 << 32) <= u64 { + return 0, d.typeError(bytes, cursor) + } + } + d.op(p, u64) return cursor, nil } diff --git a/decode_wrapped_string.go b/decode_wrapped_string.go index 2a09870..d1b2cf4 100644 --- a/decode_wrapped_string.go +++ b/decode_wrapped_string.go @@ -1,6 +1,8 @@ package json -import "unsafe" +import ( + "unsafe" +) type wrappedStringDecoder struct { dec decoder