Fix decoder for invalid number value

This commit is contained in:
Masaaki Goshima 2021-04-11 15:18:08 +09:00
parent deec2b2b0d
commit 50bf5148f3
3 changed files with 190 additions and 17 deletions

View File

@ -39,15 +39,19 @@ var (
1e00, 1e01, 1e02, 1e03, 1e04, 1e05, 1e06, 1e07, 1e08, 1e09,
1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18,
}
pow10i64Len = len(pow10i64)
)
func (d *intDecoder) parseInt(b []byte) int64 {
func (d *intDecoder) parseInt(b []byte) (int64, error) {
isNegative := false
if b[0] == '-' {
b = b[1:]
isNegative = true
}
maxDigit := len(b)
if maxDigit > pow10i64Len {
return 0, fmt.Errorf("invalid length of number")
}
sum := int64(0)
for i := 0; i < maxDigit; i++ {
c := int64(b[i]) - 48
@ -55,9 +59,9 @@ func (d *intDecoder) parseInt(b []byte) int64 {
sum += c * digitValue
}
if isNegative {
return -1 * sum
return -1 * sum, nil
}
return sum
return sum, nil
}
var (
@ -100,7 +104,10 @@ func (d *intDecoder) decodeStreamByte(s *stream) ([]byte, error) {
goto ERROR
}
return num, nil
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
case '0':
s.cursor++
return []byte{'0'}, nil
case '1', '2', '3', '4', '5', '6', '7', '8', '9':
start := s.cursor
for {
s.cursor++
@ -141,7 +148,10 @@ func (d *intDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error)
case ' ', '\n', '\t', '\r':
cursor++
continue
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
case '0':
cursor++
return []byte{'0'}, cursor, nil
case '-', '1', '2', '3', '4', '5', '6', '7', '8', '9':
start := cursor
cursor++
LOOP:
@ -181,7 +191,10 @@ func (d *intDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) erro
if bytes == nil {
return nil
}
i64 := d.parseInt(bytes)
i64, err := d.parseInt(bytes)
if err != nil {
return d.typeError(bytes, s.totalOffset())
}
switch d.kind {
case reflect.Int8:
if i64 <= -1*(1<<7) || (1<<7) <= i64 {
@ -210,7 +223,11 @@ func (d *intDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (
return c, nil
}
cursor = c
i64 := d.parseInt(bytes)
i64, err := d.parseInt(bytes)
if err != nil {
return 0, d.typeError(bytes, cursor)
}
switch d.kind {
case reflect.Int8:
if i64 <= -1*(1<<7) || (1<<7) <= i64 {

View File

@ -2951,3 +2951,141 @@ func TestInvalidTopLevelValue(t *testing.T) {
}
})
}
func TestInvalidNumber(t *testing.T) {
t.Run("invalid length of number", func(t *testing.T) {
invalidNum := strings.Repeat("1", 30)
t.Run("int", func(t *testing.T) {
var v int64
stdErr := stdjson.Unmarshal([]byte(invalidNum), &v)
if stdErr == nil {
t.Fatal("expected error")
}
err := json.Unmarshal([]byte(invalidNum), &v)
if err == nil {
t.Fatal("expected error")
}
if stdErr.Error() != err.Error() {
t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error())
}
})
t.Run("uint", func(t *testing.T) {
var v uint64
stdErr := stdjson.Unmarshal([]byte(invalidNum), &v)
if stdErr == nil {
t.Fatal("expected error")
}
err := json.Unmarshal([]byte(invalidNum), &v)
if err == nil {
t.Fatal("expected error")
}
if stdErr.Error() != err.Error() {
t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error())
}
})
})
t.Run("invalid number of zero", func(t *testing.T) {
t.Run("int", func(t *testing.T) {
invalidNum := strings.Repeat("0", 10)
var v int64
stdErr := stdjson.Unmarshal([]byte(invalidNum), &v)
if stdErr == nil {
t.Fatal("expected error")
}
err := json.Unmarshal([]byte(invalidNum), &v)
if err == nil {
t.Fatal("expected error")
}
if stdErr.Error() != err.Error() {
t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error())
}
})
t.Run("uint", func(t *testing.T) {
invalidNum := strings.Repeat("0", 10)
var v uint64
stdErr := stdjson.Unmarshal([]byte(invalidNum), &v)
if stdErr == nil {
t.Fatal("expected error")
}
err := json.Unmarshal([]byte(invalidNum), &v)
if err == nil {
t.Fatal("expected error")
}
if stdErr.Error() != err.Error() {
t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error())
}
})
})
t.Run("invalid number", func(t *testing.T) {
t.Run("int", func(t *testing.T) {
t.Run("-0", func(t *testing.T) {
var v int64
if err := stdjson.Unmarshal([]byte(`-0`), &v); err != nil {
t.Fatal(err)
}
if err := json.Unmarshal([]byte(`-0`), &v); err != nil {
t.Fatal(err)
}
})
t.Run("+0", func(t *testing.T) {
var v int64
if err := stdjson.Unmarshal([]byte(`+0`), &v); err == nil {
t.Error("expected error")
}
if err := json.Unmarshal([]byte(`+0`), &v); err == nil {
t.Error("expected error")
}
})
})
t.Run("uint", func(t *testing.T) {
t.Run("-0", func(t *testing.T) {
var v uint64
if err := stdjson.Unmarshal([]byte(`-0`), &v); err == nil {
t.Error("expected error")
}
if err := json.Unmarshal([]byte(`-0`), &v); err == nil {
t.Error("expected error")
}
})
t.Run("+0", func(t *testing.T) {
var v uint64
if err := stdjson.Unmarshal([]byte(`+0`), &v); err == nil {
t.Error("expected error")
}
if err := json.Unmarshal([]byte(`+0`), &v); err == nil {
t.Error("expected error")
}
})
})
t.Run("float", func(t *testing.T) {
t.Run("0.0", func(t *testing.T) {
var f float64
if err := stdjson.Unmarshal([]byte(`0.0`), &f); err != nil {
t.Fatal(err)
}
if err := json.Unmarshal([]byte(`0.0`), &f); err != nil {
t.Fatal(err)
}
})
t.Run("0.000000000", func(t *testing.T) {
var f float64
if err := stdjson.Unmarshal([]byte(`0.000000000`), &f); err != nil {
t.Fatal(err)
}
if err := json.Unmarshal([]byte(`0.000000000`), &f); err != nil {
t.Fatal(err)
}
})
t.Run("repeat zero a lot with float value", func(t *testing.T) {
var f float64
if err := stdjson.Unmarshal([]byte("0."+strings.Repeat("0", 30)), &f); err != nil {
t.Fatal(err)
}
if err := json.Unmarshal([]byte("0."+strings.Repeat("0", 30)), &f); err != nil {
t.Fatal(err)
}
})
})
})
}

View File

@ -32,20 +32,26 @@ func (d *uintDecoder) typeError(buf []byte, offset int64) *UnmarshalTypeError {
}
}
var pow10u64 = [...]uint64{
var (
pow10u64 = [...]uint64{
1e00, 1e01, 1e02, 1e03, 1e04, 1e05, 1e06, 1e07, 1e08, 1e09,
1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19,
}
pow10u64Len = len(pow10u64)
)
func (d *uintDecoder) parseUint(b []byte) uint64 {
func (d *uintDecoder) parseUint(b []byte) (uint64, error) {
maxDigit := len(b)
if maxDigit > pow10u64Len {
return 0, fmt.Errorf("invalid length of number")
}
sum := uint64(0)
for i := 0; i < maxDigit; i++ {
c := uint64(b[i]) - 48
digitValue := pow10u64[maxDigit-i-1]
sum += c * digitValue
}
return sum
return sum, nil
}
func (d *uintDecoder) decodeStreamByte(s *stream) ([]byte, error) {
@ -54,7 +60,10 @@ func (d *uintDecoder) decodeStreamByte(s *stream) ([]byte, error) {
case ' ', '\n', '\t', '\r':
s.cursor++
continue
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
case '0':
s.cursor++
return []byte{'0'}, nil
case '1', '2', '3', '4', '5', '6', '7', '8', '9':
start := s.cursor
for {
s.cursor++
@ -93,7 +102,10 @@ func (d *uintDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error
switch buf[cursor] {
case ' ', '\n', '\t', '\r':
continue
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
case '0':
cursor++
return []byte{'0'}, cursor, nil
case '1', '2', '3', '4', '5', '6', '7', '8', '9':
start := cursor
cursor++
for ; cursor < buflen; cursor++ {
@ -135,7 +147,10 @@ func (d *uintDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) err
if bytes == nil {
return nil
}
u64 := d.parseUint(bytes)
u64, err := d.parseUint(bytes)
if err != nil {
return d.typeError(bytes, s.totalOffset())
}
switch d.kind {
case reflect.Uint8:
if (1 << 8) <= u64 {
@ -163,7 +178,10 @@ func (d *uintDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer)
return c, nil
}
cursor = c
u64 := d.parseUint(bytes)
u64, err := d.parseUint(bytes)
if err != nil {
return 0, d.typeError(bytes, cursor)
}
switch d.kind {
case reflect.Uint8:
if (1 << 8) <= u64 {