diff --git a/Makefile b/Makefile index 7a10530..5bbfc4c 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ cover-html: cover .PHONY: lint lint: golangci-lint - golangci-lint run + $(BIN_DIR)/golangci-lint run golangci-lint: | $(BIN_DIR) @{ \ diff --git a/benchmarks/path_test.go b/benchmarks/path_test.go new file mode 100644 index 0000000..3224b65 --- /dev/null +++ b/benchmarks/path_test.go @@ -0,0 +1,24 @@ +package benchmark + +import ( + "testing" + + gojson "github.com/goccy/go-json" +) + +func Benchmark_Decode_SmallStruct_UnmarshalPath_GoJson(b *testing.B) { + path, err := gojson.CreatePath("$.st") + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + var v int + if err := path.Unmarshal(SmallFixture, &v); err != nil { + b.Fatal(err) + } + if v != 1 { + b.Fatal("failed to unmarshal path") + } + } +} diff --git a/decode.go b/decode.go index d99749d..b1512ce 100644 --- a/decode.go +++ b/decode.go @@ -83,6 +83,34 @@ func unmarshalContext(ctx context.Context, data []byte, v interface{}, optFuncs return validateEndBuf(src, cursor) } +var ( + pathDecoder = decoder.NewPathDecoder() +) + +func extractFromPath(path *Path, data []byte, optFuncs ...DecodeOptionFunc) ([][]byte, error) { + src := make([]byte, len(data)+1) // append nul byte to the end + copy(src, data) + + ctx := decoder.TakeRuntimeContext() + ctx.Buf = src + ctx.Option.Flags = 0 + ctx.Option.Flags |= decoder.PathOption + ctx.Option.Path = path.path + for _, optFunc := range optFuncs { + optFunc(ctx.Option) + } + paths, cursor, err := pathDecoder.DecodePath(ctx, 0, 0) + if err != nil { + decoder.ReleaseRuntimeContext(ctx) + return nil, err + } + decoder.ReleaseRuntimeContext(ctx) + if err := validateEndBuf(src, cursor); err != nil { + return nil, err + } + return paths, nil +} + func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error { src := make([]byte, len(data)+1) // append nul byte to the end copy(src, data) diff --git a/error.go b/error.go index 94c1339..5b2dcee 100644 --- a/error.go +++ b/error.go @@ -37,3 +37,5 @@ type UnmarshalTypeError = errors.UnmarshalTypeError type UnsupportedTypeError = errors.UnsupportedTypeError type UnsupportedValueError = errors.UnsupportedValueError + +type PathError = errors.PathError diff --git a/internal/decoder/anonymous_field.go b/internal/decoder/anonymous_field.go index 030cb7a..b6876cf 100644 --- a/internal/decoder/anonymous_field.go +++ b/internal/decoder/anonymous_field.go @@ -35,3 +35,7 @@ func (d *anonymousFieldDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p = *(*unsafe.Pointer)(p) return d.dec.Decode(ctx, cursor, depth, unsafe.Pointer(uintptr(p)+d.offset)) } + +func (d *anonymousFieldDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return d.dec.DecodePath(ctx, cursor, depth) +} diff --git a/internal/decoder/array.go b/internal/decoder/array.go index 21f1fd5..8ef91cf 100644 --- a/internal/decoder/array.go +++ b/internal/decoder/array.go @@ -1,6 +1,7 @@ package decoder import ( + "fmt" "unsafe" "github.com/goccy/go-json/internal/errors" @@ -167,3 +168,7 @@ func (d *arrayDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe } } } + +func (d *arrayDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, fmt.Errorf("json: array decoder does not support decode path") +} diff --git a/internal/decoder/assign.go b/internal/decoder/assign.go new file mode 100644 index 0000000..c53e6ad --- /dev/null +++ b/internal/decoder/assign.go @@ -0,0 +1,438 @@ +package decoder + +import ( + "fmt" + "reflect" + "strconv" +) + +var ( + nilValue = reflect.ValueOf(nil) +) + +func AssignValue(src, dst reflect.Value) error { + if dst.Type().Kind() != reflect.Ptr { + return fmt.Errorf("invalid dst type. required pointer type: %T", dst.Type()) + } + casted, err := castValue(dst.Elem().Type(), src) + if err != nil { + return err + } + dst.Elem().Set(casted) + return nil +} + +func castValue(t reflect.Type, v reflect.Value) (reflect.Value, error) { + switch t.Kind() { + case reflect.Int: + vv, err := castInt(v) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(int(vv.Int())), nil + case reflect.Int8: + vv, err := castInt(v) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(int8(vv.Int())), nil + case reflect.Int16: + vv, err := castInt(v) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(int16(vv.Int())), nil + case reflect.Int32: + vv, err := castInt(v) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(int32(vv.Int())), nil + case reflect.Int64: + return castInt(v) + case reflect.Uint: + vv, err := castUint(v) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(uint(vv.Uint())), nil + case reflect.Uint8: + vv, err := castUint(v) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(uint8(vv.Uint())), nil + case reflect.Uint16: + vv, err := castUint(v) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(uint16(vv.Uint())), nil + case reflect.Uint32: + vv, err := castUint(v) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(uint32(vv.Uint())), nil + case reflect.Uint64: + return castUint(v) + case reflect.Uintptr: + vv, err := castUint(v) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(uintptr(vv.Uint())), nil + case reflect.String: + return castString(v) + case reflect.Bool: + return castBool(v) + case reflect.Float32: + vv, err := castFloat(v) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(float32(vv.Float())), nil + case reflect.Float64: + return castFloat(v) + case reflect.Array: + return castArray(t, v) + case reflect.Slice: + return castSlice(t, v) + case reflect.Map: + return castMap(t, v) + case reflect.Struct: + return castStruct(t, v) + } + return v, nil +} + +func castInt(v reflect.Value) (reflect.Value, error) { + switch v.Type().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return reflect.ValueOf(int64(v.Uint())), nil + case reflect.String: + i64, err := strconv.ParseInt(v.String(), 10, 64) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(i64), nil + case reflect.Bool: + if v.Bool() { + return reflect.ValueOf(int64(1)), nil + } + return reflect.ValueOf(int64(0)), nil + case reflect.Float32, reflect.Float64: + return reflect.ValueOf(int64(v.Float())), nil + case reflect.Array: + if v.Len() > 0 { + return castInt(v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to int64 from empty array") + case reflect.Slice: + if v.Len() > 0 { + return castInt(v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to int64 from empty slice") + case reflect.Interface: + return castInt(reflect.ValueOf(v.Interface())) + case reflect.Map: + return nilValue, fmt.Errorf("failed to cast to int64 from map") + case reflect.Struct: + return nilValue, fmt.Errorf("failed to cast to int64 from struct") + case reflect.Ptr: + return castInt(v.Elem()) + } + return nilValue, fmt.Errorf("failed to cast to int64 from %s", v.Type().Kind()) +} + +func castUint(v reflect.Value) (reflect.Value, error) { + switch v.Type().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return reflect.ValueOf(uint64(v.Int())), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v, nil + case reflect.String: + u64, err := strconv.ParseUint(v.String(), 10, 64) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(u64), nil + case reflect.Bool: + if v.Bool() { + return reflect.ValueOf(uint64(1)), nil + } + return reflect.ValueOf(uint64(0)), nil + case reflect.Float32, reflect.Float64: + return reflect.ValueOf(uint64(v.Float())), nil + case reflect.Array: + if v.Len() > 0 { + return castUint(v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to uint64 from empty array") + case reflect.Slice: + if v.Len() > 0 { + return castUint(v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to uint64 from empty slice") + case reflect.Interface: + return castUint(reflect.ValueOf(v.Interface())) + case reflect.Map: + return nilValue, fmt.Errorf("failed to cast to uint64 from map") + case reflect.Struct: + return nilValue, fmt.Errorf("failed to cast to uint64 from struct") + case reflect.Ptr: + return castUint(v.Elem()) + } + return nilValue, fmt.Errorf("failed to cast to uint64 from %s", v.Type().Kind()) +} + +func castString(v reflect.Value) (reflect.Value, error) { + switch v.Type().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return reflect.ValueOf(fmt.Sprint(v.Int())), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return reflect.ValueOf(fmt.Sprint(v.Uint())), nil + case reflect.String: + return v, nil + case reflect.Bool: + if v.Bool() { + return reflect.ValueOf("true"), nil + } + return reflect.ValueOf("false"), nil + case reflect.Float32, reflect.Float64: + return reflect.ValueOf(fmt.Sprint(v.Float())), nil + case reflect.Array: + if v.Len() > 0 { + return castString(v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to string from empty array") + case reflect.Slice: + if v.Len() > 0 { + return castString(v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to string from empty slice") + case reflect.Interface: + return castString(reflect.ValueOf(v.Interface())) + case reflect.Map: + return nilValue, fmt.Errorf("failed to cast to string from map") + case reflect.Struct: + return nilValue, fmt.Errorf("failed to cast to string from struct") + case reflect.Ptr: + return castString(v.Elem()) + } + return nilValue, fmt.Errorf("failed to cast to string from %s", v.Type().Kind()) +} + +func castBool(v reflect.Value) (reflect.Value, error) { + switch v.Type().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch v.Int() { + case 0: + return reflect.ValueOf(false), nil + case 1: + return reflect.ValueOf(true), nil + } + return nilValue, fmt.Errorf("failed to cast to bool from %d", v.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + switch v.Uint() { + case 0: + return reflect.ValueOf(false), nil + case 1: + return reflect.ValueOf(true), nil + } + return nilValue, fmt.Errorf("failed to cast to bool from %d", v.Uint()) + case reflect.String: + b, err := strconv.ParseBool(v.String()) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(b), nil + case reflect.Bool: + return v, nil + case reflect.Float32, reflect.Float64: + switch v.Float() { + case 0: + return reflect.ValueOf(false), nil + case 1: + return reflect.ValueOf(true), nil + } + return nilValue, fmt.Errorf("failed to cast to bool from %f", v.Float()) + case reflect.Array: + if v.Len() > 0 { + return castBool(v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to string from empty array") + case reflect.Slice: + if v.Len() > 0 { + return castBool(v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to string from empty slice") + case reflect.Interface: + return castBool(reflect.ValueOf(v.Interface())) + case reflect.Map: + return nilValue, fmt.Errorf("failed to cast to string from map") + case reflect.Struct: + return nilValue, fmt.Errorf("failed to cast to string from struct") + case reflect.Ptr: + return castBool(v.Elem()) + } + return nilValue, fmt.Errorf("failed to cast to bool from %s", v.Type().Kind()) +} + +func castFloat(v reflect.Value) (reflect.Value, error) { + switch v.Type().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return reflect.ValueOf(float64(v.Int())), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return reflect.ValueOf(float64(v.Uint())), nil + case reflect.String: + f64, err := strconv.ParseFloat(v.String(), 64) + if err != nil { + return nilValue, err + } + return reflect.ValueOf(f64), nil + case reflect.Bool: + if v.Bool() { + return reflect.ValueOf(float64(1)), nil + } + return reflect.ValueOf(float64(0)), nil + case reflect.Float32, reflect.Float64: + return v, nil + case reflect.Array: + if v.Len() > 0 { + return castFloat(v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to float64 from empty array") + case reflect.Slice: + if v.Len() > 0 { + return castFloat(v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to float64 from empty slice") + case reflect.Interface: + return castFloat(reflect.ValueOf(v.Interface())) + case reflect.Map: + return nilValue, fmt.Errorf("failed to cast to float64 from map") + case reflect.Struct: + return nilValue, fmt.Errorf("failed to cast to float64 from struct") + case reflect.Ptr: + return castFloat(v.Elem()) + } + return nilValue, fmt.Errorf("failed to cast to float64 from %s", v.Type().Kind()) +} + +func castArray(t reflect.Type, v reflect.Value) (reflect.Value, error) { + kind := v.Type().Kind() + if kind == reflect.Interface { + return castArray(t, reflect.ValueOf(v.Interface())) + } + if kind != reflect.Slice && kind != reflect.Array { + return nilValue, fmt.Errorf("failed to cast to array from %s", kind) + } + if t.Elem() == v.Type().Elem() { + return v, nil + } + if t.Len() != v.Len() { + return nilValue, fmt.Errorf("failed to cast [%d]array from slice of %d length", t.Len(), v.Len()) + } + ret := reflect.New(t).Elem() + for i := 0; i < v.Len(); i++ { + vv, err := castValue(t.Elem(), v.Index(i)) + if err != nil { + return nilValue, err + } + ret.Index(i).Set(vv) + } + return ret, nil +} + +func castSlice(t reflect.Type, v reflect.Value) (reflect.Value, error) { + kind := v.Type().Kind() + if kind == reflect.Interface { + return castSlice(t, reflect.ValueOf(v.Interface())) + } + if kind != reflect.Slice && kind != reflect.Array { + return nilValue, fmt.Errorf("failed to cast to slice from %s", kind) + } + if t.Elem() == v.Type().Elem() { + return v, nil + } + ret := reflect.MakeSlice(t, v.Len(), v.Len()) + for i := 0; i < v.Len(); i++ { + vv, err := castValue(t.Elem(), v.Index(i)) + if err != nil { + return nilValue, err + } + ret.Index(i).Set(vv) + } + return ret, nil +} + +func castMap(t reflect.Type, v reflect.Value) (reflect.Value, error) { + ret := reflect.MakeMap(t) + switch v.Type().Kind() { + case reflect.Map: + iter := v.MapRange() + for iter.Next() { + key, err := castValue(t.Key(), iter.Key()) + if err != nil { + return nilValue, err + } + value, err := castValue(t.Elem(), iter.Value()) + if err != nil { + return nilValue, err + } + ret.SetMapIndex(key, value) + } + return ret, nil + case reflect.Interface: + return castMap(t, reflect.ValueOf(v.Interface())) + case reflect.Slice: + if v.Len() > 0 { + return castMap(t, v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to map from empty slice") + } + return nilValue, fmt.Errorf("failed to cast to map from %s", v.Type().Kind()) +} + +func castStruct(t reflect.Type, v reflect.Value) (reflect.Value, error) { + ret := reflect.New(t).Elem() + switch v.Type().Kind() { + case reflect.Map: + iter := v.MapRange() + for iter.Next() { + key := iter.Key() + k, err := castString(key) + if err != nil { + return nilValue, err + } + fieldName := k.String() + field, ok := t.FieldByName(fieldName) + if ok { + value, err := castValue(field.Type, iter.Value()) + if err != nil { + return nilValue, err + } + ret.FieldByName(fieldName).Set(value) + } + } + return ret, nil + case reflect.Struct: + for i := 0; i < v.Type().NumField(); i++ { + name := v.Type().Field(i).Name + ret.FieldByName(name).Set(v.FieldByName(name)) + } + return ret, nil + case reflect.Interface: + return castStruct(t, reflect.ValueOf(v.Interface())) + case reflect.Slice: + if v.Len() > 0 { + return castStruct(t, v.Index(0)) + } + return nilValue, fmt.Errorf("failed to cast to struct from empty slice") + default: + return nilValue, fmt.Errorf("failed to cast to struct from %s", v.Type().Kind()) + } +} diff --git a/internal/decoder/bool.go b/internal/decoder/bool.go index 455042a..ba6cf5b 100644 --- a/internal/decoder/bool.go +++ b/internal/decoder/bool.go @@ -1,6 +1,7 @@ package decoder import ( + "fmt" "unsafe" "github.com/goccy/go-json/internal/errors" @@ -76,3 +77,7 @@ func (d *boolDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe. } return 0, errors.ErrUnexpectedEndOfJSON("bool", cursor) } + +func (d *boolDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, fmt.Errorf("json: bool decoder does not support decode path") +} diff --git a/internal/decoder/bytes.go b/internal/decoder/bytes.go index 92c7dcf..939bf43 100644 --- a/internal/decoder/bytes.go +++ b/internal/decoder/bytes.go @@ -2,6 +2,7 @@ package decoder import ( "encoding/base64" + "fmt" "unsafe" "github.com/goccy/go-json/internal/errors" @@ -78,6 +79,10 @@ func (d *bytesDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe return cursor, nil } +func (d *bytesDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, fmt.Errorf("json: []byte decoder does not support decode path") +} + func (d *bytesDecoder) decodeStreamBinary(s *Stream, depth int64, p unsafe.Pointer) ([]byte, error) { c := s.skipWhiteSpace() if c == '[' { diff --git a/internal/decoder/float.go b/internal/decoder/float.go index dfb7168..9b2eb8b 100644 --- a/internal/decoder/float.go +++ b/internal/decoder/float.go @@ -156,3 +156,15 @@ func (d *floatDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe d.op(p, f64) return cursor, nil } + +func (d *floatDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + buf := ctx.Buf + bytes, c, err := d.decodeByte(buf, cursor) + if err != nil { + return nil, 0, err + } + if bytes == nil { + return [][]byte{nullbytes}, c, nil + } + return [][]byte{bytes}, c, nil +} diff --git a/internal/decoder/func.go b/internal/decoder/func.go index ee35637..4cc12ca 100644 --- a/internal/decoder/func.go +++ b/internal/decoder/func.go @@ -2,6 +2,7 @@ package decoder import ( "bytes" + "fmt" "unsafe" "github.com/goccy/go-json/internal/errors" @@ -139,3 +140,7 @@ func (d *funcDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe. } return cursor, errors.ErrInvalidBeginningOfValue(buf[cursor], cursor) } + +func (d *funcDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, fmt.Errorf("json: func decoder does not support decode path") +} diff --git a/internal/decoder/int.go b/internal/decoder/int.go index 509b753..1a7f081 100644 --- a/internal/decoder/int.go +++ b/internal/decoder/int.go @@ -240,3 +240,7 @@ func (d *intDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P d.op(p, i64) return cursor, nil } + +func (d *intDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, fmt.Errorf("json: int decoder does not support decode path") +} diff --git a/internal/decoder/interface.go b/internal/decoder/interface.go index 4dbb4be..45c69ab 100644 --- a/internal/decoder/interface.go +++ b/internal/decoder/interface.go @@ -94,6 +94,7 @@ func (d *interfaceDecoder) numDecoder(s *Stream) Decoder { var ( emptyInterfaceType = runtime.Type2RType(reflect.TypeOf((*interface{})(nil)).Elem()) + EmptyInterfaceType = emptyInterfaceType interfaceMapType = runtime.Type2RType( reflect.TypeOf((*map[string]interface{})(nil)).Elem(), ) @@ -456,3 +457,72 @@ func (d *interfaceDecoder) decodeEmptyInterface(ctx *RuntimeContext, cursor, dep } return cursor, errors.ErrInvalidBeginningOfValue(buf[cursor], cursor) } + +func NewPathDecoder() Decoder { + ifaceDecoder := &interfaceDecoder{ + typ: emptyInterfaceType, + structName: "", + fieldName: "", + floatDecoder: newFloatDecoder("", "", func(p unsafe.Pointer, v float64) { + *(*interface{})(p) = v + }), + numberDecoder: newNumberDecoder("", "", func(p unsafe.Pointer, v json.Number) { + *(*interface{})(p) = v + }), + stringDecoder: newStringDecoder("", ""), + } + ifaceDecoder.sliceDecoder = newSliceDecoder( + ifaceDecoder, + emptyInterfaceType, + emptyInterfaceType.Size(), + "", "", + ) + ifaceDecoder.mapDecoder = newMapDecoder( + interfaceMapType, + stringType, + ifaceDecoder.stringDecoder, + interfaceMapType.Elem(), + ifaceDecoder, + "", "", + ) + return ifaceDecoder +} + +var ( + truebytes = []byte("true") + falsebytes = []byte("false") +) + +func (d *interfaceDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + buf := ctx.Buf + cursor = skipWhiteSpace(buf, cursor) + switch buf[cursor] { + case '{': + return d.mapDecoder.DecodePath(ctx, cursor, depth) + case '[': + return d.sliceDecoder.DecodePath(ctx, cursor, depth) + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + return d.floatDecoder.DecodePath(ctx, cursor, depth) + case '"': + return d.stringDecoder.DecodePath(ctx, cursor, depth) + case 't': + if err := validateTrue(buf, cursor); err != nil { + return nil, 0, err + } + cursor += 4 + return [][]byte{truebytes}, cursor, nil + case 'f': + if err := validateFalse(buf, cursor); err != nil { + return nil, 0, err + } + cursor += 5 + return [][]byte{falsebytes}, cursor, nil + case 'n': + if err := validateNull(buf, cursor); err != nil { + return nil, 0, err + } + cursor += 4 + return [][]byte{nullbytes}, cursor, nil + } + return nil, cursor, errors.ErrInvalidBeginningOfValue(buf[cursor], cursor) +} diff --git a/internal/decoder/invalid.go b/internal/decoder/invalid.go index 1ef50a7..4c9721b 100644 --- a/internal/decoder/invalid.go +++ b/internal/decoder/invalid.go @@ -43,3 +43,13 @@ func (d *invalidDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsa Field: d.fieldName, } } + +func (d *invalidDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, &errors.UnmarshalTypeError{ + Value: "object", + Type: runtime.RType2Type(d.typ), + Offset: cursor, + Struct: d.structName, + Field: d.fieldName, + } +} diff --git a/internal/decoder/map.go b/internal/decoder/map.go index cb55ef0..4ee8e5c 100644 --- a/internal/decoder/map.go +++ b/internal/decoder/map.go @@ -185,3 +185,95 @@ func (d *mapDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P cursor++ } } + +func (d *mapDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + buf := ctx.Buf + depth++ + if depth > maxDecodeNestingDepth { + return nil, 0, errors.ErrExceededMaxDepth(buf[cursor], cursor) + } + + cursor = skipWhiteSpace(buf, cursor) + buflen := int64(len(buf)) + if buflen < 2 { + return nil, 0, errors.ErrExpected("{} for map", cursor) + } + switch buf[cursor] { + case 'n': + if err := validateNull(buf, cursor); err != nil { + return nil, 0, err + } + cursor += 4 + return [][]byte{nullbytes}, cursor, nil + case '{': + default: + return nil, 0, errors.ErrExpected("{ character for map value", cursor) + } + cursor++ + cursor = skipWhiteSpace(buf, cursor) + if buf[cursor] == '}' { + cursor++ + return nil, cursor, nil + } + keyDecoder, ok := d.keyDecoder.(*stringDecoder) + if !ok { + return nil, 0, &errors.UnmarshalTypeError{ + Value: "string", + Type: reflect.TypeOf(""), + Offset: cursor, + Struct: d.structName, + Field: d.fieldName, + } + } + ret := [][]byte{} + for { + key, keyCursor, err := keyDecoder.decodeByte(buf, cursor) + if err != nil { + return nil, 0, err + } + cursor = skipWhiteSpace(buf, keyCursor) + if buf[cursor] != ':' { + return nil, 0, errors.ErrExpected("colon after object key", cursor) + } + cursor++ + child, found, err := ctx.Option.Path.Field(string(key)) + if err != nil { + return nil, 0, err + } + if found { + if child != nil { + oldPath := ctx.Option.Path.node + ctx.Option.Path.node = child + paths, c, err := d.valueDecoder.DecodePath(ctx, cursor, depth) + if err != nil { + return nil, 0, err + } + ctx.Option.Path.node = oldPath + ret = append(ret, paths...) + cursor = c + } else { + start := cursor + end, err := skipValue(buf, cursor, depth) + if err != nil { + return nil, 0, err + } + ret = append(ret, buf[start:end]) + cursor = end + } + } else { + c, err := skipValue(buf, cursor, depth) + if err != nil { + return nil, 0, err + } + cursor = skipWhiteSpace(buf, c) + } + if buf[cursor] == '}' { + cursor++ + return ret, cursor, nil + } + if buf[cursor] != ',' { + return nil, 0, errors.ErrExpected("comma after object value", cursor) + } + cursor++ + } +} diff --git a/internal/decoder/number.go b/internal/decoder/number.go index bf63773..10e5435 100644 --- a/internal/decoder/number.go +++ b/internal/decoder/number.go @@ -51,6 +51,17 @@ func (d *numberDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf return cursor, nil } +func (d *numberDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + bytes, c, err := d.decodeByte(ctx.Buf, cursor) + if err != nil { + return nil, 0, err + } + if bytes == nil { + return [][]byte{nullbytes}, c, nil + } + return [][]byte{bytes}, c, nil +} + func (d *numberDecoder) decodeStreamByte(s *Stream) ([]byte, error) { start := s.cursor for { diff --git a/internal/decoder/option.go b/internal/decoder/option.go index e41f876..502f772 100644 --- a/internal/decoder/option.go +++ b/internal/decoder/option.go @@ -7,9 +7,11 @@ type OptionFlags uint8 const ( FirstWinOption OptionFlags = 1 << iota ContextOption + PathOption ) type Option struct { Flags OptionFlags Context context.Context + Path *Path } diff --git a/internal/decoder/path.go b/internal/decoder/path.go new file mode 100644 index 0000000..33ddcd5 --- /dev/null +++ b/internal/decoder/path.go @@ -0,0 +1,653 @@ +package decoder + +import ( + "fmt" + "reflect" + "strconv" + + "github.com/goccy/go-json/internal/errors" + "github.com/goccy/go-json/internal/runtime" +) + +type PathString string + +func (s PathString) Build() (*Path, error) { + builder := new(PathBuilder) + return builder.Build([]rune(s)) +} + +type PathBuilder struct { + root PathNode + node PathNode + singleQuotePathSelector bool + doubleQuotePathSelector bool +} + +func (b *PathBuilder) Build(buf []rune) (*Path, error) { + node, err := b.build(buf) + if err != nil { + return nil, err + } + return &Path{ + node: node, + SingleQuotePathSelector: b.singleQuotePathSelector, + DoubleQuotePathSelector: b.doubleQuotePathSelector, + }, nil +} + +func (b *PathBuilder) build(buf []rune) (PathNode, error) { + if len(buf) == 0 { + return nil, errors.ErrEmptyPath() + } + if buf[0] != '$' { + return nil, errors.ErrInvalidPath("JSON Path must start with a $ character") + } + if len(buf) == 1 { + return nil, nil + } + buf = buf[1:] + offset, err := b.buildNext(buf) + if err != nil { + return nil, err + } + if len(buf) > offset { + return nil, errors.ErrInvalidPath("remain invalid path %q", buf[offset:]) + } + return b.root, nil +} + +func (b *PathBuilder) buildNextCharIfExists(buf []rune, cursor int) (int, error) { + if len(buf) > cursor+1 { + offset, err := b.buildNext(buf[cursor+1:]) + if err != nil { + return 0, err + } + return cursor + 1 + offset, nil + } + return cursor, nil +} + +func (b *PathBuilder) buildNext(buf []rune) (int, error) { + switch buf[0] { + case '.': + if len(buf) == 1 { + return 0, errors.ErrInvalidPath("JSON Path ends with dot character") + } + offset, err := b.buildSelector(buf[1:]) + if err != nil { + return 0, err + } + return offset + 1, nil + case '[': + if len(buf) == 1 { + return 0, errors.ErrInvalidPath("JSON Path ends with left bracket character") + } + offset, err := b.buildIndex(buf[1:]) + if err != nil { + return 0, err + } + return offset + 1, nil + default: + return 0, errors.ErrInvalidPath("expect dot or left bracket character. but found %c character", buf[0]) + } +} + +func (b *PathBuilder) buildSelector(buf []rune) (int, error) { + switch buf[0] { + case '.': + if len(buf) == 1 { + return 0, errors.ErrInvalidPath("JSON Path ends with double dot character") + } + offset, err := b.buildPathRecursive(buf[1:]) + if err != nil { + return 0, err + } + return 1 + offset, nil + case '[', ']', '$', '*': + return 0, errors.ErrInvalidPath("found invalid path character %c after dot", buf[0]) + } + for cursor := 0; cursor < len(buf); cursor++ { + switch buf[cursor] { + case '$', '*', ']': + return 0, errors.ErrInvalidPath("found %c character in field selector context", buf[cursor]) + case '.': + if cursor+1 >= len(buf) { + return 0, errors.ErrInvalidPath("JSON Path ends with dot character") + } + selector := buf[:cursor] + b.addSelectorNode(string(selector)) + offset, err := b.buildSelector(buf[cursor+1:]) + if err != nil { + return 0, err + } + return cursor + 1 + offset, nil + case '[': + if cursor+1 >= len(buf) { + return 0, errors.ErrInvalidPath("JSON Path ends with left bracket character") + } + selector := buf[:cursor] + b.addSelectorNode(string(selector)) + offset, err := b.buildIndex(buf[cursor+1:]) + if err != nil { + return 0, err + } + return cursor + 1 + offset, nil + case '"': + if cursor+1 >= len(buf) { + return 0, errors.ErrInvalidPath("JSON Path ends with double quote character") + } + offset, err := b.buildQuoteSelector(buf[cursor+1:], DoubleQuotePathSelector) + if err != nil { + return 0, err + } + return cursor + 1 + offset, nil + } + } + b.addSelectorNode(string(buf)) + return len(buf), nil +} + +func (b *PathBuilder) buildQuoteSelector(buf []rune, sel QuotePathSelector) (int, error) { + switch buf[0] { + case '[', ']', '$', '.', '*', '\'', '"': + return 0, errors.ErrInvalidPath("found invalid path character %c after quote", buf[0]) + } + for cursor := 0; cursor < len(buf); cursor++ { + switch buf[cursor] { + case '\'': + if sel != SingleQuotePathSelector { + return 0, errors.ErrInvalidPath("found double quote character in field selector with single quote context") + } + if len(buf) <= cursor+1 { + return 0, errors.ErrInvalidPath("JSON Path ends with single quote character in field selector context") + } + if buf[cursor+1] != ']' { + return 0, errors.ErrInvalidPath("expect right bracket for field selector with single quote but found %c", buf[cursor+1]) + } + selector := buf[:cursor] + b.addSelectorNode(string(selector)) + return b.buildNextCharIfExists(buf, cursor+1) + case '"': + if sel != DoubleQuotePathSelector { + return 0, errors.ErrInvalidPath("found single quote character in field selector with double quote context") + } + selector := buf[:cursor] + b.addSelectorNode(string(selector)) + return b.buildNextCharIfExists(buf, cursor) + } + } + return 0, errors.ErrInvalidPath("couldn't find quote character in selector quote path context") +} + +func (b *PathBuilder) buildPathRecursive(buf []rune) (int, error) { + switch buf[0] { + case '.', '[', ']', '$', '*': + return 0, errors.ErrInvalidPath("found invalid path character %c after double dot", buf[0]) + } + for cursor := 0; cursor < len(buf); cursor++ { + switch buf[cursor] { + case '$', '*', ']': + return 0, errors.ErrInvalidPath("found %c character in field selector context", buf[cursor]) + case '.': + if cursor+1 >= len(buf) { + return 0, errors.ErrInvalidPath("JSON Path ends with dot character") + } + selector := buf[:cursor] + b.addRecursiveNode(string(selector)) + offset, err := b.buildSelector(buf[cursor+1:]) + if err != nil { + return 0, err + } + return cursor + 1 + offset, nil + case '[': + if cursor+1 >= len(buf) { + return 0, errors.ErrInvalidPath("JSON Path ends with left bracket character") + } + selector := buf[:cursor] + b.addRecursiveNode(string(selector)) + offset, err := b.buildIndex(buf[cursor+1:]) + if err != nil { + return 0, err + } + return cursor + 1 + offset, nil + } + } + b.addRecursiveNode(string(buf)) + return len(buf), nil +} + +func (b *PathBuilder) buildIndex(buf []rune) (int, error) { + switch buf[0] { + case '.', '[', ']', '$': + return 0, errors.ErrInvalidPath("found invalid path character %c after left bracket", buf[0]) + case '\'': + if len(buf) == 1 { + return 0, errors.ErrInvalidPath("JSON Path ends with single quote character") + } + offset, err := b.buildQuoteSelector(buf[1:], SingleQuotePathSelector) + if err != nil { + return 0, err + } + return 1 + offset, nil + case '*': + if len(buf) == 1 { + return 0, errors.ErrInvalidPath("JSON Path ends with star character") + } + if buf[1] != ']' { + return 0, errors.ErrInvalidPath("expect right bracket character for index all path but found %c character", buf[1]) + } + b.addIndexAllNode() + offset := len("*]") + if len(buf) > 2 { + buildOffset, err := b.buildNext(buf[2:]) + if err != nil { + return 0, err + } + return offset + buildOffset, nil + } + return offset, nil + } + + for cursor := 0; cursor < len(buf); cursor++ { + switch buf[cursor] { + case ']': + index, err := strconv.ParseInt(string(buf[:cursor]), 10, 64) + if err != nil { + return 0, errors.ErrInvalidPath("%q is unexpected index path", buf[:cursor]) + } + b.addIndexNode(int(index)) + return b.buildNextCharIfExists(buf, cursor) + } + } + return 0, errors.ErrInvalidPath("couldn't find right bracket character in index path context") +} + +func (b *PathBuilder) addIndexAllNode() { + node := newPathIndexAllNode() + if b.root == nil { + b.root = node + b.node = node + } else { + b.node = b.node.chain(node) + } +} + +func (b *PathBuilder) addRecursiveNode(selector string) { + node := newPathRecursiveNode(selector) + if b.root == nil { + b.root = node + b.node = node + } else { + b.node = b.node.chain(node) + } +} + +func (b *PathBuilder) addSelectorNode(name string) { + node := newPathSelectorNode(name) + if b.root == nil { + b.root = node + b.node = node + } else { + b.node = b.node.chain(node) + } +} + +func (b *PathBuilder) addIndexNode(idx int) { + node := newPathIndexNode(idx) + if b.root == nil { + b.root = node + b.node = node + } else { + b.node = b.node.chain(node) + } +} + +type QuotePathSelector int + +const ( + SingleQuotePathSelector QuotePathSelector = 1 + DoubleQuotePathSelector QuotePathSelector = 2 +) + +type Path struct { + node PathNode + SingleQuotePathSelector bool + DoubleQuotePathSelector bool +} + +func (p *Path) Field(sel string) (PathNode, bool, error) { + return p.node.Field(sel) +} + +func (p *Path) Get(src, dst reflect.Value) error { + return p.node.Get(src, dst) +} + +type PathNode interface { + fmt.Stringer + Index(idx int) (PathNode, bool, error) + Field(fieldName string) (PathNode, bool, error) + Get(src, dst reflect.Value) error + chain(PathNode) PathNode + target() bool + single() bool +} + +type BasePathNode struct { + child PathNode +} + +func (n *BasePathNode) chain(node PathNode) PathNode { + n.child = node + return node +} + +func (n *BasePathNode) target() bool { + return n.child == nil +} + +func (n *BasePathNode) single() bool { + return true +} + +type PathSelectorNode struct { + *BasePathNode + selector string +} + +func newPathSelectorNode(selector string) *PathSelectorNode { + return &PathSelectorNode{ + BasePathNode: &BasePathNode{}, + selector: selector, + } +} + +func (n *PathSelectorNode) Index(idx int) (PathNode, bool, error) { + return nil, false, &errors.PathError{} +} + +func (n *PathSelectorNode) Field(fieldName string) (PathNode, bool, error) { + if n.selector == fieldName { + return n.child, true, nil + } + return nil, false, nil +} + +func (n *PathSelectorNode) Get(src, dst reflect.Value) error { + switch src.Type().Kind() { + case reflect.Map: + iter := src.MapRange() + for iter.Next() { + key, ok := iter.Key().Interface().(string) + if !ok { + return fmt.Errorf("invalid map key type %T", src.Type().Key()) + } + child, found, err := n.Field(key) + if err != nil { + return err + } + if found { + if child != nil { + return child.Get(iter.Value(), dst) + } + return AssignValue(iter.Value(), dst) + } + } + case reflect.Struct: + typ := src.Type() + for i := 0; i < typ.Len(); i++ { + tag := runtime.StructTagFromField(typ.Field(i)) + child, found, err := n.Field(tag.Key) + if err != nil { + return err + } + if found { + if child != nil { + return child.Get(src.Field(i), dst) + } + return AssignValue(src.Field(i), dst) + } + } + case reflect.Ptr: + return n.Get(src.Elem(), dst) + case reflect.Interface: + return n.Get(reflect.ValueOf(src.Interface()), dst) + case reflect.Float64, reflect.String, reflect.Bool: + return AssignValue(src, dst) + } + return fmt.Errorf("failed to get %s value from %s", n.selector, src.Type()) +} + +func (n *PathSelectorNode) String() string { + s := fmt.Sprintf(".%s", n.selector) + if n.child != nil { + s += n.child.String() + } + return s +} + +type PathIndexNode struct { + *BasePathNode + selector int +} + +func newPathIndexNode(selector int) *PathIndexNode { + return &PathIndexNode{ + BasePathNode: &BasePathNode{}, + selector: selector, + } +} + +func (n *PathIndexNode) Index(idx int) (PathNode, bool, error) { + if n.selector == idx { + return n.child, true, nil + } + return nil, false, nil +} + +func (n *PathIndexNode) Field(fieldName string) (PathNode, bool, error) { + return nil, false, &errors.PathError{} +} + +func (n *PathIndexNode) Get(src, dst reflect.Value) error { + switch src.Type().Kind() { + case reflect.Array, reflect.Slice: + if src.Len() > n.selector { + if n.child != nil { + return n.child.Get(src.Index(n.selector), dst) + } + return AssignValue(src.Index(n.selector), dst) + } + case reflect.Ptr: + return n.Get(src.Elem(), dst) + case reflect.Interface: + return n.Get(reflect.ValueOf(src.Interface()), dst) + } + return fmt.Errorf("failed to get [%d] value from %s", n.selector, src.Type()) +} + +func (n *PathIndexNode) String() string { + s := fmt.Sprintf("[%d]", n.selector) + if n.child != nil { + s += n.child.String() + } + return s +} + +type PathIndexAllNode struct { + *BasePathNode +} + +func newPathIndexAllNode() *PathIndexAllNode { + return &PathIndexAllNode{ + BasePathNode: &BasePathNode{}, + } +} + +func (n *PathIndexAllNode) Index(idx int) (PathNode, bool, error) { + return n.child, true, nil +} + +func (n *PathIndexAllNode) Field(fieldName string) (PathNode, bool, error) { + return nil, false, &errors.PathError{} +} + +func (n *PathIndexAllNode) Get(src, dst reflect.Value) error { + switch src.Type().Kind() { + case reflect.Array, reflect.Slice: + var arr []interface{} + for i := 0; i < src.Len(); i++ { + var v interface{} + rv := reflect.ValueOf(&v) + if n.child != nil { + if err := n.child.Get(src.Index(i), rv); err != nil { + return err + } + } else { + if err := AssignValue(src.Index(i), rv); err != nil { + return err + } + } + arr = append(arr, v) + } + if err := AssignValue(reflect.ValueOf(arr), dst); err != nil { + return err + } + return nil + case reflect.Ptr: + return n.Get(src.Elem(), dst) + case reflect.Interface: + return n.Get(reflect.ValueOf(src.Interface()), dst) + } + return fmt.Errorf("failed to get all value from %s", src.Type()) +} + +func (n *PathIndexAllNode) String() string { + s := "[*]" + if n.child != nil { + s += n.child.String() + } + return s +} + +type PathRecursiveNode struct { + *BasePathNode + selector string +} + +func newPathRecursiveNode(selector string) *PathRecursiveNode { + node := newPathSelectorNode(selector) + return &PathRecursiveNode{ + BasePathNode: &BasePathNode{ + child: node, + }, + selector: selector, + } +} + +func (n *PathRecursiveNode) Field(fieldName string) (PathNode, bool, error) { + if n.selector == fieldName { + return n.child, true, nil + } + return nil, false, nil +} + +func (n *PathRecursiveNode) Index(_ int) (PathNode, bool, error) { + return n, true, nil +} + +func valueToSliceValue(v interface{}) []interface{} { + rv := reflect.ValueOf(v) + ret := []interface{}{} + if rv.Type().Kind() == reflect.Slice || rv.Type().Kind() == reflect.Array { + for i := 0; i < rv.Len(); i++ { + ret = append(ret, rv.Index(i).Interface()) + } + return ret + } + return []interface{}{v} +} + +func (n *PathRecursiveNode) Get(src, dst reflect.Value) error { + if n.child == nil { + return fmt.Errorf("failed to get by recursive path ..%s", n.selector) + } + var arr []interface{} + switch src.Type().Kind() { + case reflect.Map: + iter := src.MapRange() + for iter.Next() { + key, ok := iter.Key().Interface().(string) + if !ok { + return fmt.Errorf("invalid map key type %T", src.Type().Key()) + } + child, found, err := n.Field(key) + if err != nil { + return err + } + if found { + var v interface{} + rv := reflect.ValueOf(&v) + _ = child.Get(iter.Value(), rv) + arr = append(arr, valueToSliceValue(v)...) + } else { + var v interface{} + rv := reflect.ValueOf(&v) + _ = n.Get(iter.Value(), rv) + if v != nil { + arr = append(arr, valueToSliceValue(v)...) + } + } + } + _ = AssignValue(reflect.ValueOf(arr), dst) + return nil + case reflect.Struct: + typ := src.Type() + for i := 0; i < typ.Len(); i++ { + tag := runtime.StructTagFromField(typ.Field(i)) + child, found, err := n.Field(tag.Key) + if err != nil { + return err + } + if found { + var v interface{} + rv := reflect.ValueOf(&v) + _ = child.Get(src.Field(i), rv) + arr = append(arr, valueToSliceValue(v)...) + } else { + var v interface{} + rv := reflect.ValueOf(&v) + _ = n.Get(src.Field(i), rv) + if v != nil { + arr = append(arr, valueToSliceValue(v)...) + } + } + } + _ = AssignValue(reflect.ValueOf(arr), dst) + return nil + case reflect.Array, reflect.Slice: + for i := 0; i < src.Len(); i++ { + var v interface{} + rv := reflect.ValueOf(&v) + _ = n.Get(src.Index(i), rv) + if v != nil { + arr = append(arr, valueToSliceValue(v)...) + } + } + _ = AssignValue(reflect.ValueOf(arr), dst) + return nil + case reflect.Ptr: + return n.Get(src.Elem(), dst) + case reflect.Interface: + return n.Get(reflect.ValueOf(src.Interface()), dst) + } + return fmt.Errorf("failed to get %s value from %s", n.selector, src.Type()) +} + +func (n *PathRecursiveNode) String() string { + s := fmt.Sprintf("..%s", n.selector) + if n.child != nil { + s += n.child.String() + } + return s +} diff --git a/internal/decoder/ptr.go b/internal/decoder/ptr.go index 2c83b9c..de12e10 100644 --- a/internal/decoder/ptr.go +++ b/internal/decoder/ptr.go @@ -1,6 +1,7 @@ package decoder import ( + "fmt" "unsafe" "github.com/goccy/go-json/internal/runtime" @@ -34,6 +35,10 @@ func (d *ptrDecoder) contentDecoder() Decoder { //go:linkname unsafe_New reflect.unsafe_New func unsafe_New(*runtime.Type) unsafe.Pointer +func UnsafeNew(t *runtime.Type) unsafe.Pointer { + return unsafe_New(t) +} + func (d *ptrDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) error { if s.skipWhiteSpace() == nul { s.read() @@ -85,3 +90,7 @@ func (d *ptrDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P cursor = c return cursor, nil } + +func (d *ptrDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, fmt.Errorf("json: ptr decoder does not support decode path") +} diff --git a/internal/decoder/slice.go b/internal/decoder/slice.go index 85b6e11..ec4dcd3 100644 --- a/internal/decoder/slice.go +++ b/internal/decoder/slice.go @@ -299,3 +299,80 @@ func (d *sliceDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe } } } + +func (d *sliceDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + buf := ctx.Buf + depth++ + if depth > maxDecodeNestingDepth { + return nil, 0, errors.ErrExceededMaxDepth(buf[cursor], cursor) + } + + ret := [][]byte{} + for { + switch buf[cursor] { + case ' ', '\n', '\t', '\r': + cursor++ + continue + case 'n': + if err := validateNull(buf, cursor); err != nil { + return nil, 0, err + } + cursor += 4 + return [][]byte{nullbytes}, cursor, nil + case '[': + cursor++ + cursor = skipWhiteSpace(buf, cursor) + if buf[cursor] == ']' { + cursor++ + return ret, cursor, nil + } + idx := 0 + for { + child, found, err := ctx.Option.Path.node.Index(idx) + if err != nil { + return nil, 0, err + } + if found { + if child != nil { + oldPath := ctx.Option.Path.node + ctx.Option.Path.node = child + paths, c, err := d.valueDecoder.DecodePath(ctx, cursor, depth) + if err != nil { + return nil, 0, err + } + ctx.Option.Path.node = oldPath + ret = append(ret, paths...) + cursor = c + } else { + start := cursor + end, err := skipValue(buf, cursor, depth) + if err != nil { + return nil, 0, err + } + ret = append(ret, buf[start:end]) + } + } else { + c, err := skipValue(buf, cursor, depth) + if err != nil { + return nil, 0, err + } + cursor = skipWhiteSpace(buf, c) + } + switch buf[cursor] { + case ']': + cursor++ + return ret, cursor, nil + case ',': + idx++ + default: + return nil, 0, errors.ErrInvalidCharacter(buf[cursor], "slice", cursor) + } + cursor++ + } + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + return nil, 0, d.errNumber(cursor) + default: + return nil, 0, errors.ErrUnexpectedEndOfJSON("slice", cursor) + } + } +} diff --git a/internal/decoder/string.go b/internal/decoder/string.go index d07ad71..32602c9 100644 --- a/internal/decoder/string.go +++ b/internal/decoder/string.go @@ -60,6 +60,17 @@ func (d *stringDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf return cursor, nil } +func (d *stringDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + bytes, c, err := d.decodeByte(ctx.Buf, cursor) + if err != nil { + return nil, 0, err + } + if bytes == nil { + return [][]byte{nullbytes}, c, nil + } + return [][]byte{bytes}, c, nil +} + var ( hexToInt = [256]int{ '0': 0, diff --git a/internal/decoder/struct.go b/internal/decoder/struct.go index 2c64680..6d32654 100644 --- a/internal/decoder/struct.go +++ b/internal/decoder/struct.go @@ -817,3 +817,7 @@ func (d *structDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf cursor++ } } + +func (d *structDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, fmt.Errorf("json: struct decoder does not support decode path") +} diff --git a/internal/decoder/type.go b/internal/decoder/type.go index 70e9907..beaf3ab 100644 --- a/internal/decoder/type.go +++ b/internal/decoder/type.go @@ -10,6 +10,7 @@ import ( type Decoder interface { Decode(*RuntimeContext, int64, int64, unsafe.Pointer) (int64, error) + DecodePath(*RuntimeContext, int64, int64) ([][]byte, int64, error) DecodeStream(*Stream, int64, unsafe.Pointer) error } diff --git a/internal/decoder/uint.go b/internal/decoder/uint.go index a62c514..4131731 100644 --- a/internal/decoder/uint.go +++ b/internal/decoder/uint.go @@ -188,3 +188,7 @@ func (d *uintDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe. d.op(p, u64) return cursor, nil } + +func (d *uintDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, fmt.Errorf("json: uint decoder does not support decode path") +} diff --git a/internal/decoder/unmarshal_json.go b/internal/decoder/unmarshal_json.go index e9b25c6..4cd6dbd 100644 --- a/internal/decoder/unmarshal_json.go +++ b/internal/decoder/unmarshal_json.go @@ -3,6 +3,7 @@ package decoder import ( "context" "encoding/json" + "fmt" "unsafe" "github.com/goccy/go-json/internal/errors" @@ -97,3 +98,7 @@ func (d *unmarshalJSONDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, } return end, nil } + +func (d *unmarshalJSONDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, fmt.Errorf("json: unmarshal json decoder does not support decode path") +} diff --git a/internal/decoder/unmarshal_text.go b/internal/decoder/unmarshal_text.go index 1ef2877..6d37993 100644 --- a/internal/decoder/unmarshal_text.go +++ b/internal/decoder/unmarshal_text.go @@ -3,6 +3,7 @@ package decoder import ( "bytes" "encoding" + "fmt" "unicode" "unicode/utf16" "unicode/utf8" @@ -142,6 +143,10 @@ func (d *unmarshalTextDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, return end, nil } +func (d *unmarshalTextDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, fmt.Errorf("json: unmarshal text decoder does not support decode path") +} + func unquoteBytes(s []byte) (t []byte, ok bool) { length := len(s) if length < 2 || s[0] != '"' || s[length-1] != '"' { diff --git a/internal/decoder/wrapped_string.go b/internal/decoder/wrapped_string.go index 66227ae..0c4e2e6 100644 --- a/internal/decoder/wrapped_string.go +++ b/internal/decoder/wrapped_string.go @@ -1,6 +1,7 @@ package decoder import ( + "fmt" "reflect" "unsafe" @@ -66,3 +67,7 @@ func (d *wrappedStringDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, ctx.Buf = oldBuf return c, nil } + +func (d *wrappedStringDecoder) DecodePath(ctx *RuntimeContext, cursor, depth int64) ([][]byte, int64, error) { + return nil, 0, fmt.Errorf("json: wrapped string decoder does not support decode path") +} diff --git a/internal/errors/error.go b/internal/errors/error.go index d58e39f..9207d0f 100644 --- a/internal/errors/error.go +++ b/internal/errors/error.go @@ -162,3 +162,22 @@ func ErrInvalidBeginningOfValue(c byte, cursor int64) *SyntaxError { Offset: cursor, } } + +type PathError struct { + msg string +} + +func (e *PathError) Error() string { + return fmt.Sprintf("json: invalid path format: %s", e.msg) +} + +func ErrInvalidPath(msg string, args ...interface{}) *PathError { + if len(args) != 0 { + return &PathError{msg: fmt.Sprintf(msg, args...)} + } + return &PathError{msg: msg} +} + +func ErrEmptyPath() *PathError { + return &PathError{msg: "path is empty"} +} diff --git a/path.go b/path.go new file mode 100644 index 0000000..0b9c68c --- /dev/null +++ b/path.go @@ -0,0 +1,73 @@ +package json + +import ( + "reflect" + + "github.com/goccy/go-json/internal/decoder" +) + +// CreatePath creates JSON Path. +// +// JSON Path rule +// $ : root object or element. The JSON Path format must start with this operator, which refers to the outermost level of the JSON-formatted string. +// . : child operator. You can identify child values using dot-notation. +// .. : recursive descent. +// [] : subscript operator. If the JSON object is an array, you can use brackets to specify the array index. +// [*] : all objects/elements for array. +// +// Reserved words must be properly escaped when included in Path. +// Escale Rule +// single quote style escape: e.g.) `$['a.b'].c` +// double quote style escape: e.g.) `$."a.b".c` +func CreatePath(p string) (*Path, error) { + path, err := decoder.PathString(p).Build() + if err != nil { + return nil, err + } + return &Path{path: path}, nil +} + +// Path represents JSON Path. +type Path struct { + path *decoder.Path +} + +// UsedSingleQuotePathSelector whether single quote-based escaping was done when building the JSON Path. +func (p *Path) UsedSingleQuotePathSelector() bool { + return p.path.SingleQuotePathSelector +} + +// UsedSingleQuotePathSelector whether double quote-based escaping was done when building the JSON Path. +func (p *Path) UsedDoubleQuotePathSelector() bool { + return p.path.DoubleQuotePathSelector +} + +// Extract extracts a specific JSON string. +func (p *Path) Extract(data []byte, optFuncs ...DecodeOptionFunc) ([][]byte, error) { + return extractFromPath(p, data, optFuncs...) +} + +// Unmarshal extract and decode the value of the part corresponding to JSON Path from the input data. +func (p *Path) Unmarshal(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error { + contents, err := extractFromPath(p, data, optFuncs...) + if err != nil { + return err + } + results := make([]interface{}, 0, len(contents)) + for _, content := range contents { + var result interface{} + if err := Unmarshal(content, &result); err != nil { + return err + } + results = append(results, result) + } + if err := decoder.AssignValue(reflect.ValueOf(results), reflect.ValueOf(v)); err != nil { + return err + } + return nil +} + +// Get extract and substitute the value of the part corresponding to JSON Path from the input value. +func (p *Path) Get(src, dst interface{}) error { + return p.path.Get(reflect.ValueOf(src), reflect.ValueOf(dst)) +} diff --git a/path_test.go b/path_test.go new file mode 100644 index 0000000..4891eb8 --- /dev/null +++ b/path_test.go @@ -0,0 +1,234 @@ +package json_test + +import ( + "bytes" + "reflect" + "sort" + "testing" + + "github.com/goccy/go-json" +) + +func TestExtractPath(t *testing.T) { + src := []byte(`{"a":{"b":10,"c":true},"b":"text"}`) + t.Run("$.a.b", func(t *testing.T) { + path, err := json.CreatePath("$.a.b") + if err != nil { + t.Fatal(err) + } + contents, err := path.Extract(src) + if err != nil { + t.Fatal(err) + } + if len(contents) != 1 { + t.Fatal("failed to extract") + } + if !bytes.Equal(contents[0], []byte("10")) { + t.Fatal("failed to extract") + } + }) + t.Run("$.b", func(t *testing.T) { + path, err := json.CreatePath("$.b") + if err != nil { + t.Fatal(err) + } + contents, err := path.Extract(src) + if err != nil { + t.Fatal(err) + } + if len(contents) != 1 { + t.Fatal("failed to extract") + } + if !bytes.Equal(contents[0], []byte(`"text"`)) { + t.Fatal("failed to extract") + } + }) + t.Run("$.a", func(t *testing.T) { + path, err := json.CreatePath("$.a") + if err != nil { + t.Fatal(err) + } + contents, err := path.Extract(src) + if err != nil { + t.Fatal(err) + } + if len(contents) != 1 { + t.Fatal("failed to extract") + } + if !bytes.Equal(contents[0], []byte(`{"b":10,"c":true}`)) { + t.Fatal("failed to extract") + } + }) +} + +func TestUnmarshalPath(t *testing.T) { + t.Run("int", func(t *testing.T) { + src := []byte(`{"a":{"b":10,"c":true},"b":"text"}`) + t.Run("success", func(t *testing.T) { + path, err := json.CreatePath("$.a.b") + if err != nil { + t.Fatal(err) + } + var v int + if err := path.Unmarshal(src, &v); err != nil { + t.Fatal(err) + } + if v != 10 { + t.Fatal("failed to unmarshal path") + } + }) + t.Run("failure", func(t *testing.T) { + path, err := json.CreatePath("$.a.c") + if err != nil { + t.Fatal(err) + } + var v map[string]interface{} + if err := path.Unmarshal(src, &v); err == nil { + t.Fatal("expected error") + } + }) + }) + t.Run("bool", func(t *testing.T) { + src := []byte(`{"a":{"b":10,"c":true},"b":"text"}`) + t.Run("success", func(t *testing.T) { + path, err := json.CreatePath("$.a.c") + if err != nil { + t.Fatal(err) + } + var v bool + if err := path.Unmarshal(src, &v); err != nil { + t.Fatal(err) + } + if !v { + t.Fatal("failed to unmarshal path") + } + }) + t.Run("failure", func(t *testing.T) { + path, err := json.CreatePath("$.a.b") + if err != nil { + t.Fatal(err) + } + var v bool + if err := path.Unmarshal(src, &v); err == nil { + t.Fatal("expected error") + } + }) + }) + t.Run("map", func(t *testing.T) { + src := []byte(`{"a":{"b":10,"c":true},"b":"text"}`) + t.Run("success", func(t *testing.T) { + path, err := json.CreatePath("$.a") + if err != nil { + t.Fatal(err) + } + var v map[string]interface{} + if err := path.Unmarshal(src, &v); err != nil { + t.Fatal(err) + } + if len(v) != 2 { + t.Fatal("failed to decode map") + } + }) + }) + + t.Run("path with single quote selector", func(t *testing.T) { + path, err := json.CreatePath("$['a.b'].c") + if err != nil { + t.Fatal(err) + } + + var v string + if err := path.Unmarshal([]byte(`{"a.b": {"c": "world"}}`), &v); err != nil { + t.Fatal(err) + } + if v != "world" { + t.Fatal("failed to unmarshal path") + } + }) + t.Run("path with double quote selector", func(t *testing.T) { + path, err := json.CreatePath(`$."a.b".c`) + if err != nil { + t.Fatal(err) + } + + var v string + if err := path.Unmarshal([]byte(`{"a.b": {"c": "world"}}`), &v); err != nil { + t.Fatal(err) + } + if v != "world" { + t.Fatal("failed to unmarshal path") + } + }) +} + +func TestGetPath(t *testing.T) { + t.Run("selector", func(t *testing.T) { + var v interface{} + if err := json.Unmarshal([]byte(`{"a":{"b":10,"c":true},"b":"text"}`), &v); err != nil { + t.Fatal(err) + } + path, err := json.CreatePath("$.a.b") + if err != nil { + t.Fatal(err) + } + var b int + if err := path.Get(v, &b); err != nil { + t.Fatal(err) + } + if b != 10 { + t.Fatalf("failed to decode by json.Get") + } + }) + t.Run("index", func(t *testing.T) { + var v interface{} + if err := json.Unmarshal([]byte(`{"a":[{"b":10,"c":true},{"b":"text"}]}`), &v); err != nil { + t.Fatal(err) + } + path, err := json.CreatePath("$.a[0].b") + if err != nil { + t.Fatal(err) + } + var b int + if err := path.Get(v, &b); err != nil { + t.Fatal(err) + } + if b != 10 { + t.Fatalf("failed to decode by json.Get") + } + }) + t.Run("indexAll", func(t *testing.T) { + var v interface{} + if err := json.Unmarshal([]byte(`{"a":[{"b":1,"c":true},{"b":2},{"b":3}]}`), &v); err != nil { + t.Fatal(err) + } + path, err := json.CreatePath("$.a[*].b") + if err != nil { + t.Fatal(err) + } + var b []int + if err := path.Get(v, &b); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(b, []int{1, 2, 3}) { + t.Fatalf("failed to decode by json.Get") + } + }) + t.Run("recursive", func(t *testing.T) { + var v interface{} + if err := json.Unmarshal([]byte(`{"a":[{"b":1,"c":true},{"b":2},{"b":3}],"a2":{"b":4}}`), &v); err != nil { + t.Fatal(err) + } + path, err := json.CreatePath("$..b") + if err != nil { + t.Fatal(err) + } + var b []int + if err := path.Get(v, &b); err != nil { + t.Fatal(err) + } + sort.Ints(b) + if !reflect.DeepEqual(b, []int{1, 2, 3, 4}) { + t.Fatalf("failed to decode by json.Get") + } + }) +}