From 29f4737e5d257c94c0b4273cada5db58b2972407 Mon Sep 17 00:00:00 2001 From: Alex Roitman Date: Thu, 15 Feb 2018 11:08:27 -0800 Subject: [PATCH] Add WHEREEVAL clause to scan/search commands. --- controller/hooks.go | 5 +- controller/live.go | 4 +- controller/scan.go | 19 ++++-- controller/scanner.go | 48 ++++++++++++--- controller/search.go | 52 ++++++++++++++--- controller/token.go | 132 +++++++++++++++++++++++++++++++++++++++++- 6 files changed, 237 insertions(+), 23 deletions(-) diff --git a/controller/hooks.go b/controller/hooks.go index dc5e5201..cceddaac 100644 --- a/controller/hooks.go +++ b/controller/hooks.go @@ -97,6 +97,7 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandD break } s, err := c.cmdSearchArgs(cmdlc, vs, types) + defer s.Close() if err != nil { return server.NOMessage, d, err } @@ -132,7 +133,9 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandD hook.cond = sync.NewCond(&hook.mu) var wr bytes.Buffer - hook.ScanWriter, err = c.newScanWriter(&wr, cmsg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) + hook.ScanWriter, err = c.newScanWriter( + &wr, cmsg, s.key, s.output, s.precision, s.glob, false, + s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields) if err != nil { return server.NOMessage, d, err } diff --git a/controller/live.go b/controller/live.go index 38c157e9..a49a09ca 100644 --- a/controller/live.go +++ b/controller/live.go @@ -87,7 +87,9 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.PipelineReade lb.key = s.key lb.fence = &s c.mu.RLock() - sw, err = c.newScanWriter(&wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) + sw, err = c.newScanWriter( + &wr, msg, s.key, s.output, s.precision, s.glob, false, + s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields) c.mu.RUnlock() } // everything below if for live SCAN, NEARBY, WITHIN, INTERSECTS diff --git a/controller/scan.go b/controller/scan.go index 0227d85c..db62ee24 100644 --- a/controller/scan.go +++ b/controller/scan.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "errors" "time" "github.com/tidwall/resp" @@ -10,8 +11,8 @@ import ( "github.com/tidwall/tile38/geojson" ) -func cmdScanArgs(vs []resp.Value) (s liveFenceSwitches, err error) { - if vs, s.searchScanBaseTokens, err = parseSearchScanBaseTokens("scan", vs); err != nil { +func (c *Controller) cmdScanArgs(vs []resp.Value) (s liveFenceSwitches, err error) { + if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens("scan", vs); err != nil { return } if len(vs) != 0 { @@ -25,12 +26,22 @@ func (c *Controller) cmdScan(msg *server.Message) (res resp.Value, err error) { start := time.Now() vs := msg.Values[1:] - s, err := cmdScanArgs(vs) + s, err := c.cmdScanArgs(vs) + defer s.Close() + defer func() { + if r := recover(); r != nil { + res = server.NOMessage + err = errors.New(r.(string)) + return + } + }() if err != nil { return server.NOMessage, err } wr := &bytes.Buffer{} - sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) + sw, err := c.newScanWriter( + wr, msg, s.key, s.output, s.precision, s.glob, false, + s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields) if err != nil { return server.NOMessage, err } diff --git a/controller/scanner.go b/controller/scanner.go index 790906e4..dedf004b 100644 --- a/controller/scanner.go +++ b/controller/scanner.go @@ -40,6 +40,7 @@ type scanWriter struct { output outputT wheres []whereT whereins []whereinT + whereevals []whereevalT numberItems uint64 nofields bool cursor uint64 @@ -68,7 +69,7 @@ type ScanWriterParams struct { func (c *Controller) newScanWriter( wr *bytes.Buffer, msg *server.Message, key string, output outputT, precision uint64, globPattern string, matchValues bool, - cursor, limit uint64, wheres []whereT, whereins []whereinT, nofields bool, + cursor, limit uint64, wheres []whereT, whereins []whereinT, whereevals []whereevalT, nofields bool, ) ( *scanWriter, error, ) { @@ -92,6 +93,7 @@ func (c *Controller) newScanWriter( limit: limit, wheres: wheres, whereins: whereins, + whereevals: whereevals, output: output, nofields: nofields, precision: precision, @@ -186,9 +188,10 @@ func (sw *scanWriter) writeFoot() { } } -func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64, bool) { +func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) (fvals []float64, match bool) { var z float64 var gotz bool + fvals = sw.fvals if !sw.hasFieldsOutput() || sw.fullFields { for _, where := range sw.wheres { if where.field == "z" { @@ -196,7 +199,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64, z = o.CalculatedPoint().Z } if !where.match(z) { - return sw.fvals, false + return } continue } @@ -208,7 +211,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64, } } if !where.match(value) { - return sw.fvals, false + return } } for _, wherein := range sw.whereins { @@ -220,7 +223,20 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64, } } if !wherein.match(value) { - return sw.fvals, false + return + } + } + for _, whereval := range sw.whereevals { + fieldsWithNames := make(map[string]float64) + for field, idx := range sw.fmap { + if idx < len(fields) { + fieldsWithNames[field] = fields[idx] + } else { + fieldsWithNames[field] = 0 + } + } + if !whereval.match(fieldsWithNames) { + return } } } else { @@ -237,7 +253,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64, z = o.CalculatedPoint().Z } if !where.match(z) { - return sw.fvals, false + return } continue } @@ -247,7 +263,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64, value = sw.fvals[idx] } if !where.match(value) { - return sw.fvals, false + return } } for _, wherein := range sw.whereins { @@ -257,11 +273,25 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64, value = sw.fvals[idx] } if !wherein.match(value) { - return sw.fvals, false + return + } + } + for _, whereval := range sw.whereevals { + fieldsWithNames := make(map[string]float64) + for field, idx := range sw.fmap { + if idx < len(fields) { + fieldsWithNames[field] = fields[idx] + } else { + fieldsWithNames[field] = 0 + } + } + if !whereval.match(fieldsWithNames) { + return } } } - return sw.fvals, true + match = true + return } //id string, o geojson.Object, fields []float64, noLock bool diff --git a/controller/search.go b/controller/search.go index 8a656779..71225aaf 100644 --- a/controller/search.go +++ b/controller/search.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "errors" "sort" "strconv" "strings" @@ -40,8 +41,14 @@ func (s liveFenceSwitches) Error() string { return "going live" } +func (s liveFenceSwitches) Close() { + for _, whereeval := range s.searchScanBaseTokens.whereevals { + whereeval.Close() + } +} + func (c *Controller) cmdSearchArgs(cmd string, vs []resp.Value, types []string) (s liveFenceSwitches, err error) { - if vs, s.searchScanBaseTokens, err = parseSearchScanBaseTokens(cmd, vs); err != nil { + if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens(cmd, vs); err != nil { return } var typ string @@ -288,6 +295,14 @@ func (c *Controller) cmdNearby(msg *server.Message) (res resp.Value, err error) vs := msg.Values[1:] wr := &bytes.Buffer{} s, err := c.cmdSearchArgs("nearby", vs, nearbyTypes) + defer s.Close() + defer func() { + if r := recover(); r != nil { + res = server.NOMessage + err = errors.New(r.(string)) + return + } + }() if err != nil { return server.NOMessage, err } @@ -296,7 +311,9 @@ func (c *Controller) cmdNearby(msg *server.Message) (res resp.Value, err error) return server.NOMessage, s } minZ, maxZ := zMinMaxFromWheres(s.wheres) - sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) + sw, err := c.newScanWriter( + wr, msg, s.key, s.output, s.precision, s.glob, false, + s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields) if err != nil { return server.NOMessage, err } @@ -385,6 +402,15 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res wr := &bytes.Buffer{} s, err := c.cmdSearchArgs(cmd, vs, withinOrIntersectsTypes) + defer s.Close() + defer func() { + if r := recover(); r != nil { + res = server.NOMessage + err = errors.New(r.(string)) + return + } + }() + if err != nil { return server.NOMessage, err } @@ -392,7 +418,9 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res if s.fence { return server.NOMessage, s } - sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) + sw, err := c.newScanWriter( + wr, msg, s.key, s.output, s.precision, s.glob, false, + s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields) if err != nil { return server.NOMessage, err } @@ -440,8 +468,8 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res return sw.respOut, nil } -func cmdSeachValuesArgs(vs []resp.Value) (s liveFenceSwitches, err error) { - if vs, s.searchScanBaseTokens, err = parseSearchScanBaseTokens("search", vs); err != nil { +func (c *Controller) cmdSeachValuesArgs(vs []resp.Value) (s liveFenceSwitches, err error) { + if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens("search", vs); err != nil { return } if len(vs) != 0 { @@ -456,11 +484,21 @@ func (c *Controller) cmdSearch(msg *server.Message) (res resp.Value, err error) vs := msg.Values[1:] wr := &bytes.Buffer{} - s, err := cmdSeachValuesArgs(vs) + s, err := c.cmdSeachValuesArgs(vs) + defer s.Close() + defer func() { + if r := recover(); r != nil { + res = server.NOMessage + err = errors.New(r.(string)) + return + } + }() if err != nil { return server.NOMessage, err } - sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, true, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) + sw, err := c.newScanWriter( + wr, msg, s.key, s.output, s.precision, s.glob, true, + s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields) if err != nil { return server.NOMessage, err } diff --git a/controller/token.go b/controller/token.go index a8b4924b..a57f8033 100644 --- a/controller/token.go +++ b/controller/token.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/tidwall/resp" + "github.com/yuin/gopher-lua" ) const defaultSearchOutput = outputObjects @@ -166,6 +167,64 @@ func (wherein whereinT) match(value float64) bool { return ok } +type whereevalT struct { + c *Controller + luaState *lua.LState + fn *lua.LFunction +} + +func (whereeval whereevalT) Close() { + luaSetRawGlobals( + whereeval.luaState, map[string]lua.LValue{ + "ARGV": lua.LNil, + }) + whereeval.c.luapool.Put(whereeval.luaState) +} + +func (whereeval whereevalT) match(fieldsWithNames map[string]float64) bool { + fieldsTbl := whereeval.luaState.CreateTable(0, len(fieldsWithNames)) + for field, val := range fieldsWithNames { + fieldsTbl.RawSetString(field, lua.LNumber(val)) + } + + luaSetRawGlobals( + whereeval.luaState, map[string]lua.LValue{ + "FIELDS": fieldsTbl, + }) + defer luaSetRawGlobals( + whereeval.luaState, map[string]lua.LValue{ + "FIELDS": lua.LNil, + }) + + whereeval.luaState.Push(whereeval.fn) + if err := whereeval.luaState.PCall(0, 1, nil); err != nil { + panic(err.Error()) + } + ret := whereeval.luaState.Get(-1) + whereeval.luaState.Pop(1) + + // Make bool out of returned lua value + switch ret.Type() { + case lua.LTNil: + return false + case lua.LTBool: + return ret == lua.LTrue + case lua.LTNumber: + return float64(ret.(lua.LNumber)) != 0 + case lua.LTString: + return ret.String() != "" + case lua.LTTable: + tbl := ret.(*lua.LTable) + if tbl.Len() != 0 { + return true + } + var match bool + tbl.ForEach(func(lk lua.LValue, lv lua.LValue) {match = true}) + return match + } + panic(fmt.Sprintf("Script returned value of type %s", ret.Type())) +} + type searchScanBaseTokens struct { key string cursor uint64 @@ -179,6 +238,7 @@ type searchScanBaseTokens struct { glob string wheres []whereT whereins []whereinT + whereevals []whereevalT nofields bool ulimit bool limit uint64 @@ -187,7 +247,7 @@ type searchScanBaseTokens struct { desc bool } -func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, t searchScanBaseTokens, err error) { +func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, t searchScanBaseTokens, err error) { var ok bool if vs, t.key, ok = tokenval(vs); !ok || t.key == "" { err = errInvalidNumberOfArguments @@ -288,6 +348,76 @@ func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, } t.whereins = append(t.whereins, whereinT{field, valMap}) continue + } else if (wtok[0] == 'W' || wtok[0] == 'w') && strings.Contains(strings.ToLower(wtok), "whereeval") { + scriptIsSha := strings.ToLower(wtok) == "whereevalsha" + vs = nvs + var script, nargsStr, arg string + if vs, script, ok = tokenval(vs); !ok || script == "" { + err = errInvalidNumberOfArguments + return + } + if vs, nargsStr, ok = tokenval(vs); !ok || nargsStr == "" { + err = errInvalidNumberOfArguments + return + } + + var i, nargs uint64 + if nargs, err = strconv.ParseUint(nargsStr, 10, 64); err != nil { + err = errInvalidArgument(nargsStr) + return + } + + var luaState *lua.LState + luaState, err = c.luapool.Get() + if err != nil { + return + } + + argsTbl := luaState.CreateTable(len(vs), 0) + for i = 0; i < nargs; i++ { + if vs, arg, ok = tokenval(vs); !ok || arg == "" { + err = errInvalidNumberOfArguments + return + } + argsTbl.Append(lua.LString(arg)) + } + + var shaSum string + if scriptIsSha { + shaSum = script + } else { + shaSum = Sha1Sum(script) + } + + luaSetRawGlobals( + luaState, map[string]lua.LValue{ + "ARGV": argsTbl, + }) + + compiled, ok := c.luascripts.Get(shaSum) + var fn *lua.LFunction + if ok { + fn = &lua.LFunction{ + IsG: false, + Env: luaState.Env, + + Proto: compiled, + GFunction: nil, + Upvalues: make([]*lua.Upvalue, 0), + } + } else if scriptIsSha { + err = errShaNotFound + return + } else { + fn, err = luaState.Load(strings.NewReader(script), "f_"+shaSum) + if err != nil { + err = makeSafeErr(err) + return + } + c.luascripts.Put(shaSum, fn.Proto) + } + t.whereevals = append(t.whereevals, whereevalT{c,luaState, fn}) + continue } else if (wtok[0] == 'N' || wtok[0] == 'n') && strings.ToLower(wtok) == "nofields" { vs = nvs if t.nofields {