Support encoding for Array type

This commit is contained in:
Masaaki Goshima 2020-04-21 13:19:53 +09:00
parent ed5ee07fdf
commit a573d86121
3 changed files with 118 additions and 110 deletions

View File

@ -142,6 +142,9 @@ func (e *Encoder) encode(v reflect.Value, ptr unsafe.Pointer) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if op == nil {
return nil, nil
}
if name != "" { if name != "" {
cachedEncodeOp[name] = op cachedEncodeOp[name] = op
} }
@ -161,6 +164,8 @@ func (e *Encoder) compile(v reflect.Value) (EncodeOp, error) {
return e.compileStruct(v) return e.compileStruct(v)
case reflect.Map: case reflect.Map:
return e.compileMap(v) return e.compileMap(v)
case reflect.Array:
return e.compileArray(v)
case reflect.Int: case reflect.Int:
return e.compileInt() return e.compileInt()
case reflect.Int8: case reflect.Int8:
@ -189,6 +194,20 @@ func (e *Encoder) compile(v reflect.Value) (EncodeOp, error) {
return e.compileString() return e.compileString()
case reflect.Bool: case reflect.Bool:
return e.compileBool() return e.compileBool()
case reflect.Interface:
return nil, ErrCompileSlowPath
case reflect.Func:
return nil, nil
case reflect.Chan:
return nil, nil
case reflect.UnsafePointer:
return nil, nil
case reflect.Uintptr:
return nil, nil
case reflect.Complex64:
return nil, nil
case reflect.Complex128:
return nil, nil
} }
return nil, xerrors.Errorf("failed to compile %s: %w", v.Type(), ErrUnknownType) return nil, xerrors.Errorf("failed to compile %s: %w", v.Type(), ErrUnknownType)
} }
@ -198,6 +217,9 @@ func (e *Encoder) compilePtr(v reflect.Value) (EncodeOp, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if op == nil {
return nil, nil
}
return func(enc *Encoder, p uintptr) { return func(enc *Encoder, p uintptr) {
op(enc, e.ptrToPtr(p)) op(enc, e.ptrToPtr(p))
}, nil }, nil
@ -259,17 +281,16 @@ func (e *Encoder) compileBool() (EncodeOp, error) {
return func(enc *Encoder, p uintptr) { enc.EncodeBool(e.ptrToBool(p)) }, nil return func(enc *Encoder, p uintptr) { enc.EncodeBool(e.ptrToBool(p)) }, nil
} }
func (e *Encoder) zeroValue(typ reflect.Type) reflect.Value {
return reflect.New(typ).Elem()
}
func (e *Encoder) compileSlice(v reflect.Value) (EncodeOp, error) { func (e *Encoder) compileSlice(v reflect.Value) (EncodeOp, error) {
typ := v.Type() typ := v.Type()
size := typ.Elem().Size() size := typ.Elem().Size()
op, err := e.compile(e.zeroValue(typ.Elem())) op, err := e.compile(reflect.Zero(typ.Elem()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if op == nil {
return nil, nil
}
return func(enc *Encoder, base uintptr) { return func(enc *Encoder, base uintptr) {
if base == 0 { if base == 0 {
enc.EncodeString("null") enc.EncodeString("null")
@ -286,7 +307,33 @@ func (e *Encoder) compileSlice(v reflect.Value) (EncodeOp, error) {
} }
enc.EncodeByte(']') enc.EncodeByte(']')
}, nil }, nil
}
func (e *Encoder) compileArray(v reflect.Value) (EncodeOp, error) {
typ := v.Type()
alen := typ.Len()
size := typ.Elem().Size()
op, err := e.compile(reflect.Zero(typ.Elem()))
if err != nil {
return nil, err
}
if op == nil {
return nil, nil
}
return func(enc *Encoder, base uintptr) {
if base == 0 {
enc.EncodeString("null")
return
}
enc.EncodeByte('[')
for i := 0; i < alen; i++ {
if i != 0 {
enc.EncodeByte(',')
}
op(enc, base+uintptr(i)*size)
}
enc.EncodeByte(']')
}, nil
} }
func (e *Encoder) compileStruct(v reflect.Value) (EncodeOp, error) { func (e *Encoder) compileStruct(v reflect.Value) (EncodeOp, error) {
@ -307,6 +354,9 @@ func (e *Encoder) compileStruct(v reflect.Value) (EncodeOp, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if op == nil {
continue
}
key := fmt.Sprintf(`"%s":`, keyName) key := fmt.Sprintf(`"%s":`, keyName)
opQueue = append(opQueue, func(enc *Encoder, base uintptr) { opQueue = append(opQueue, func(enc *Encoder, base uintptr) {
enc.EncodeString(key) enc.EncodeString(key)
@ -352,14 +402,20 @@ type valueType struct {
func (e *Encoder) compileMap(v reflect.Value) (EncodeOp, error) { func (e *Encoder) compileMap(v reflect.Value) (EncodeOp, error) {
mapType := (*valueType)(unsafe.Pointer(&v)).typ mapType := (*valueType)(unsafe.Pointer(&v)).typ
keyOp, err := e.compile(e.zeroValue(v.Type().Key())) keyOp, err := e.compile(reflect.Zero(v.Type().Key()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
valueOp, err := e.compile(e.zeroValue(v.Type().Elem())) if keyOp == nil {
return nil, nil
}
valueOp, err := e.compile(reflect.Zero(v.Type().Elem()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if valueOp == nil {
return nil, nil
}
return func(enc *Encoder, base uintptr) { return func(enc *Encoder, base uintptr) {
if base == 0 { if base == 0 {
enc.EncodeString("null") enc.EncodeString("null")

View File

@ -6,132 +6,90 @@ import (
"github.com/goccy/go-json" "github.com/goccy/go-json"
) )
func assertErr(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("%+v", err)
}
}
func assertEq(t *testing.T, msg string, exp interface{}, act interface{}) {
t.Helper()
if exp != act {
t.Fatalf("failed to encode %s. exp=[%v] but act=[%v]", msg, exp, act)
}
}
func Test_Encoder(t *testing.T) { func Test_Encoder(t *testing.T) {
t.Run("int", func(t *testing.T) { t.Run("int", func(t *testing.T) {
bytes, err := json.Marshal(-10) bytes, err := json.Marshal(-10)
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "int", `-10`, string(bytes))
}
if string(bytes) != `-10` {
t.Fatal("failed to encode int")
}
}) })
t.Run("int8", func(t *testing.T) { t.Run("int8", func(t *testing.T) {
bytes, err := json.Marshal(int8(-11)) bytes, err := json.Marshal(int8(-11))
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "int8", `-11`, string(bytes))
}
if string(bytes) != `-11` {
t.Fatal("failed to encode int8")
}
}) })
t.Run("int16", func(t *testing.T) { t.Run("int16", func(t *testing.T) {
bytes, err := json.Marshal(int16(-12)) bytes, err := json.Marshal(int16(-12))
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "int16", `-12`, string(bytes))
}
if string(bytes) != `-12` {
t.Fatal("failed to encode int16")
}
}) })
t.Run("int32", func(t *testing.T) { t.Run("int32", func(t *testing.T) {
bytes, err := json.Marshal(int32(-13)) bytes, err := json.Marshal(int32(-13))
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "int32", `-13`, string(bytes))
}
if string(bytes) != `-13` {
t.Fatal("failed to encode int32")
}
}) })
t.Run("int64", func(t *testing.T) { t.Run("int64", func(t *testing.T) {
bytes, err := json.Marshal(int64(-14)) bytes, err := json.Marshal(int64(-14))
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "int64", `-14`, string(bytes))
}
if string(bytes) != `-14` {
t.Fatal("failed to encode int64")
}
}) })
t.Run("uint", func(t *testing.T) { t.Run("uint", func(t *testing.T) {
bytes, err := json.Marshal(uint(10)) bytes, err := json.Marshal(uint(10))
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "uint", `10`, string(bytes))
}
if string(bytes) != `10` {
t.Fatal("failed to encode uint")
}
}) })
t.Run("uint8", func(t *testing.T) { t.Run("uint8", func(t *testing.T) {
bytes, err := json.Marshal(uint8(11)) bytes, err := json.Marshal(uint8(11))
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "uint8", `11`, string(bytes))
}
if string(bytes) != `11` {
t.Fatal("failed to encode uint8")
}
}) })
t.Run("uint16", func(t *testing.T) { t.Run("uint16", func(t *testing.T) {
bytes, err := json.Marshal(uint16(12)) bytes, err := json.Marshal(uint16(12))
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "uint16", `12`, string(bytes))
}
if string(bytes) != `12` {
t.Fatal("failed to encode uint16")
}
}) })
t.Run("uint32", func(t *testing.T) { t.Run("uint32", func(t *testing.T) {
bytes, err := json.Marshal(uint32(13)) bytes, err := json.Marshal(uint32(13))
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "uint32", `13`, string(bytes))
}
if string(bytes) != `13` {
t.Fatal("failed to encode uint32")
}
}) })
t.Run("uint64", func(t *testing.T) { t.Run("uint64", func(t *testing.T) {
bytes, err := json.Marshal(uint64(14)) bytes, err := json.Marshal(uint64(14))
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "uint64", `14`, string(bytes))
}
if string(bytes) != `14` {
t.Fatal("failed to encode uint64")
}
}) })
t.Run("float32", func(t *testing.T) { t.Run("float32", func(t *testing.T) {
bytes, err := json.Marshal(float32(3.14)) bytes, err := json.Marshal(float32(3.14))
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "float32", `3.14`, string(bytes))
}
if string(bytes) != `3.14` {
t.Fatal("failed to encode float32")
}
}) })
t.Run("float64", func(t *testing.T) { t.Run("float64", func(t *testing.T) {
bytes, err := json.Marshal(float64(3.14)) bytes, err := json.Marshal(float64(3.14))
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "float64", `3.14`, string(bytes))
}
if string(bytes) != `3.14` {
t.Fatal("failed to encode float64")
}
}) })
t.Run("bool", func(t *testing.T) { t.Run("bool", func(t *testing.T) {
bytes, err := json.Marshal(true) bytes, err := json.Marshal(true)
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "bool", `true`, string(bytes))
}
if string(bytes) != `true` {
t.Fatal("failed to encode bool")
}
}) })
t.Run("string", func(t *testing.T) { t.Run("string", func(t *testing.T) {
bytes, err := json.Marshal("hello world") bytes, err := json.Marshal("hello world")
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "string", `"hello world"`, string(bytes))
}
if string(bytes) != `"hello world"` {
t.Fatal("failed to encode string")
}
}) })
t.Run("struct", func(t *testing.T) { t.Run("struct", func(t *testing.T) {
bytes, err := json.Marshal(struct { bytes, err := json.Marshal(struct {
@ -143,21 +101,18 @@ func Test_Encoder(t *testing.T) {
B: 1, B: 1,
C: "hello world", C: "hello world",
}) })
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "struct", `{"a":-1,"b":1,"c":"hello world"}`, string(bytes))
}
if string(bytes) != `{"a":-1,"b":1,"c":"hello world"}` {
t.Fatal("failed to encode struct")
}
}) })
t.Run("slice", func(t *testing.T) { t.Run("slice", func(t *testing.T) {
bytes, err := json.Marshal([]int{1, 2, 3, 4}) bytes, err := json.Marshal([]int{1, 2, 3, 4})
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "slice", `[1,2,3,4]`, string(bytes))
} })
if string(bytes) != `[1,2,3,4]` { t.Run("array", func(t *testing.T) {
t.Fatal("failed to encode slice of int") bytes, err := json.Marshal([4]int{1, 2, 3, 4})
} assertErr(t, err)
assertEq(t, "array", `[1,2,3,4]`, string(bytes))
}) })
t.Run("map", func(t *testing.T) { t.Run("map", func(t *testing.T) {
bytes, err := json.Marshal(map[string]int{ bytes, err := json.Marshal(map[string]int{
@ -166,11 +121,7 @@ func Test_Encoder(t *testing.T) {
"c": 3, "c": 3,
"d": 4, "d": 4,
}) })
if err != nil { assertErr(t, err)
t.Fatalf("%+v", err) assertEq(t, "map", len(`{"a":1,"b":2,"c":3,"d":4}`), len(string(bytes)))
}
if len(string(bytes)) != len(`{"a":1,"b":2,"c":3,"d":4}`) {
t.Fatal("failed to encode map of string/int")
}
}) })
} }

View File

@ -4,4 +4,5 @@ import "errors"
var ( var (
ErrUnknownType = errors.New("unknown type name") ErrUnknownType = errors.New("unknown type name")
ErrCompileSlowPath = errors.New("detect dynamic type ( interface{} ) and compile with slow path")
) )