Add WHEREEVAL clause to scan/search commands.

This commit is contained in:
Alex Roitman 2018-02-15 11:08:27 -08:00
parent 2088b5d2d2
commit 29f4737e5d
6 changed files with 237 additions and 23 deletions

View File

@ -97,6 +97,7 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandD
break
}
s, err := c.cmdSearchArgs(cmdlc, vs, types)
defer s.Close()
if err != nil {
return server.NOMessage, d, err
}
@ -132,7 +133,9 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandD
hook.cond = sync.NewCond(&hook.mu)
var wr bytes.Buffer
hook.ScanWriter, err = c.newScanWriter(&wr, cmsg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
hook.ScanWriter, err = c.newScanWriter(
&wr, cmsg, s.key, s.output, s.precision, s.glob, false,
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
if err != nil {
return server.NOMessage, d, err
}

View File

@ -87,7 +87,9 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.PipelineReade
lb.key = s.key
lb.fence = &s
c.mu.RLock()
sw, err = c.newScanWriter(&wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
sw, err = c.newScanWriter(
&wr, msg, s.key, s.output, s.precision, s.glob, false,
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
c.mu.RUnlock()
}
// everything below if for live SCAN, NEARBY, WITHIN, INTERSECTS

View File

@ -2,6 +2,7 @@ package controller
import (
"bytes"
"errors"
"time"
"github.com/tidwall/resp"
@ -10,8 +11,8 @@ import (
"github.com/tidwall/tile38/geojson"
)
func cmdScanArgs(vs []resp.Value) (s liveFenceSwitches, err error) {
if vs, s.searchScanBaseTokens, err = parseSearchScanBaseTokens("scan", vs); err != nil {
func (c *Controller) cmdScanArgs(vs []resp.Value) (s liveFenceSwitches, err error) {
if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens("scan", vs); err != nil {
return
}
if len(vs) != 0 {
@ -25,12 +26,22 @@ func (c *Controller) cmdScan(msg *server.Message) (res resp.Value, err error) {
start := time.Now()
vs := msg.Values[1:]
s, err := cmdScanArgs(vs)
s, err := c.cmdScanArgs(vs)
defer s.Close()
defer func() {
if r := recover(); r != nil {
res = server.NOMessage
err = errors.New(r.(string))
return
}
}()
if err != nil {
return server.NOMessage, err
}
wr := &bytes.Buffer{}
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
sw, err := c.newScanWriter(
wr, msg, s.key, s.output, s.precision, s.glob, false,
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
if err != nil {
return server.NOMessage, err
}

View File

@ -40,6 +40,7 @@ type scanWriter struct {
output outputT
wheres []whereT
whereins []whereinT
whereevals []whereevalT
numberItems uint64
nofields bool
cursor uint64
@ -68,7 +69,7 @@ type ScanWriterParams struct {
func (c *Controller) newScanWriter(
wr *bytes.Buffer, msg *server.Message, key string, output outputT,
precision uint64, globPattern string, matchValues bool,
cursor, limit uint64, wheres []whereT, whereins []whereinT, nofields bool,
cursor, limit uint64, wheres []whereT, whereins []whereinT, whereevals []whereevalT, nofields bool,
) (
*scanWriter, error,
) {
@ -92,6 +93,7 @@ func (c *Controller) newScanWriter(
limit: limit,
wheres: wheres,
whereins: whereins,
whereevals: whereevals,
output: output,
nofields: nofields,
precision: precision,
@ -186,9 +188,10 @@ func (sw *scanWriter) writeFoot() {
}
}
func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64, bool) {
func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) (fvals []float64, match bool) {
var z float64
var gotz bool
fvals = sw.fvals
if !sw.hasFieldsOutput() || sw.fullFields {
for _, where := range sw.wheres {
if where.field == "z" {
@ -196,7 +199,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
z = o.CalculatedPoint().Z
}
if !where.match(z) {
return sw.fvals, false
return
}
continue
}
@ -208,7 +211,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
}
}
if !where.match(value) {
return sw.fvals, false
return
}
}
for _, wherein := range sw.whereins {
@ -220,7 +223,20 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
}
}
if !wherein.match(value) {
return sw.fvals, false
return
}
}
for _, whereval := range sw.whereevals {
fieldsWithNames := make(map[string]float64)
for field, idx := range sw.fmap {
if idx < len(fields) {
fieldsWithNames[field] = fields[idx]
} else {
fieldsWithNames[field] = 0
}
}
if !whereval.match(fieldsWithNames) {
return
}
}
} else {
@ -237,7 +253,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
z = o.CalculatedPoint().Z
}
if !where.match(z) {
return sw.fvals, false
return
}
continue
}
@ -247,7 +263,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
value = sw.fvals[idx]
}
if !where.match(value) {
return sw.fvals, false
return
}
}
for _, wherein := range sw.whereins {
@ -257,11 +273,25 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
value = sw.fvals[idx]
}
if !wherein.match(value) {
return sw.fvals, false
return
}
}
for _, whereval := range sw.whereevals {
fieldsWithNames := make(map[string]float64)
for field, idx := range sw.fmap {
if idx < len(fields) {
fieldsWithNames[field] = fields[idx]
} else {
fieldsWithNames[field] = 0
}
}
if !whereval.match(fieldsWithNames) {
return
}
}
}
return sw.fvals, true
match = true
return
}
//id string, o geojson.Object, fields []float64, noLock bool

View File

@ -2,6 +2,7 @@ package controller
import (
"bytes"
"errors"
"sort"
"strconv"
"strings"
@ -40,8 +41,14 @@ func (s liveFenceSwitches) Error() string {
return "going live"
}
func (s liveFenceSwitches) Close() {
for _, whereeval := range s.searchScanBaseTokens.whereevals {
whereeval.Close()
}
}
func (c *Controller) cmdSearchArgs(cmd string, vs []resp.Value, types []string) (s liveFenceSwitches, err error) {
if vs, s.searchScanBaseTokens, err = parseSearchScanBaseTokens(cmd, vs); err != nil {
if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens(cmd, vs); err != nil {
return
}
var typ string
@ -288,6 +295,14 @@ func (c *Controller) cmdNearby(msg *server.Message) (res resp.Value, err error)
vs := msg.Values[1:]
wr := &bytes.Buffer{}
s, err := c.cmdSearchArgs("nearby", vs, nearbyTypes)
defer s.Close()
defer func() {
if r := recover(); r != nil {
res = server.NOMessage
err = errors.New(r.(string))
return
}
}()
if err != nil {
return server.NOMessage, err
}
@ -296,7 +311,9 @@ func (c *Controller) cmdNearby(msg *server.Message) (res resp.Value, err error)
return server.NOMessage, s
}
minZ, maxZ := zMinMaxFromWheres(s.wheres)
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
sw, err := c.newScanWriter(
wr, msg, s.key, s.output, s.precision, s.glob, false,
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
if err != nil {
return server.NOMessage, err
}
@ -385,6 +402,15 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res
wr := &bytes.Buffer{}
s, err := c.cmdSearchArgs(cmd, vs, withinOrIntersectsTypes)
defer s.Close()
defer func() {
if r := recover(); r != nil {
res = server.NOMessage
err = errors.New(r.(string))
return
}
}()
if err != nil {
return server.NOMessage, err
}
@ -392,7 +418,9 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res
if s.fence {
return server.NOMessage, s
}
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
sw, err := c.newScanWriter(
wr, msg, s.key, s.output, s.precision, s.glob, false,
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
if err != nil {
return server.NOMessage, err
}
@ -440,8 +468,8 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res
return sw.respOut, nil
}
func cmdSeachValuesArgs(vs []resp.Value) (s liveFenceSwitches, err error) {
if vs, s.searchScanBaseTokens, err = parseSearchScanBaseTokens("search", vs); err != nil {
func (c *Controller) cmdSeachValuesArgs(vs []resp.Value) (s liveFenceSwitches, err error) {
if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens("search", vs); err != nil {
return
}
if len(vs) != 0 {
@ -456,11 +484,21 @@ func (c *Controller) cmdSearch(msg *server.Message) (res resp.Value, err error)
vs := msg.Values[1:]
wr := &bytes.Buffer{}
s, err := cmdSeachValuesArgs(vs)
s, err := c.cmdSeachValuesArgs(vs)
defer s.Close()
defer func() {
if r := recover(); r != nil {
res = server.NOMessage
err = errors.New(r.(string))
return
}
}()
if err != nil {
return server.NOMessage, err
}
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, true, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
sw, err := c.newScanWriter(
wr, msg, s.key, s.output, s.precision, s.glob, true,
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
if err != nil {
return server.NOMessage, err
}

View File

@ -8,6 +8,7 @@ import (
"strings"
"github.com/tidwall/resp"
"github.com/yuin/gopher-lua"
)
const defaultSearchOutput = outputObjects
@ -166,6 +167,64 @@ func (wherein whereinT) match(value float64) bool {
return ok
}
type whereevalT struct {
c *Controller
luaState *lua.LState
fn *lua.LFunction
}
func (whereeval whereevalT) Close() {
luaSetRawGlobals(
whereeval.luaState, map[string]lua.LValue{
"ARGV": lua.LNil,
})
whereeval.c.luapool.Put(whereeval.luaState)
}
func (whereeval whereevalT) match(fieldsWithNames map[string]float64) bool {
fieldsTbl := whereeval.luaState.CreateTable(0, len(fieldsWithNames))
for field, val := range fieldsWithNames {
fieldsTbl.RawSetString(field, lua.LNumber(val))
}
luaSetRawGlobals(
whereeval.luaState, map[string]lua.LValue{
"FIELDS": fieldsTbl,
})
defer luaSetRawGlobals(
whereeval.luaState, map[string]lua.LValue{
"FIELDS": lua.LNil,
})
whereeval.luaState.Push(whereeval.fn)
if err := whereeval.luaState.PCall(0, 1, nil); err != nil {
panic(err.Error())
}
ret := whereeval.luaState.Get(-1)
whereeval.luaState.Pop(1)
// Make bool out of returned lua value
switch ret.Type() {
case lua.LTNil:
return false
case lua.LTBool:
return ret == lua.LTrue
case lua.LTNumber:
return float64(ret.(lua.LNumber)) != 0
case lua.LTString:
return ret.String() != ""
case lua.LTTable:
tbl := ret.(*lua.LTable)
if tbl.Len() != 0 {
return true
}
var match bool
tbl.ForEach(func(lk lua.LValue, lv lua.LValue) {match = true})
return match
}
panic(fmt.Sprintf("Script returned value of type %s", ret.Type()))
}
type searchScanBaseTokens struct {
key string
cursor uint64
@ -179,6 +238,7 @@ type searchScanBaseTokens struct {
glob string
wheres []whereT
whereins []whereinT
whereevals []whereevalT
nofields bool
ulimit bool
limit uint64
@ -187,7 +247,7 @@ type searchScanBaseTokens struct {
desc bool
}
func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, t searchScanBaseTokens, err error) {
func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, t searchScanBaseTokens, err error) {
var ok bool
if vs, t.key, ok = tokenval(vs); !ok || t.key == "" {
err = errInvalidNumberOfArguments
@ -288,6 +348,76 @@ func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value,
}
t.whereins = append(t.whereins, whereinT{field, valMap})
continue
} else if (wtok[0] == 'W' || wtok[0] == 'w') && strings.Contains(strings.ToLower(wtok), "whereeval") {
scriptIsSha := strings.ToLower(wtok) == "whereevalsha"
vs = nvs
var script, nargsStr, arg string
if vs, script, ok = tokenval(vs); !ok || script == "" {
err = errInvalidNumberOfArguments
return
}
if vs, nargsStr, ok = tokenval(vs); !ok || nargsStr == "" {
err = errInvalidNumberOfArguments
return
}
var i, nargs uint64
if nargs, err = strconv.ParseUint(nargsStr, 10, 64); err != nil {
err = errInvalidArgument(nargsStr)
return
}
var luaState *lua.LState
luaState, err = c.luapool.Get()
if err != nil {
return
}
argsTbl := luaState.CreateTable(len(vs), 0)
for i = 0; i < nargs; i++ {
if vs, arg, ok = tokenval(vs); !ok || arg == "" {
err = errInvalidNumberOfArguments
return
}
argsTbl.Append(lua.LString(arg))
}
var shaSum string
if scriptIsSha {
shaSum = script
} else {
shaSum = Sha1Sum(script)
}
luaSetRawGlobals(
luaState, map[string]lua.LValue{
"ARGV": argsTbl,
})
compiled, ok := c.luascripts.Get(shaSum)
var fn *lua.LFunction
if ok {
fn = &lua.LFunction{
IsG: false,
Env: luaState.Env,
Proto: compiled,
GFunction: nil,
Upvalues: make([]*lua.Upvalue, 0),
}
} else if scriptIsSha {
err = errShaNotFound
return
} else {
fn, err = luaState.Load(strings.NewReader(script), "f_"+shaSum)
if err != nil {
err = makeSafeErr(err)
return
}
c.luascripts.Put(shaSum, fn.Proto)
}
t.whereevals = append(t.whereevals, whereevalT{c,luaState, fn})
continue
} else if (wtok[0] == 'N' || wtok[0] == 'n') && strings.ToLower(wtok) == "nofields" {
vs = nvs
if t.nofields {