diff --git a/encode.go b/encode.go index 434e7b4..555d2fd 100644 --- a/encode.go +++ b/encode.go @@ -117,36 +117,42 @@ func (e *Encoder) EncodeByte(b byte) { func (e *Encoder) Encode(v interface{}) ([]byte, error) { rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr { - rv = rv.Addr() + newV := reflect.New(rv.Type()) + newV.Elem().Set(rv) + rv = newV } return e.encode(rv) } func (e *Encoder) encode(v reflect.Value) ([]byte, error) { - name := v.Type().Name() + name := v.Type().String() if op, exists := cachedEncodeOp[name]; exists { op(e, v.Pointer()) copied := make([]byte, len(e.buf)) copy(copied, e.buf) return copied, nil } - op, err := e.compile(v) + op, err := e.compile(v.Type()) if err != nil { return nil, err } - cachedEncodeOp[name] = op + if name != "" { + cachedEncodeOp[name] = op + } op(e, v.Pointer()) copied := make([]byte, len(e.buf)) copy(copied, e.buf) return copied, nil } -func (e *Encoder) compile(v reflect.Value) (EncodeOp, error) { - switch v.Type().Kind() { +func (e *Encoder) compile(typ reflect.Type) (EncodeOp, error) { + switch typ.Kind() { case reflect.Ptr: - return e.compile(v.Elem()) + return e.compile(typ.Elem()) + case reflect.Slice: + return e.compileSlice(typ) case reflect.Struct: - return e.compileStruct(v) + return e.compileStruct(typ) case reflect.Int: return e.compileInt() case reflect.Int8: @@ -157,6 +163,16 @@ func (e *Encoder) compile(v reflect.Value) (EncodeOp, error) { return e.compileInt32() case reflect.Int64: return e.compileInt64() + case reflect.Uint: + return e.compileUint() + case reflect.Uint8: + return e.compileUint8() + case reflect.Uint16: + return e.compileUint16() + case reflect.Uint32: + return e.compileUint32() + case reflect.Uint64: + return e.compileUint64() case reflect.Float32: return e.compileFloat32() case reflect.Float64: @@ -166,7 +182,7 @@ func (e *Encoder) compile(v reflect.Value) (EncodeOp, error) { case reflect.Bool: return e.compileBool() } - return nil, xerrors.Errorf("failed to compile %s: %w", v.Type(), ErrUnknownType) + return nil, xerrors.Errorf("failed to compile %s: %w", typ, ErrUnknownType) } func (e *Encoder) compileInt() (EncodeOp, error) { @@ -225,9 +241,29 @@ func (e *Encoder) compileBool() (EncodeOp, error) { return func(enc *Encoder, p uintptr) { enc.EncodeBool(e.ptrToBool(p)) }, nil } -func (e *Encoder) compileStruct(v reflect.Value) (EncodeOp, error) { - typ := v.Type() - fieldNum := v.NumField() +func (e *Encoder) compileSlice(typ reflect.Type) (EncodeOp, error) { + size := typ.Elem().Size() + op, err := e.compile(typ.Elem()) + if err != nil { + return nil, err + } + return func(enc *Encoder, base uintptr) { + enc.EncodeByte('[') + slice := (*reflect.SliceHeader)(unsafe.Pointer(base)) + num := slice.Len + for i := 0; i < num; i++ { + op(enc, slice.Data+uintptr(i)*size) + if i != num-1 { + enc.EncodeByte(',') + } + } + enc.EncodeByte(']') + }, nil + +} + +func (e *Encoder) compileStruct(typ reflect.Type) (EncodeOp, error) { + fieldNum := typ.NumField() opQueue := make([]EncodeOp, 0, fieldNum) for i := 0; i < fieldNum; i++ { field := typ.Field(i) @@ -239,7 +275,7 @@ func (e *Encoder) compileStruct(v reflect.Value) (EncodeOp, error) { keyName = opts[0] } } - op, err := e.compile(v.Field(i)) + op, err := e.compile(typ.Field(i).Type) if err != nil { return nil, err } diff --git a/encode_test.go b/encode_test.go new file mode 100644 index 0000000..06a27f6 --- /dev/null +++ b/encode_test.go @@ -0,0 +1,162 @@ +package json_test + +import ( + "testing" + + "github.com/goccy/go-json" +) + +func Test_Encoder(t *testing.T) { + t.Run("int", func(t *testing.T) { + bytes, err := json.Marshal(-10) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `-10` { + t.Fatal("failed to encode int") + } + }) + t.Run("int8", func(t *testing.T) { + bytes, err := json.Marshal(int8(-11)) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `-11` { + t.Fatal("failed to encode int8") + } + }) + t.Run("int16", func(t *testing.T) { + bytes, err := json.Marshal(int16(-12)) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `-12` { + t.Fatal("failed to encode int16") + } + }) + t.Run("int32", func(t *testing.T) { + bytes, err := json.Marshal(int32(-13)) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `-13` { + t.Fatal("failed to encode int32") + } + }) + t.Run("int64", func(t *testing.T) { + bytes, err := json.Marshal(int64(-14)) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `-14` { + t.Fatal("failed to encode int64") + } + }) + t.Run("uint", func(t *testing.T) { + bytes, err := json.Marshal(uint(10)) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `10` { + t.Fatal("failed to encode uint") + } + }) + t.Run("uint8", func(t *testing.T) { + bytes, err := json.Marshal(uint8(11)) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `11` { + t.Fatal("failed to encode uint8") + } + }) + t.Run("uint16", func(t *testing.T) { + bytes, err := json.Marshal(uint16(12)) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `12` { + t.Fatal("failed to encode uint16") + } + }) + t.Run("uint32", func(t *testing.T) { + bytes, err := json.Marshal(uint32(13)) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `13` { + t.Fatal("failed to encode uint32") + } + }) + t.Run("uint64", func(t *testing.T) { + bytes, err := json.Marshal(uint64(14)) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `14` { + t.Fatal("failed to encode uint64") + } + }) + t.Run("float32", func(t *testing.T) { + bytes, err := json.Marshal(float32(3.14)) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `3.14` { + t.Fatal("failed to encode float32") + } + }) + t.Run("float64", func(t *testing.T) { + bytes, err := json.Marshal(float64(3.14)) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `3.14` { + t.Fatal("failed to encode float64") + } + }) + t.Run("bool", func(t *testing.T) { + bytes, err := json.Marshal(true) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `true` { + t.Fatal("failed to encode bool") + } + }) + t.Run("string", func(t *testing.T) { + bytes, err := json.Marshal("hello world") + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `"hello world"` { + t.Fatal("failed to encode string") + } + }) + t.Run("struct", func(t *testing.T) { + bytes, err := json.Marshal(struct { + A int `json:"a"` + B uint `json:"b"` + C string `json:"c"` + }{ + A: -1, + B: 1, + C: "hello world", + }) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `{"a":-1,"b":1,"c":"hello world"}` { + t.Fatal("failed to encode struct") + } + }) + t.Run("slice", func(t *testing.T) { + bytes, err := json.Marshal([]int{1, 2, 3, 4}) + if err != nil { + t.Fatalf("%+v", err) + } + if string(bytes) != `[1,2,3,4]` { + t.Fatal("failed to encode slice of int") + } + }) +}