diff --git a/encode.go b/encode.go index 30bf31f..5c86ec8 100644 --- a/encode.go +++ b/encode.go @@ -142,6 +142,9 @@ func (e *Encoder) encode(v reflect.Value, ptr unsafe.Pointer) ([]byte, error) { if err != nil { return nil, err } + if op == nil { + return nil, nil + } if name != "" { cachedEncodeOp[name] = op } @@ -161,6 +164,8 @@ func (e *Encoder) compile(v reflect.Value) (EncodeOp, error) { return e.compileStruct(v) case reflect.Map: return e.compileMap(v) + case reflect.Array: + return e.compileArray(v) case reflect.Int: return e.compileInt() case reflect.Int8: @@ -189,6 +194,20 @@ func (e *Encoder) compile(v reflect.Value) (EncodeOp, error) { return e.compileString() case reflect.Bool: return e.compileBool() + case reflect.Interface: + return nil, ErrCompileSlowPath + case reflect.Func: + return nil, nil + case reflect.Chan: + return nil, nil + case reflect.UnsafePointer: + return nil, nil + case reflect.Uintptr: + return nil, nil + case reflect.Complex64: + return nil, nil + case reflect.Complex128: + return nil, nil } return nil, xerrors.Errorf("failed to compile %s: %w", v.Type(), ErrUnknownType) } @@ -198,6 +217,9 @@ func (e *Encoder) compilePtr(v reflect.Value) (EncodeOp, error) { if err != nil { return nil, err } + if op == nil { + return nil, nil + } return func(enc *Encoder, p uintptr) { op(enc, e.ptrToPtr(p)) }, nil @@ -259,17 +281,16 @@ func (e *Encoder) compileBool() (EncodeOp, error) { return func(enc *Encoder, p uintptr) { enc.EncodeBool(e.ptrToBool(p)) }, nil } -func (e *Encoder) zeroValue(typ reflect.Type) reflect.Value { - return reflect.New(typ).Elem() -} - func (e *Encoder) compileSlice(v reflect.Value) (EncodeOp, error) { typ := v.Type() size := typ.Elem().Size() - op, err := e.compile(e.zeroValue(typ.Elem())) + op, err := e.compile(reflect.Zero(typ.Elem())) if err != nil { return nil, err } + if op == nil { + return nil, nil + } return func(enc *Encoder, base uintptr) { if base == 0 { enc.EncodeString("null") @@ -286,7 +307,33 @@ func (e *Encoder) compileSlice(v reflect.Value) (EncodeOp, error) { } enc.EncodeByte(']') }, nil +} +func (e *Encoder) compileArray(v reflect.Value) (EncodeOp, error) { + typ := v.Type() + alen := typ.Len() + size := typ.Elem().Size() + op, err := e.compile(reflect.Zero(typ.Elem())) + if err != nil { + return nil, err + } + if op == nil { + return nil, nil + } + return func(enc *Encoder, base uintptr) { + if base == 0 { + enc.EncodeString("null") + return + } + enc.EncodeByte('[') + for i := 0; i < alen; i++ { + if i != 0 { + enc.EncodeByte(',') + } + op(enc, base+uintptr(i)*size) + } + enc.EncodeByte(']') + }, nil } func (e *Encoder) compileStruct(v reflect.Value) (EncodeOp, error) { @@ -307,6 +354,9 @@ func (e *Encoder) compileStruct(v reflect.Value) (EncodeOp, error) { if err != nil { return nil, err } + if op == nil { + continue + } key := fmt.Sprintf(`"%s":`, keyName) opQueue = append(opQueue, func(enc *Encoder, base uintptr) { enc.EncodeString(key) @@ -352,14 +402,20 @@ type valueType struct { func (e *Encoder) compileMap(v reflect.Value) (EncodeOp, error) { mapType := (*valueType)(unsafe.Pointer(&v)).typ - keyOp, err := e.compile(e.zeroValue(v.Type().Key())) + keyOp, err := e.compile(reflect.Zero(v.Type().Key())) if err != nil { return nil, err } - valueOp, err := e.compile(e.zeroValue(v.Type().Elem())) + if keyOp == nil { + return nil, nil + } + valueOp, err := e.compile(reflect.Zero(v.Type().Elem())) if err != nil { return nil, err } + if valueOp == nil { + return nil, nil + } return func(enc *Encoder, base uintptr) { if base == 0 { enc.EncodeString("null") diff --git a/encode_test.go b/encode_test.go index 12ddd6b..f8fa4c0 100644 --- a/encode_test.go +++ b/encode_test.go @@ -6,132 +6,90 @@ import ( "github.com/goccy/go-json" ) +func assertErr(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("%+v", err) + } +} + +func assertEq(t *testing.T, msg string, exp interface{}, act interface{}) { + t.Helper() + if exp != act { + t.Fatalf("failed to encode %s. exp=[%v] but act=[%v]", msg, exp, act) + } +} + func Test_Encoder(t *testing.T) { t.Run("int", func(t *testing.T) { bytes, err := json.Marshal(-10) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `-10` { - t.Fatal("failed to encode int") - } + assertErr(t, err) + assertEq(t, "int", `-10`, string(bytes)) }) t.Run("int8", func(t *testing.T) { bytes, err := json.Marshal(int8(-11)) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `-11` { - t.Fatal("failed to encode int8") - } + assertErr(t, err) + assertEq(t, "int8", `-11`, string(bytes)) }) t.Run("int16", func(t *testing.T) { bytes, err := json.Marshal(int16(-12)) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `-12` { - t.Fatal("failed to encode int16") - } + assertErr(t, err) + assertEq(t, "int16", `-12`, string(bytes)) }) t.Run("int32", func(t *testing.T) { bytes, err := json.Marshal(int32(-13)) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `-13` { - t.Fatal("failed to encode int32") - } + assertErr(t, err) + assertEq(t, "int32", `-13`, string(bytes)) }) t.Run("int64", func(t *testing.T) { bytes, err := json.Marshal(int64(-14)) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `-14` { - t.Fatal("failed to encode int64") - } + assertErr(t, err) + assertEq(t, "int64", `-14`, string(bytes)) }) t.Run("uint", func(t *testing.T) { bytes, err := json.Marshal(uint(10)) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `10` { - t.Fatal("failed to encode uint") - } + assertErr(t, err) + assertEq(t, "uint", `10`, string(bytes)) }) t.Run("uint8", func(t *testing.T) { bytes, err := json.Marshal(uint8(11)) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `11` { - t.Fatal("failed to encode uint8") - } + assertErr(t, err) + assertEq(t, "uint8", `11`, string(bytes)) }) t.Run("uint16", func(t *testing.T) { bytes, err := json.Marshal(uint16(12)) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `12` { - t.Fatal("failed to encode uint16") - } + assertErr(t, err) + assertEq(t, "uint16", `12`, string(bytes)) }) t.Run("uint32", func(t *testing.T) { bytes, err := json.Marshal(uint32(13)) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `13` { - t.Fatal("failed to encode uint32") - } + assertErr(t, err) + assertEq(t, "uint32", `13`, string(bytes)) }) t.Run("uint64", func(t *testing.T) { bytes, err := json.Marshal(uint64(14)) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `14` { - t.Fatal("failed to encode uint64") - } + assertErr(t, err) + assertEq(t, "uint64", `14`, string(bytes)) }) t.Run("float32", func(t *testing.T) { bytes, err := json.Marshal(float32(3.14)) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `3.14` { - t.Fatal("failed to encode float32") - } + assertErr(t, err) + assertEq(t, "float32", `3.14`, string(bytes)) }) t.Run("float64", func(t *testing.T) { bytes, err := json.Marshal(float64(3.14)) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `3.14` { - t.Fatal("failed to encode float64") - } + assertErr(t, err) + assertEq(t, "float64", `3.14`, string(bytes)) }) t.Run("bool", func(t *testing.T) { bytes, err := json.Marshal(true) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `true` { - t.Fatal("failed to encode bool") - } + assertErr(t, err) + assertEq(t, "bool", `true`, string(bytes)) }) t.Run("string", func(t *testing.T) { bytes, err := json.Marshal("hello world") - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `"hello world"` { - t.Fatal("failed to encode string") - } + assertErr(t, err) + assertEq(t, "string", `"hello world"`, string(bytes)) }) t.Run("struct", func(t *testing.T) { bytes, err := json.Marshal(struct { @@ -143,21 +101,18 @@ func Test_Encoder(t *testing.T) { B: 1, C: "hello world", }) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `{"a":-1,"b":1,"c":"hello world"}` { - t.Fatal("failed to encode struct") - } + assertErr(t, err) + assertEq(t, "struct", `{"a":-1,"b":1,"c":"hello world"}`, string(bytes)) }) t.Run("slice", func(t *testing.T) { bytes, err := json.Marshal([]int{1, 2, 3, 4}) - if err != nil { - t.Fatalf("%+v", err) - } - if string(bytes) != `[1,2,3,4]` { - t.Fatal("failed to encode slice of int") - } + assertErr(t, err) + assertEq(t, "slice", `[1,2,3,4]`, string(bytes)) + }) + t.Run("array", func(t *testing.T) { + bytes, err := json.Marshal([4]int{1, 2, 3, 4}) + assertErr(t, err) + assertEq(t, "array", `[1,2,3,4]`, string(bytes)) }) t.Run("map", func(t *testing.T) { bytes, err := json.Marshal(map[string]int{ @@ -166,11 +121,7 @@ func Test_Encoder(t *testing.T) { "c": 3, "d": 4, }) - if err != nil { - t.Fatalf("%+v", err) - } - if len(string(bytes)) != len(`{"a":1,"b":2,"c":3,"d":4}`) { - t.Fatal("failed to encode map of string/int") - } + assertErr(t, err) + assertEq(t, "map", len(`{"a":1,"b":2,"c":3,"d":4}`), len(string(bytes))) }) } diff --git a/error.go b/error.go index cf30f67..31ee265 100644 --- a/error.go +++ b/error.go @@ -3,5 +3,6 @@ package json import "errors" var ( - ErrUnknownType = errors.New("unknown type name") + ErrUnknownType = errors.New("unknown type name") + ErrCompileSlowPath = errors.New("detect dynamic type ( interface{} ) and compile with slow path") )