diff --git a/go.mod b/go.mod index 4f571a54..f6d78d1e 100644 --- a/go.mod +++ b/go.mod @@ -91,6 +91,7 @@ require ( github.com/prometheus/common v0.32.1 // indirect github.com/prometheus/procfs v0.7.3 // indirect github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect + github.com/tidwall/expr v0.8.3 // indirect github.com/tidwall/geoindex v1.7.0 // indirect github.com/tidwall/grect v0.1.4 // indirect github.com/tidwall/rtred v0.1.2 // indirect diff --git a/go.sum b/go.sum index 8b5f9906..a68064e8 100644 --- a/go.sum +++ b/go.sum @@ -356,6 +356,8 @@ github.com/tidwall/buntdb v1.2.9 h1:XVz684P7X6HCTrdr385yDZWB1zt/n20ZNG3M1iGyFm4= github.com/tidwall/buntdb v1.2.9/go.mod h1:IwyGSvvDg6hnKSIhtdZ0AqhCZGH8ukdtCAzaP8fI1X4= github.com/tidwall/cities v0.1.0 h1:CVNkmMf7NEC9Bvokf5GoSsArHCKRMTgLuubRTHnH0mE= github.com/tidwall/cities v0.1.0/go.mod h1:lV/HDp2gCcRcHJWqgt6Di54GiDrTZwh1aG2ZUPNbqa4= +github.com/tidwall/expr v0.8.3 h1:hLaz3DmuXsat+LAO904UxjD1WHrHEbRYZgzzzcn7JB4= +github.com/tidwall/expr v0.8.3/go.mod h1:GnVpaS2R9wWV9Ft2u5TPDypJ+iQNxhAt9ISTUaUTlto= github.com/tidwall/geoindex v1.4.4/go.mod h1:rvVVNEFfkJVWGUdEfU8QaoOg/9zFX0h9ofWzA60mz1I= github.com/tidwall/geoindex v1.7.0 h1:jtk41sfgwIt8MEDyC3xyKSj75iXXf6rjReJGDNPtR5o= github.com/tidwall/geoindex v1.7.0/go.mod h1:rvVVNEFfkJVWGUdEfU8QaoOg/9zFX0h9ofWzA60mz1I= diff --git a/internal/server/expr.go b/internal/server/expr.go new file mode 100644 index 00000000..6252b129 --- /dev/null +++ b/internal/server/expr.go @@ -0,0 +1,140 @@ +package server + +import ( + "sync" + + "github.com/tidwall/expr" + "github.com/tidwall/geojson" + "github.com/tidwall/gjson" + "github.com/tidwall/tile38/internal/field" + "github.com/tidwall/tile38/internal/object" +) + +type exprPool struct { + pool *sync.Pool +} + +func typeForObject(o *object.Object) expr.Value { + switch o.Geo().(type) { + case *geojson.Point, *geojson.SimplePoint: + return expr.String("Point") + case *geojson.LineString: + return expr.String("LineString") + case *geojson.Polygon, *geojson.Circle, *geojson.Rect: + return expr.String("Polygon") + case *geojson.MultiPoint: + return expr.String("MultiPoint") + case *geojson.MultiLineString: + return expr.String("MultiLineString") + case *geojson.MultiPolygon: + return expr.String("MultiPolygon") + case *geojson.GeometryCollection: + return expr.String("GeometryCollection") + case *geojson.Feature: + return expr.String("Feature") + case *geojson.FeatureCollection: + return expr.String("FeatureCollection") + default: + return expr.Undefined + } +} + +func resultToValue(r gjson.Result) expr.Value { + if !r.Exists() { + return expr.Undefined + } + switch r.Type { + case gjson.String: + return expr.String(r.String()) + case gjson.False: + return expr.Bool(false) + case gjson.True: + return expr.Bool(true) + case gjson.Number: + return expr.Number(r.Float()) + case gjson.JSON: + return expr.String(r.String()) + default: + return expr.Null + } +} + +func newExprPool(s *Server) *exprPool { + ext := expr.NewExtender( + // ref + func(info expr.RefInfo, ctx *expr.Context) (expr.Value, error) { + o := ctx.UserData.(*object.Object) + if !info.Chain { + // root + if r := gjson.Get(o.Geo().Members(), info.Ident); r.Exists() { + return resultToValue(r), nil + } + switch info.Ident { + case "id": + return expr.String(o.ID()), nil + case "type": + return typeForObject(o), nil + default: + var rf field.Field + var ok bool + o.Fields().Scan(func(f field.Field) bool { + if f.Name() == info.Ident { + rf = f + ok = true + return false + } + return true + }) + if ok { + r := gjson.Parse(rf.Value().JSON()) + return resultToValue(r), nil + } + } + } else { + switch info.Value.Value().(type) { + case string: + r := gjson.Get(info.Value.String(), info.Ident) + return resultToValue(r), nil + } + } + return expr.Undefined, nil + }, + // call + func(info expr.CallInfo, ctx *expr.Context) (expr.Value, error) { + // No custom calls + return expr.Undefined, nil + }, + // op + func(info expr.OpInfo, ctx *expr.Context) (expr.Value, error) { + // No custom operations + return expr.Undefined, nil + }, + ) + return &exprPool{ + pool: &sync.Pool{ + New: func() any { + ctx := &expr.Context{ + Extender: ext, + } + return ctx + }, + }, + } +} + +func (p *exprPool) Get(o *object.Object) *expr.Context { + ctx := p.pool.Get().(*expr.Context) + ctx.UserData = o + return ctx +} + +func (p *exprPool) Put(ctx *expr.Context) { + p.pool.Put(ctx) +} + +func (where whereT) matchExpr(s *Server, o *object.Object) bool { + ctx := s.epool.Get(o) + res, _ := expr.Eval(where.name, ctx) + s.epool.Put(ctx) + return res.Bool() +} diff --git a/internal/server/scanner.go b/internal/server/scanner.go index ea5d3377..013ab40a 100644 --- a/internal/server/scanner.go +++ b/internal/server/scanner.go @@ -229,8 +229,14 @@ func getFieldValue(o *object.Object, name string) field.Value { func (sw *scanWriter) fieldMatch(o *object.Object) (bool, error) { for _, where := range sw.wheres { - if !where.match(getFieldValue(o, where.name)) { - return false, nil + if where.expr { + if !where.matchExpr(sw.s, o) { + return false, nil + } + } else { + if !where.matchField(getFieldValue(o, where.name)) { + return false, nil + } } } for _, wherein := range sw.whereins { diff --git a/internal/server/scanner_test.go b/internal/server/scanner_test.go index ae9fec31..fceccbe9 100644 --- a/internal/server/scanner_test.go +++ b/internal/server/scanner_test.go @@ -36,8 +36,8 @@ func BenchmarkFieldMatch(t *testing.B) { } sw := &scanWriter{ wheres: []whereT{ - {"foo", false, field.ValueOf("1"), false, field.ValueOf("3")}, - {"bar", false, field.ValueOf("10"), false, field.ValueOf("30")}, + {false, "foo", false, field.ValueOf("1"), false, field.ValueOf("3")}, + {false, "bar", false, field.ValueOf("10"), false, field.ValueOf("30")}, }, whereins: []whereinT{ {"foo", []field.Value{field.ValueOf("1"), field.ValueOf("2")}}, diff --git a/internal/server/server.go b/internal/server/server.go index bf792724..dc9b2952 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -79,6 +79,7 @@ type Server struct { started time.Time config *Config epc *endpoint.Manager + epool *exprPool lnmu sync.Mutex ln net.Listener // server listener @@ -222,7 +223,7 @@ func Serve(opts Options) error { hookExpires: btree.NewNonConcurrent(byHookExpires), opts: opts, } - + s.epool = newExprPool(s) s.epc = endpoint.NewManager(s) defer s.epc.Shutdown() s.luascripts = s.newScriptMap() diff --git a/internal/server/token.go b/internal/server/token.go index 253c3e8e..3ccf5f68 100644 --- a/internal/server/token.go +++ b/internal/server/token.go @@ -63,6 +63,7 @@ func lc(s1, s2 string) bool { } type whereT struct { + expr bool name string minx bool min field.Value @@ -76,7 +77,7 @@ func mGT(a, b field.Value) bool { return mLT(b, a) } func mGTE(a, b field.Value) bool { return !mLT(a, b) } func mEQ(a, b field.Value) bool { return a.Equals(b) } -func (where whereT) match(value field.Value) bool { +func (where whereT) matchField(value field.Value) bool { switch where.min.Data() { case "<": return mLT(value, where.max) @@ -275,45 +276,59 @@ func (s *Server) parseSearchScanBaseTokens( continue case "where": vs = nvs - var name, smin, smax string - if vs, name, ok = tokenval(vs); !ok { - err = errInvalidNumberOfArguments - return - } - if vs, smin, ok = tokenval(vs); !ok { - err = errInvalidNumberOfArguments - return - } - if vs, smax, ok = tokenval(vs); !ok { - err = errInvalidNumberOfArguments - return - } - var minx, maxx bool - smin = strings.ToLower(smin) - smax = strings.ToLower(smax) - if smax == "+inf" || smax == "inf" { - smax = "inf" - } - switch smin { - case "<", "<=", ">", ">=", "==", "!=": - default: - if strings.HasPrefix(smin, "(") { - minx = true - smin = smin[1:] + if detectExprToken(vs) { + // using expressions + // WHERE expr + var expr string + if vs, expr, ok = tokenval(vs); !ok { + err = errInvalidNumberOfArguments + return } - if strings.HasPrefix(smax, "(") { - maxx = true - smax = smax[1:] + t.wheres = append(t.wheres, whereT{name: expr, expr: true}) + continue + } else { + // using field filter + // WHERE min max + var name, smin, smax string + if vs, name, ok = tokenval(vs); !ok { + err = errInvalidNumberOfArguments + return } + if vs, smin, ok = tokenval(vs); !ok { + err = errInvalidNumberOfArguments + return + } + if vs, smax, ok = tokenval(vs); !ok { + err = errInvalidNumberOfArguments + return + } + var minx, maxx bool + smin = strings.ToLower(smin) + smax = strings.ToLower(smax) + if smax == "+inf" || smax == "inf" { + smax = "inf" + } + switch smin { + case "<", "<=", ">", ">=", "==", "!=": + default: + if strings.HasPrefix(smin, "(") { + minx = true + smin = smin[1:] + } + if strings.HasPrefix(smax, "(") { + maxx = true + smax = smax[1:] + } + } + t.wheres = append(t.wheres, whereT{ + name: strings.ToLower(name), + minx: minx, + min: field.ValueOf(smin), + maxx: maxx, + max: field.ValueOf(smax), + }) + continue } - t.wheres = append(t.wheres, whereT{ - name: strings.ToLower(name), - minx: minx, - min: field.ValueOf(smin), - maxx: maxx, - max: field.ValueOf(smax), - }) - continue case "wherein": vs = nvs var name, nvalsStr, valStr string @@ -675,6 +690,25 @@ func (s *Server) parseSearchScanBaseTokens( return } +func detectExprToken(vs []string) bool { + // Detect the kind of where, either: + // - expr + // - name min max + if len(vs) == 0 { + return false + } else if len(vs) == 1 || (len(vs) == 2 && len(vs[1]) == 0) { + return true + } + v := vs[1] + if (v[0] >= 'a' && v[0] <= 'z') || (v[0] >= 'A' && v[0] <= 'Z') { + if (v[0] == 'i' || v[0] == 'I') && strings.ToLower(v) == "inf" { + return false + } + return true + } + return false +} + type parentStack []*areaExpression func (ps *parentStack) isEmpty() bool { diff --git a/tests/keys_test.go b/tests/keys_test.go index 518d2af0..92ce18b3 100644 --- a/tests/keys_test.go +++ b/tests/keys_test.go @@ -463,7 +463,7 @@ func keys_FIELDS_test(mc *mockServer) error { // Do some GJSON queries. Do("SET", "fleet", "truck2", "FIELD", "hello", `{"world":"tom"}`, "POINT", "-112", "33").JSON().OK(), Do("SCAN", "fleet", "WHERE", "hello", `{"world":"tom"}`, `{"world":"tom"}`, "COUNT").JSON().Str(`{"ok":true,"count":1,"cursor":0}`), - Do("SCAN", "fleet", "WHERE", "hello.world", `tom`, `tom`, "COUNT").JSON().Str(`{"ok":true,"count":1,"cursor":0}`), + Do("SCAN", "fleet", "WHERE", "hello.world == 'tom'", "COUNT").JSON().Str(`{"ok":true,"count":1,"cursor":0}`), // The next scan does not match on anything, but since we're matching // on zeros, which is the default, then all (two) objects are returned. Do("SCAN", "fleet", "WHERE", "hello.world.1", `0`, `0`, "IDS").JSON().Str(`{"ok":true,"ids":["truck1","truck2"],"count":2,"cursor":0}`),