mirror of https://github.com/goccy/go-json.git
Merge pull request #185 from goccy/feature/fix-number
Fix decoder for invalid number value
This commit is contained in:
commit
29283c4c83
|
@ -39,15 +39,19 @@ var (
|
||||||
1e00, 1e01, 1e02, 1e03, 1e04, 1e05, 1e06, 1e07, 1e08, 1e09,
|
1e00, 1e01, 1e02, 1e03, 1e04, 1e05, 1e06, 1e07, 1e08, 1e09,
|
||||||
1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18,
|
1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18,
|
||||||
}
|
}
|
||||||
|
pow10i64Len = len(pow10i64)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (d *intDecoder) parseInt(b []byte) int64 {
|
func (d *intDecoder) parseInt(b []byte) (int64, error) {
|
||||||
isNegative := false
|
isNegative := false
|
||||||
if b[0] == '-' {
|
if b[0] == '-' {
|
||||||
b = b[1:]
|
b = b[1:]
|
||||||
isNegative = true
|
isNegative = true
|
||||||
}
|
}
|
||||||
maxDigit := len(b)
|
maxDigit := len(b)
|
||||||
|
if maxDigit > pow10i64Len {
|
||||||
|
return 0, fmt.Errorf("invalid length of number")
|
||||||
|
}
|
||||||
sum := int64(0)
|
sum := int64(0)
|
||||||
for i := 0; i < maxDigit; i++ {
|
for i := 0; i < maxDigit; i++ {
|
||||||
c := int64(b[i]) - 48
|
c := int64(b[i]) - 48
|
||||||
|
@ -55,9 +59,9 @@ func (d *intDecoder) parseInt(b []byte) int64 {
|
||||||
sum += c * digitValue
|
sum += c * digitValue
|
||||||
}
|
}
|
||||||
if isNegative {
|
if isNegative {
|
||||||
return -1 * sum
|
return -1 * sum, nil
|
||||||
}
|
}
|
||||||
return sum
|
return sum, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -100,7 +104,10 @@ func (d *intDecoder) decodeStreamByte(s *stream) ([]byte, error) {
|
||||||
goto ERROR
|
goto ERROR
|
||||||
}
|
}
|
||||||
return num, nil
|
return num, nil
|
||||||
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
case '0':
|
||||||
|
s.cursor++
|
||||||
|
return []byte{'0'}, nil
|
||||||
|
case '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
||||||
start := s.cursor
|
start := s.cursor
|
||||||
for {
|
for {
|
||||||
s.cursor++
|
s.cursor++
|
||||||
|
@ -141,7 +148,10 @@ func (d *intDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error)
|
||||||
case ' ', '\n', '\t', '\r':
|
case ' ', '\n', '\t', '\r':
|
||||||
cursor++
|
cursor++
|
||||||
continue
|
continue
|
||||||
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
case '0':
|
||||||
|
cursor++
|
||||||
|
return []byte{'0'}, cursor, nil
|
||||||
|
case '-', '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
||||||
start := cursor
|
start := cursor
|
||||||
cursor++
|
cursor++
|
||||||
LOOP:
|
LOOP:
|
||||||
|
@ -181,7 +191,10 @@ func (d *intDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) erro
|
||||||
if bytes == nil {
|
if bytes == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
i64 := d.parseInt(bytes)
|
i64, err := d.parseInt(bytes)
|
||||||
|
if err != nil {
|
||||||
|
return d.typeError(bytes, s.totalOffset())
|
||||||
|
}
|
||||||
switch d.kind {
|
switch d.kind {
|
||||||
case reflect.Int8:
|
case reflect.Int8:
|
||||||
if i64 <= -1*(1<<7) || (1<<7) <= i64 {
|
if i64 <= -1*(1<<7) || (1<<7) <= i64 {
|
||||||
|
@ -210,7 +223,11 @@ func (d *intDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
cursor = c
|
cursor = c
|
||||||
i64 := d.parseInt(bytes)
|
|
||||||
|
i64, err := d.parseInt(bytes)
|
||||||
|
if err != nil {
|
||||||
|
return 0, d.typeError(bytes, cursor)
|
||||||
|
}
|
||||||
switch d.kind {
|
switch d.kind {
|
||||||
case reflect.Int8:
|
case reflect.Int8:
|
||||||
if i64 <= -1*(1<<7) || (1<<7) <= i64 {
|
if i64 <= -1*(1<<7) || (1<<7) <= i64 {
|
||||||
|
|
138
decode_test.go
138
decode_test.go
|
@ -2951,3 +2951,141 @@ func TestInvalidTopLevelValue(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInvalidNumber(t *testing.T) {
|
||||||
|
t.Run("invalid length of number", func(t *testing.T) {
|
||||||
|
invalidNum := strings.Repeat("1", 30)
|
||||||
|
t.Run("int", func(t *testing.T) {
|
||||||
|
var v int64
|
||||||
|
stdErr := stdjson.Unmarshal([]byte(invalidNum), &v)
|
||||||
|
if stdErr == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
err := json.Unmarshal([]byte(invalidNum), &v)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if stdErr.Error() != err.Error() {
|
||||||
|
t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("uint", func(t *testing.T) {
|
||||||
|
var v uint64
|
||||||
|
stdErr := stdjson.Unmarshal([]byte(invalidNum), &v)
|
||||||
|
if stdErr == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
err := json.Unmarshal([]byte(invalidNum), &v)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if stdErr.Error() != err.Error() {
|
||||||
|
t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
t.Run("invalid number of zero", func(t *testing.T) {
|
||||||
|
t.Run("int", func(t *testing.T) {
|
||||||
|
invalidNum := strings.Repeat("0", 10)
|
||||||
|
var v int64
|
||||||
|
stdErr := stdjson.Unmarshal([]byte(invalidNum), &v)
|
||||||
|
if stdErr == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
err := json.Unmarshal([]byte(invalidNum), &v)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if stdErr.Error() != err.Error() {
|
||||||
|
t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("uint", func(t *testing.T) {
|
||||||
|
invalidNum := strings.Repeat("0", 10)
|
||||||
|
var v uint64
|
||||||
|
stdErr := stdjson.Unmarshal([]byte(invalidNum), &v)
|
||||||
|
if stdErr == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
err := json.Unmarshal([]byte(invalidNum), &v)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if stdErr.Error() != err.Error() {
|
||||||
|
t.Fatalf("unexpected error message. expected: %q but got %q", stdErr.Error(), err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("invalid number", func(t *testing.T) {
|
||||||
|
t.Run("int", func(t *testing.T) {
|
||||||
|
t.Run("-0", func(t *testing.T) {
|
||||||
|
var v int64
|
||||||
|
if err := stdjson.Unmarshal([]byte(`-0`), &v); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(`-0`), &v); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("+0", func(t *testing.T) {
|
||||||
|
var v int64
|
||||||
|
if err := stdjson.Unmarshal([]byte(`+0`), &v); err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(`+0`), &v); err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("uint", func(t *testing.T) {
|
||||||
|
t.Run("-0", func(t *testing.T) {
|
||||||
|
var v uint64
|
||||||
|
if err := stdjson.Unmarshal([]byte(`-0`), &v); err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(`-0`), &v); err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("+0", func(t *testing.T) {
|
||||||
|
var v uint64
|
||||||
|
if err := stdjson.Unmarshal([]byte(`+0`), &v); err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(`+0`), &v); err == nil {
|
||||||
|
t.Error("expected error")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("float", func(t *testing.T) {
|
||||||
|
t.Run("0.0", func(t *testing.T) {
|
||||||
|
var f float64
|
||||||
|
if err := stdjson.Unmarshal([]byte(`0.0`), &f); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(`0.0`), &f); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("0.000000000", func(t *testing.T) {
|
||||||
|
var f float64
|
||||||
|
if err := stdjson.Unmarshal([]byte(`0.000000000`), &f); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(`0.000000000`), &f); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("repeat zero a lot with float value", func(t *testing.T) {
|
||||||
|
var f float64
|
||||||
|
if err := stdjson.Unmarshal([]byte("0."+strings.Repeat("0", 30)), &f); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte("0."+strings.Repeat("0", 30)), &f); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -32,20 +32,26 @@ func (d *uintDecoder) typeError(buf []byte, offset int64) *UnmarshalTypeError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var pow10u64 = [...]uint64{
|
var (
|
||||||
1e00, 1e01, 1e02, 1e03, 1e04, 1e05, 1e06, 1e07, 1e08, 1e09,
|
pow10u64 = [...]uint64{
|
||||||
1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19,
|
1e00, 1e01, 1e02, 1e03, 1e04, 1e05, 1e06, 1e07, 1e08, 1e09,
|
||||||
}
|
1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19,
|
||||||
|
}
|
||||||
|
pow10u64Len = len(pow10u64)
|
||||||
|
)
|
||||||
|
|
||||||
func (d *uintDecoder) parseUint(b []byte) uint64 {
|
func (d *uintDecoder) parseUint(b []byte) (uint64, error) {
|
||||||
maxDigit := len(b)
|
maxDigit := len(b)
|
||||||
|
if maxDigit > pow10u64Len {
|
||||||
|
return 0, fmt.Errorf("invalid length of number")
|
||||||
|
}
|
||||||
sum := uint64(0)
|
sum := uint64(0)
|
||||||
for i := 0; i < maxDigit; i++ {
|
for i := 0; i < maxDigit; i++ {
|
||||||
c := uint64(b[i]) - 48
|
c := uint64(b[i]) - 48
|
||||||
digitValue := pow10u64[maxDigit-i-1]
|
digitValue := pow10u64[maxDigit-i-1]
|
||||||
sum += c * digitValue
|
sum += c * digitValue
|
||||||
}
|
}
|
||||||
return sum
|
return sum, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *uintDecoder) decodeStreamByte(s *stream) ([]byte, error) {
|
func (d *uintDecoder) decodeStreamByte(s *stream) ([]byte, error) {
|
||||||
|
@ -54,7 +60,10 @@ func (d *uintDecoder) decodeStreamByte(s *stream) ([]byte, error) {
|
||||||
case ' ', '\n', '\t', '\r':
|
case ' ', '\n', '\t', '\r':
|
||||||
s.cursor++
|
s.cursor++
|
||||||
continue
|
continue
|
||||||
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
case '0':
|
||||||
|
s.cursor++
|
||||||
|
return []byte{'0'}, nil
|
||||||
|
case '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
||||||
start := s.cursor
|
start := s.cursor
|
||||||
for {
|
for {
|
||||||
s.cursor++
|
s.cursor++
|
||||||
|
@ -93,7 +102,10 @@ func (d *uintDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error
|
||||||
switch buf[cursor] {
|
switch buf[cursor] {
|
||||||
case ' ', '\n', '\t', '\r':
|
case ' ', '\n', '\t', '\r':
|
||||||
continue
|
continue
|
||||||
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
case '0':
|
||||||
|
cursor++
|
||||||
|
return []byte{'0'}, cursor, nil
|
||||||
|
case '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
||||||
start := cursor
|
start := cursor
|
||||||
cursor++
|
cursor++
|
||||||
for ; cursor < buflen; cursor++ {
|
for ; cursor < buflen; cursor++ {
|
||||||
|
@ -135,7 +147,10 @@ func (d *uintDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) err
|
||||||
if bytes == nil {
|
if bytes == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
u64 := d.parseUint(bytes)
|
u64, err := d.parseUint(bytes)
|
||||||
|
if err != nil {
|
||||||
|
return d.typeError(bytes, s.totalOffset())
|
||||||
|
}
|
||||||
switch d.kind {
|
switch d.kind {
|
||||||
case reflect.Uint8:
|
case reflect.Uint8:
|
||||||
if (1 << 8) <= u64 {
|
if (1 << 8) <= u64 {
|
||||||
|
@ -163,7 +178,10 @@ func (d *uintDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer)
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
cursor = c
|
cursor = c
|
||||||
u64 := d.parseUint(bytes)
|
u64, err := d.parseUint(bytes)
|
||||||
|
if err != nil {
|
||||||
|
return 0, d.typeError(bytes, cursor)
|
||||||
|
}
|
||||||
switch d.kind {
|
switch d.kind {
|
||||||
case reflect.Uint8:
|
case reflect.Uint8:
|
||||||
if (1 << 8) <= u64 {
|
if (1 << 8) <= u64 {
|
||||||
|
|
Loading…
Reference in New Issue