mirror of https://github.com/tidwall/tile38.git
Add WHEREEVAL clause to scan/search commands.
This commit is contained in:
parent
2088b5d2d2
commit
29f4737e5d
|
@ -97,6 +97,7 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandD
|
|||
break
|
||||
}
|
||||
s, err := c.cmdSearchArgs(cmdlc, vs, types)
|
||||
defer s.Close()
|
||||
if err != nil {
|
||||
return server.NOMessage, d, err
|
||||
}
|
||||
|
@ -132,7 +133,9 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandD
|
|||
hook.cond = sync.NewCond(&hook.mu)
|
||||
|
||||
var wr bytes.Buffer
|
||||
hook.ScanWriter, err = c.newScanWriter(&wr, cmsg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
|
||||
hook.ScanWriter, err = c.newScanWriter(
|
||||
&wr, cmsg, s.key, s.output, s.precision, s.glob, false,
|
||||
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
|
||||
if err != nil {
|
||||
return server.NOMessage, d, err
|
||||
}
|
||||
|
|
|
@ -87,7 +87,9 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.PipelineReade
|
|||
lb.key = s.key
|
||||
lb.fence = &s
|
||||
c.mu.RLock()
|
||||
sw, err = c.newScanWriter(&wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
|
||||
sw, err = c.newScanWriter(
|
||||
&wr, msg, s.key, s.output, s.precision, s.glob, false,
|
||||
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
|
||||
c.mu.RUnlock()
|
||||
}
|
||||
// everything below if for live SCAN, NEARBY, WITHIN, INTERSECTS
|
||||
|
|
|
@ -2,6 +2,7 @@ package controller
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/resp"
|
||||
|
@ -10,8 +11,8 @@ import (
|
|||
"github.com/tidwall/tile38/geojson"
|
||||
)
|
||||
|
||||
func cmdScanArgs(vs []resp.Value) (s liveFenceSwitches, err error) {
|
||||
if vs, s.searchScanBaseTokens, err = parseSearchScanBaseTokens("scan", vs); err != nil {
|
||||
func (c *Controller) cmdScanArgs(vs []resp.Value) (s liveFenceSwitches, err error) {
|
||||
if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens("scan", vs); err != nil {
|
||||
return
|
||||
}
|
||||
if len(vs) != 0 {
|
||||
|
@ -25,12 +26,22 @@ func (c *Controller) cmdScan(msg *server.Message) (res resp.Value, err error) {
|
|||
start := time.Now()
|
||||
vs := msg.Values[1:]
|
||||
|
||||
s, err := cmdScanArgs(vs)
|
||||
s, err := c.cmdScanArgs(vs)
|
||||
defer s.Close()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
res = server.NOMessage
|
||||
err = errors.New(r.(string))
|
||||
return
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return server.NOMessage, err
|
||||
}
|
||||
wr := &bytes.Buffer{}
|
||||
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
|
||||
sw, err := c.newScanWriter(
|
||||
wr, msg, s.key, s.output, s.precision, s.glob, false,
|
||||
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
|
||||
if err != nil {
|
||||
return server.NOMessage, err
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ type scanWriter struct {
|
|||
output outputT
|
||||
wheres []whereT
|
||||
whereins []whereinT
|
||||
whereevals []whereevalT
|
||||
numberItems uint64
|
||||
nofields bool
|
||||
cursor uint64
|
||||
|
@ -68,7 +69,7 @@ type ScanWriterParams struct {
|
|||
func (c *Controller) newScanWriter(
|
||||
wr *bytes.Buffer, msg *server.Message, key string, output outputT,
|
||||
precision uint64, globPattern string, matchValues bool,
|
||||
cursor, limit uint64, wheres []whereT, whereins []whereinT, nofields bool,
|
||||
cursor, limit uint64, wheres []whereT, whereins []whereinT, whereevals []whereevalT, nofields bool,
|
||||
) (
|
||||
*scanWriter, error,
|
||||
) {
|
||||
|
@ -92,6 +93,7 @@ func (c *Controller) newScanWriter(
|
|||
limit: limit,
|
||||
wheres: wheres,
|
||||
whereins: whereins,
|
||||
whereevals: whereevals,
|
||||
output: output,
|
||||
nofields: nofields,
|
||||
precision: precision,
|
||||
|
@ -186,9 +188,10 @@ func (sw *scanWriter) writeFoot() {
|
|||
}
|
||||
}
|
||||
|
||||
func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64, bool) {
|
||||
func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) (fvals []float64, match bool) {
|
||||
var z float64
|
||||
var gotz bool
|
||||
fvals = sw.fvals
|
||||
if !sw.hasFieldsOutput() || sw.fullFields {
|
||||
for _, where := range sw.wheres {
|
||||
if where.field == "z" {
|
||||
|
@ -196,7 +199,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
|
|||
z = o.CalculatedPoint().Z
|
||||
}
|
||||
if !where.match(z) {
|
||||
return sw.fvals, false
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
@ -208,7 +211,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
|
|||
}
|
||||
}
|
||||
if !where.match(value) {
|
||||
return sw.fvals, false
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, wherein := range sw.whereins {
|
||||
|
@ -220,7 +223,20 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
|
|||
}
|
||||
}
|
||||
if !wherein.match(value) {
|
||||
return sw.fvals, false
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, whereval := range sw.whereevals {
|
||||
fieldsWithNames := make(map[string]float64)
|
||||
for field, idx := range sw.fmap {
|
||||
if idx < len(fields) {
|
||||
fieldsWithNames[field] = fields[idx]
|
||||
} else {
|
||||
fieldsWithNames[field] = 0
|
||||
}
|
||||
}
|
||||
if !whereval.match(fieldsWithNames) {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -237,7 +253,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
|
|||
z = o.CalculatedPoint().Z
|
||||
}
|
||||
if !where.match(z) {
|
||||
return sw.fvals, false
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
@ -247,7 +263,7 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
|
|||
value = sw.fvals[idx]
|
||||
}
|
||||
if !where.match(value) {
|
||||
return sw.fvals, false
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, wherein := range sw.whereins {
|
||||
|
@ -257,11 +273,25 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
|
|||
value = sw.fvals[idx]
|
||||
}
|
||||
if !wherein.match(value) {
|
||||
return sw.fvals, false
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, whereval := range sw.whereevals {
|
||||
fieldsWithNames := make(map[string]float64)
|
||||
for field, idx := range sw.fmap {
|
||||
if idx < len(fields) {
|
||||
fieldsWithNames[field] = fields[idx]
|
||||
} else {
|
||||
fieldsWithNames[field] = 0
|
||||
}
|
||||
}
|
||||
if !whereval.match(fieldsWithNames) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return sw.fvals, true
|
||||
match = true
|
||||
return
|
||||
}
|
||||
|
||||
//id string, o geojson.Object, fields []float64, noLock bool
|
||||
|
|
|
@ -2,6 +2,7 @@ package controller
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -40,8 +41,14 @@ func (s liveFenceSwitches) Error() string {
|
|||
return "going live"
|
||||
}
|
||||
|
||||
func (s liveFenceSwitches) Close() {
|
||||
for _, whereeval := range s.searchScanBaseTokens.whereevals {
|
||||
whereeval.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) cmdSearchArgs(cmd string, vs []resp.Value, types []string) (s liveFenceSwitches, err error) {
|
||||
if vs, s.searchScanBaseTokens, err = parseSearchScanBaseTokens(cmd, vs); err != nil {
|
||||
if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens(cmd, vs); err != nil {
|
||||
return
|
||||
}
|
||||
var typ string
|
||||
|
@ -288,6 +295,14 @@ func (c *Controller) cmdNearby(msg *server.Message) (res resp.Value, err error)
|
|||
vs := msg.Values[1:]
|
||||
wr := &bytes.Buffer{}
|
||||
s, err := c.cmdSearchArgs("nearby", vs, nearbyTypes)
|
||||
defer s.Close()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
res = server.NOMessage
|
||||
err = errors.New(r.(string))
|
||||
return
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return server.NOMessage, err
|
||||
}
|
||||
|
@ -296,7 +311,9 @@ func (c *Controller) cmdNearby(msg *server.Message) (res resp.Value, err error)
|
|||
return server.NOMessage, s
|
||||
}
|
||||
minZ, maxZ := zMinMaxFromWheres(s.wheres)
|
||||
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
|
||||
sw, err := c.newScanWriter(
|
||||
wr, msg, s.key, s.output, s.precision, s.glob, false,
|
||||
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
|
||||
if err != nil {
|
||||
return server.NOMessage, err
|
||||
}
|
||||
|
@ -385,6 +402,15 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res
|
|||
|
||||
wr := &bytes.Buffer{}
|
||||
s, err := c.cmdSearchArgs(cmd, vs, withinOrIntersectsTypes)
|
||||
defer s.Close()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
res = server.NOMessage
|
||||
err = errors.New(r.(string))
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
return server.NOMessage, err
|
||||
}
|
||||
|
@ -392,7 +418,9 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res
|
|||
if s.fence {
|
||||
return server.NOMessage, s
|
||||
}
|
||||
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
|
||||
sw, err := c.newScanWriter(
|
||||
wr, msg, s.key, s.output, s.precision, s.glob, false,
|
||||
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
|
||||
if err != nil {
|
||||
return server.NOMessage, err
|
||||
}
|
||||
|
@ -440,8 +468,8 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res
|
|||
return sw.respOut, nil
|
||||
}
|
||||
|
||||
func cmdSeachValuesArgs(vs []resp.Value) (s liveFenceSwitches, err error) {
|
||||
if vs, s.searchScanBaseTokens, err = parseSearchScanBaseTokens("search", vs); err != nil {
|
||||
func (c *Controller) cmdSeachValuesArgs(vs []resp.Value) (s liveFenceSwitches, err error) {
|
||||
if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens("search", vs); err != nil {
|
||||
return
|
||||
}
|
||||
if len(vs) != 0 {
|
||||
|
@ -456,11 +484,21 @@ func (c *Controller) cmdSearch(msg *server.Message) (res resp.Value, err error)
|
|||
vs := msg.Values[1:]
|
||||
|
||||
wr := &bytes.Buffer{}
|
||||
s, err := cmdSeachValuesArgs(vs)
|
||||
s, err := c.cmdSeachValuesArgs(vs)
|
||||
defer s.Close()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
res = server.NOMessage
|
||||
err = errors.New(r.(string))
|
||||
return
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return server.NOMessage, err
|
||||
}
|
||||
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, true, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
|
||||
sw, err := c.newScanWriter(
|
||||
wr, msg, s.key, s.output, s.precision, s.glob, true,
|
||||
s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields)
|
||||
if err != nil {
|
||||
return server.NOMessage, err
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/tidwall/resp"
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
const defaultSearchOutput = outputObjects
|
||||
|
@ -166,6 +167,64 @@ func (wherein whereinT) match(value float64) bool {
|
|||
return ok
|
||||
}
|
||||
|
||||
type whereevalT struct {
|
||||
c *Controller
|
||||
luaState *lua.LState
|
||||
fn *lua.LFunction
|
||||
}
|
||||
|
||||
func (whereeval whereevalT) Close() {
|
||||
luaSetRawGlobals(
|
||||
whereeval.luaState, map[string]lua.LValue{
|
||||
"ARGV": lua.LNil,
|
||||
})
|
||||
whereeval.c.luapool.Put(whereeval.luaState)
|
||||
}
|
||||
|
||||
func (whereeval whereevalT) match(fieldsWithNames map[string]float64) bool {
|
||||
fieldsTbl := whereeval.luaState.CreateTable(0, len(fieldsWithNames))
|
||||
for field, val := range fieldsWithNames {
|
||||
fieldsTbl.RawSetString(field, lua.LNumber(val))
|
||||
}
|
||||
|
||||
luaSetRawGlobals(
|
||||
whereeval.luaState, map[string]lua.LValue{
|
||||
"FIELDS": fieldsTbl,
|
||||
})
|
||||
defer luaSetRawGlobals(
|
||||
whereeval.luaState, map[string]lua.LValue{
|
||||
"FIELDS": lua.LNil,
|
||||
})
|
||||
|
||||
whereeval.luaState.Push(whereeval.fn)
|
||||
if err := whereeval.luaState.PCall(0, 1, nil); err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
ret := whereeval.luaState.Get(-1)
|
||||
whereeval.luaState.Pop(1)
|
||||
|
||||
// Make bool out of returned lua value
|
||||
switch ret.Type() {
|
||||
case lua.LTNil:
|
||||
return false
|
||||
case lua.LTBool:
|
||||
return ret == lua.LTrue
|
||||
case lua.LTNumber:
|
||||
return float64(ret.(lua.LNumber)) != 0
|
||||
case lua.LTString:
|
||||
return ret.String() != ""
|
||||
case lua.LTTable:
|
||||
tbl := ret.(*lua.LTable)
|
||||
if tbl.Len() != 0 {
|
||||
return true
|
||||
}
|
||||
var match bool
|
||||
tbl.ForEach(func(lk lua.LValue, lv lua.LValue) {match = true})
|
||||
return match
|
||||
}
|
||||
panic(fmt.Sprintf("Script returned value of type %s", ret.Type()))
|
||||
}
|
||||
|
||||
type searchScanBaseTokens struct {
|
||||
key string
|
||||
cursor uint64
|
||||
|
@ -179,6 +238,7 @@ type searchScanBaseTokens struct {
|
|||
glob string
|
||||
wheres []whereT
|
||||
whereins []whereinT
|
||||
whereevals []whereevalT
|
||||
nofields bool
|
||||
ulimit bool
|
||||
limit uint64
|
||||
|
@ -187,7 +247,7 @@ type searchScanBaseTokens struct {
|
|||
desc bool
|
||||
}
|
||||
|
||||
func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, t searchScanBaseTokens, err error) {
|
||||
func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, t searchScanBaseTokens, err error) {
|
||||
var ok bool
|
||||
if vs, t.key, ok = tokenval(vs); !ok || t.key == "" {
|
||||
err = errInvalidNumberOfArguments
|
||||
|
@ -288,6 +348,76 @@ func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value,
|
|||
}
|
||||
t.whereins = append(t.whereins, whereinT{field, valMap})
|
||||
continue
|
||||
} else if (wtok[0] == 'W' || wtok[0] == 'w') && strings.Contains(strings.ToLower(wtok), "whereeval") {
|
||||
scriptIsSha := strings.ToLower(wtok) == "whereevalsha"
|
||||
vs = nvs
|
||||
var script, nargsStr, arg string
|
||||
if vs, script, ok = tokenval(vs); !ok || script == "" {
|
||||
err = errInvalidNumberOfArguments
|
||||
return
|
||||
}
|
||||
if vs, nargsStr, ok = tokenval(vs); !ok || nargsStr == "" {
|
||||
err = errInvalidNumberOfArguments
|
||||
return
|
||||
}
|
||||
|
||||
var i, nargs uint64
|
||||
if nargs, err = strconv.ParseUint(nargsStr, 10, 64); err != nil {
|
||||
err = errInvalidArgument(nargsStr)
|
||||
return
|
||||
}
|
||||
|
||||
var luaState *lua.LState
|
||||
luaState, err = c.luapool.Get()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
argsTbl := luaState.CreateTable(len(vs), 0)
|
||||
for i = 0; i < nargs; i++ {
|
||||
if vs, arg, ok = tokenval(vs); !ok || arg == "" {
|
||||
err = errInvalidNumberOfArguments
|
||||
return
|
||||
}
|
||||
argsTbl.Append(lua.LString(arg))
|
||||
}
|
||||
|
||||
var shaSum string
|
||||
if scriptIsSha {
|
||||
shaSum = script
|
||||
} else {
|
||||
shaSum = Sha1Sum(script)
|
||||
}
|
||||
|
||||
luaSetRawGlobals(
|
||||
luaState, map[string]lua.LValue{
|
||||
"ARGV": argsTbl,
|
||||
})
|
||||
|
||||
compiled, ok := c.luascripts.Get(shaSum)
|
||||
var fn *lua.LFunction
|
||||
if ok {
|
||||
fn = &lua.LFunction{
|
||||
IsG: false,
|
||||
Env: luaState.Env,
|
||||
|
||||
Proto: compiled,
|
||||
GFunction: nil,
|
||||
Upvalues: make([]*lua.Upvalue, 0),
|
||||
}
|
||||
} else if scriptIsSha {
|
||||
err = errShaNotFound
|
||||
return
|
||||
} else {
|
||||
fn, err = luaState.Load(strings.NewReader(script), "f_"+shaSum)
|
||||
if err != nil {
|
||||
err = makeSafeErr(err)
|
||||
return
|
||||
}
|
||||
c.luascripts.Put(shaSum, fn.Proto)
|
||||
}
|
||||
t.whereevals = append(t.whereevals, whereevalT{c,luaState, fn})
|
||||
continue
|
||||
} else if (wtok[0] == 'N' || wtok[0] == 'n') && strings.ToLower(wtok) == "nofields" {
|
||||
vs = nvs
|
||||
if t.nofields {
|
||||
|
|
Loading…
Reference in New Issue