diff --git a/decode_int.go b/decode_int.go index 4ea6fa5..8db98a9 100644 --- a/decode_int.go +++ b/decode_int.go @@ -39,15 +39,19 @@ var ( 1e00, 1e01, 1e02, 1e03, 1e04, 1e05, 1e06, 1e07, 1e08, 1e09, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, } + pow10i64Len = len(pow10i64) ) -func (d *intDecoder) parseInt(b []byte) int64 { +func (d *intDecoder) parseInt(b []byte) (int64, error) { isNegative := false if b[0] == '-' { b = b[1:] isNegative = true } maxDigit := len(b) + if maxDigit > pow10i64Len { + return 0, fmt.Errorf("invalid length of number") + } sum := int64(0) for i := 0; i < maxDigit; i++ { c := int64(b[i]) - 48 @@ -55,9 +59,9 @@ func (d *intDecoder) parseInt(b []byte) int64 { sum += c * digitValue } if isNegative { - return -1 * sum + return -1 * sum, nil } - return sum + return sum, nil } var ( @@ -100,7 +104,10 @@ func (d *intDecoder) decodeStreamByte(s *stream) ([]byte, error) { goto ERROR } return num, nil - case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + case '0': + s.cursor++ + return []byte{'0'}, nil + case '1', '2', '3', '4', '5', '6', '7', '8', '9': start := s.cursor for { s.cursor++ @@ -141,7 +148,10 @@ func (d *intDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error) case ' ', '\n', '\t', '\r': cursor++ continue - case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + case '0': + cursor++ + return []byte{'0'}, cursor, nil + case '-', '1', '2', '3', '4', '5', '6', '7', '8', '9': start := cursor cursor++ LOOP: @@ -181,7 +191,10 @@ func (d *intDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) erro if bytes == nil { return nil } - i64 := d.parseInt(bytes) + i64, err := d.parseInt(bytes) + if err != nil { + return d.typeError(bytes, s.totalOffset()) + } switch d.kind { case reflect.Int8: if i64 <= -1*(1<<7) || (1<<7) <= i64 { @@ -210,7 +223,11 @@ func (d *intDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) ( return c, nil } cursor = c - i64 := d.parseInt(bytes) + + i64, err := d.parseInt(bytes) + if err != nil { + return 0, d.typeError(bytes, cursor) + } switch d.kind { case reflect.Int8: if i64 <= -1*(1<<7) || (1<<7) <= i64 { diff --git a/decode_test.go b/decode_test.go index 356a754..574341f 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2951,3 +2951,141 @@ func TestInvalidTopLevelValue(t *testing.T) { } }) } + +func TestInvalidNumber(t *testing.T) { + t.Run("invalid length of number", func(t *testing.T) { + invalidNum := strings.Repeat("1", 30) + t.Run("int", func(t *testing.T) { + var v int64 + stdErr := stdjson.Unmarshal([]byte(invalidNum), &v) + if stdErr == nil { + t.Fatal("expected error") + } + err := json.Unmarshal([]byte(invalidNum), &v) + if err == nil { + t.Fatal("expected error") + } + if stdErr.Error() != err.Error() { + t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error()) + } + }) + t.Run("uint", func(t *testing.T) { + var v uint64 + stdErr := stdjson.Unmarshal([]byte(invalidNum), &v) + if stdErr == nil { + t.Fatal("expected error") + } + err := json.Unmarshal([]byte(invalidNum), &v) + if err == nil { + t.Fatal("expected error") + } + if stdErr.Error() != err.Error() { + t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error()) + } + }) + + }) + t.Run("invalid number of zero", func(t *testing.T) { + t.Run("int", func(t *testing.T) { + invalidNum := strings.Repeat("0", 10) + var v int64 + stdErr := stdjson.Unmarshal([]byte(invalidNum), &v) + if stdErr == nil { + t.Fatal("expected error") + } + err := json.Unmarshal([]byte(invalidNum), &v) + if err == nil { + t.Fatal("expected error") + } + if stdErr.Error() != err.Error() { + t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error()) + } + }) + t.Run("uint", func(t *testing.T) { + invalidNum := strings.Repeat("0", 10) + var v uint64 + stdErr := stdjson.Unmarshal([]byte(invalidNum), &v) + if stdErr == nil { + t.Fatal("expected error") + } + err := json.Unmarshal([]byte(invalidNum), &v) + if err == nil { + t.Fatal("expected error") + } + if stdErr.Error() != err.Error() { + t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error()) + } + }) + }) + t.Run("invalid number", func(t *testing.T) { + t.Run("int", func(t *testing.T) { + t.Run("-0", func(t *testing.T) { + var v int64 + if err := stdjson.Unmarshal([]byte(`-0`), &v); err != nil { + t.Fatal(err) + } + if err := json.Unmarshal([]byte(`-0`), &v); err != nil { + t.Fatal(err) + } + }) + t.Run("+0", func(t *testing.T) { + var v int64 + if err := stdjson.Unmarshal([]byte(`+0`), &v); err == nil { + t.Error("expected error") + } + if err := json.Unmarshal([]byte(`+0`), &v); err == nil { + t.Error("expected error") + } + }) + }) + t.Run("uint", func(t *testing.T) { + t.Run("-0", func(t *testing.T) { + var v uint64 + if err := stdjson.Unmarshal([]byte(`-0`), &v); err == nil { + t.Error("expected error") + } + if err := json.Unmarshal([]byte(`-0`), &v); err == nil { + t.Error("expected error") + } + }) + t.Run("+0", func(t *testing.T) { + var v uint64 + if err := stdjson.Unmarshal([]byte(`+0`), &v); err == nil { + t.Error("expected error") + } + if err := json.Unmarshal([]byte(`+0`), &v); err == nil { + t.Error("expected error") + } + }) + }) + t.Run("float", func(t *testing.T) { + t.Run("0.0", func(t *testing.T) { + var f float64 + if err := stdjson.Unmarshal([]byte(`0.0`), &f); err != nil { + t.Fatal(err) + } + if err := json.Unmarshal([]byte(`0.0`), &f); err != nil { + t.Fatal(err) + } + }) + t.Run("0.000000000", func(t *testing.T) { + var f float64 + if err := stdjson.Unmarshal([]byte(`0.000000000`), &f); err != nil { + t.Fatal(err) + } + if err := json.Unmarshal([]byte(`0.000000000`), &f); err != nil { + t.Fatal(err) + } + }) + t.Run("repeat zero a lot with float value", func(t *testing.T) { + var f float64 + if err := stdjson.Unmarshal([]byte("0."+strings.Repeat("0", 30)), &f); err != nil { + t.Fatal(err) + } + if err := json.Unmarshal([]byte("0."+strings.Repeat("0", 30)), &f); err != nil { + t.Fatal(err) + } + }) + }) + }) +} diff --git a/decode_uint.go b/decode_uint.go index 4c55bac..2d06e06 100644 --- a/decode_uint.go +++ b/decode_uint.go @@ -32,20 +32,26 @@ func (d *uintDecoder) typeError(buf []byte, offset int64) *UnmarshalTypeError { } } -var pow10u64 = [...]uint64{ - 1e00, 1e01, 1e02, 1e03, 1e04, 1e05, 1e06, 1e07, 1e08, 1e09, - 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19, -} +var ( + pow10u64 = [...]uint64{ + 1e00, 1e01, 1e02, 1e03, 1e04, 1e05, 1e06, 1e07, 1e08, 1e09, + 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19, + } + pow10u64Len = len(pow10u64) +) -func (d *uintDecoder) parseUint(b []byte) uint64 { +func (d *uintDecoder) parseUint(b []byte) (uint64, error) { maxDigit := len(b) + if maxDigit > pow10u64Len { + return 0, fmt.Errorf("invalid length of number") + } sum := uint64(0) for i := 0; i < maxDigit; i++ { c := uint64(b[i]) - 48 digitValue := pow10u64[maxDigit-i-1] sum += c * digitValue } - return sum + return sum, nil } func (d *uintDecoder) decodeStreamByte(s *stream) ([]byte, error) { @@ -54,7 +60,10 @@ func (d *uintDecoder) decodeStreamByte(s *stream) ([]byte, error) { case ' ', '\n', '\t', '\r': s.cursor++ continue - case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + case '0': + s.cursor++ + return []byte{'0'}, nil + case '1', '2', '3', '4', '5', '6', '7', '8', '9': start := s.cursor for { s.cursor++ @@ -93,7 +102,10 @@ func (d *uintDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error switch buf[cursor] { case ' ', '\n', '\t', '\r': continue - case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + case '0': + cursor++ + return []byte{'0'}, cursor, nil + case '1', '2', '3', '4', '5', '6', '7', '8', '9': start := cursor cursor++ for ; cursor < buflen; cursor++ { @@ -135,7 +147,10 @@ func (d *uintDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) err if bytes == nil { return nil } - u64 := d.parseUint(bytes) + u64, err := d.parseUint(bytes) + if err != nil { + return d.typeError(bytes, s.totalOffset()) + } switch d.kind { case reflect.Uint8: if (1 << 8) <= u64 { @@ -163,7 +178,10 @@ func (d *uintDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) return c, nil } cursor = c - u64 := d.parseUint(bytes) + u64, err := d.parseUint(bytes) + if err != nil { + return 0, d.typeError(bytes, cursor) + } switch d.kind { case reflect.Uint8: if (1 << 8) <= u64 {