diff --git a/encode.go b/encode.go index 376fc5d..f41114e 100644 --- a/encode.go +++ b/encode.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding" "io" + "math" "reflect" "strconv" "sync" @@ -245,11 +246,29 @@ func (e *Encoder) encodeUint64(v uint64) { } func (e *Encoder) encodeFloat32(v float32) { - e.buf = strconv.AppendFloat(e.buf, float64(v), 'f', -1, 32) + f64 := float64(v) + abs := math.Abs(f64) + fmt := byte('f') + // Note: Must use float32 comparisons for underlying float32 value to get precise cutoffs right. + if abs != 0 { + f32 := float32(abs) + if f32 < 1e-6 || f32 >= 1e21 { + fmt = 'e' + } + } + e.buf = strconv.AppendFloat(e.buf, f64, fmt, -1, 32) } func (e *Encoder) encodeFloat64(v float64) { - e.buf = strconv.AppendFloat(e.buf, v, 'f', -1, 64) + abs := math.Abs(v) + fmt := byte('f') + // Note: Must use float32 comparisons for underlying float32 value to get precise cutoffs right. + if abs != 0 { + if abs < 1e-6 || abs >= 1e21 { + fmt = 'e' + } + } + e.buf = strconv.AppendFloat(e.buf, v, fmt, -1, 64) } func (e *Encoder) encodeBool(v bool) { diff --git a/encode_test.go b/encode_test.go index 8033bd3..448c3b2 100644 --- a/encode_test.go +++ b/encode_test.go @@ -3,7 +3,11 @@ package json_test import ( "errors" "fmt" + "log" + "math" "reflect" + "regexp" + "strconv" "testing" "time" @@ -862,3 +866,110 @@ func TestMarshalerError(t *testing.T) { } } } + +var re = regexp.MustCompile + +// syntactic checks on form of marshaled floating point numbers. +var badFloatREs = []*regexp.Regexp{ + re(`p`), // no binary exponential notation + re(`^\+`), // no leading + sign + re(`^-?0[^.]`), // no unnecessary leading zeros + re(`^-?\.`), // leading zero required before decimal point + re(`\.(e|$)`), // no trailing decimal + re(`\.[0-9]+0(e|$)`), // no trailing zero in fraction + re(`^-?(0|[0-9]{2,})\..*e`), // exponential notation must have normalized mantissa + re(`e[0-9]`), // positive exponent must be signed + //re(`e[+-]0`), // exponent must not have leading zeros + re(`e-[1-6]$`), // not tiny enough for exponential notation + re(`e+(.|1.|20)$`), // not big enough for exponential notation + re(`^-?0\.0000000`), // too tiny, should use exponential notation + re(`^-?[0-9]{22}`), // too big, should use exponential notation + re(`[1-9][0-9]{16}[1-9]`), // too many significant digits in integer + re(`[1-9][0-9.]{17}[1-9]`), // too many significant digits in decimal + // below here for float32 only + re(`[1-9][0-9]{8}[1-9]`), // too many significant digits in integer + re(`[1-9][0-9.]{9}[1-9]`), // too many significant digits in decimal +} + +func TestMarshalFloat(t *testing.T) { + t.Parallel() + nfail := 0 + test := func(f float64, bits int) { + vf := interface{}(f) + if bits == 32 { + f = float64(float32(f)) // round + vf = float32(f) + } + bout, err := json.Marshal(vf) + if err != nil { + t.Errorf("Marshal(%T(%g)): %v", vf, vf, err) + nfail++ + return + } + out := string(bout) + + // result must convert back to the same float + g, err := strconv.ParseFloat(out, bits) + if err != nil { + t.Errorf("Marshal(%T(%g)) = %q, cannot parse back: %v", vf, vf, out, err) + nfail++ + return + } + if f != g || fmt.Sprint(f) != fmt.Sprint(g) { // fmt.Sprint handles ±0 + t.Errorf("Marshal(%T(%g)) = %q (is %g, not %g)", vf, vf, out, float32(g), vf) + nfail++ + return + } + + bad := badFloatREs + if bits == 64 { + bad = bad[:len(bad)-2] + } + for _, re := range bad { + if re.MatchString(out) { + t.Errorf("Marshal(%T(%g)) = %q, must not match /%s/", vf, vf, out, re) + nfail++ + return + } + } + } + + var ( + bigger = math.Inf(+1) + smaller = math.Inf(-1) + ) + + var digits = "1.2345678901234567890123" + for i := len(digits); i >= 2; i-- { + if testing.Short() && i < len(digits)-4 { + break + } + for exp := -30; exp <= 30; exp++ { + for _, sign := range "+-" { + for bits := 32; bits <= 64; bits += 32 { + s := fmt.Sprintf("%c%se%d", sign, digits[:i], exp) + f, err := strconv.ParseFloat(s, bits) + if err != nil { + log.Fatal(err) + } + next := math.Nextafter + if bits == 32 { + next = func(g, h float64) float64 { + return float64(math.Nextafter32(float32(g), float32(h))) + } + } + test(f, bits) + test(next(f, bigger), bits) + test(next(f, smaller), bits) + if nfail > 50 { + t.Fatalf("stopping test early") + } + } + } + } + } + test(0, 64) + test(math.Copysign(0, -1), 64) + test(0, 32) + test(math.Copysign(0, -1), 32) +}