tile38/internal/server/token.go

846 lines
19 KiB
Go

package server
import (
"errors"
"fmt"
"math"
"strconv"
"strings"
"github.com/tidwall/tile38/internal/field"
lua "github.com/yuin/gopher-lua"
)
const defaultSearchOutput = outputObjects
var errInvalidNumberOfArguments = errors.New("invalid number of arguments")
var errKeyNotFound = errors.New("key not found")
var errIDNotFound = errors.New("id not found")
var errIDAlreadyExists = errors.New("id already exists")
var errPathNotFound = errors.New("path not found")
var errKeyHasHooksSet = errors.New("key has hooks set")
var errNotRectangle = errors.New("not a rectangle")
func errInvalidArgument(arg string) error {
return fmt.Errorf("invalid argument '%s'", arg)
}
func errDuplicateArgument(arg string) error {
return fmt.Errorf("duplicate argument '%s'", arg)
}
func token(line string) (newLine, token string) {
for i := 0; i < len(line); i++ {
if line[i] == ' ' {
return line[i+1:], line[:i]
}
}
return "", line
}
func tokenval(vs []string) (nvs []string, token string, ok bool) {
if len(vs) > 0 {
token = vs[0]
nvs = vs[1:]
ok = true
}
return
}
func lc(s1, s2 string) bool {
if len(s1) != len(s2) {
return false
}
for i := 0; i < len(s1); i++ {
ch := s1[i]
if ch >= 'A' && ch <= 'Z' {
if ch+32 != s2[i] {
return false
}
} else if ch != s2[i] {
return false
}
}
return true
}
type whereT struct {
expr bool
name string
minx bool
min field.Value
maxx bool
max field.Value
}
func mLT(a, b field.Value) bool { return a.Less(b) }
func mLTE(a, b field.Value) bool { return !mLT(b, a) }
func mGT(a, b field.Value) bool { return mLT(b, a) }
func mGTE(a, b field.Value) bool { return !mLT(a, b) }
func mEQ(a, b field.Value) bool { return a.Equals(b) }
func (where whereT) matchField(value field.Value) bool {
switch where.min.Data() {
case "<":
return mLT(value, where.max)
case "<=":
return mLTE(value, where.max)
case ">":
return mGT(value, where.max)
case ">=":
return mGTE(value, where.max)
case "==":
return mEQ(value, where.max)
case "!=":
return !mEQ(value, where.max)
}
if !where.minx {
if mLT(value, where.min) { // if value < where.min {
return false
}
} else {
if mLTE(value, where.min) { // if value <= where.min {
return false
}
}
if !where.maxx {
if mGT(value, where.max) { // if value > where.max {
return false
}
} else {
if mGTE(value, where.max) { // if value >= where.max {
return false
}
}
return true
}
type whereinT struct {
name string
valArr []field.Value
}
func (wherein whereinT) match(value field.Value) bool {
for _, val := range wherein.valArr {
if mEQ(val, value) {
return true
}
}
return false
}
type whereevalT struct {
c *Server
luaState *lua.LState
fn *lua.LFunction
}
func (whereeval whereevalT) Close() {
luaSetRawGlobals(
whereeval.luaState, map[string]lua.LValue{
"ARGV": lua.LNil,
})
whereeval.c.luapool.Put(whereeval.luaState)
}
func luaSetField(tbl *lua.LTable, name string, val field.Value) {
var lval lua.LValue
switch val.Kind() {
case field.Null:
lval = lua.LNil
case field.False:
lval = lua.LFalse
case field.True:
lval = lua.LTrue
case field.Number:
lval = lua.LNumber(val.Num())
default:
lval = lua.LString(val.Data())
}
tbl.RawSetString(name, lval)
}
func (whereeval whereevalT) match(fieldsWithNames map[string]field.Value) (bool, error) {
fieldsTbl := whereeval.luaState.CreateTable(0, len(fieldsWithNames))
for name, val := range fieldsWithNames {
luaSetField(fieldsTbl, name, val)
}
luaSetRawGlobals(
whereeval.luaState, map[string]lua.LValue{
"FIELDS": fieldsTbl,
})
defer luaSetRawGlobals(
whereeval.luaState, map[string]lua.LValue{
"FIELDS": lua.LNil,
})
whereeval.luaState.Push(whereeval.fn)
if err := whereeval.luaState.PCall(0, 1, nil); err != nil {
return false, err
}
ret := whereeval.luaState.Get(-1)
whereeval.luaState.Pop(1)
// Make bool out of returned lua value
switch ret.Type() {
case lua.LTNil:
return false, nil
case lua.LTBool:
return ret == lua.LTrue, nil
case lua.LTNumber:
return float64(ret.(lua.LNumber)) != 0, nil
case lua.LTString:
return ret.String() != "", nil
case lua.LTTable:
tbl := ret.(*lua.LTable)
if tbl.Len() != 0 {
return true, nil
}
var match bool
tbl.ForEach(func(lk lua.LValue, lv lua.LValue) { match = true })
return match, nil
}
return false, fmt.Errorf("script returned value of type %s", ret.Type())
}
type searchScanBaseTokens struct {
key string
cursor uint64
output outputT
precision uint64
fence bool
distance bool
nodwell bool
detect map[string]bool
accept map[string]bool
globs []string
wheres []whereT
whereins []whereinT
whereevals []whereevalT
nofields bool
ulimit bool
limit uint64
usparse bool
sparse uint8
desc bool
clip bool
buffer float64
hasbuffer bool
}
func (s *Server) parseSearchScanBaseTokens(
cmd string, t searchScanBaseTokens, vs []string,
) (
vsout []string, tout searchScanBaseTokens, err error,
) {
var ok bool
if vs, t.key, ok = tokenval(vs); !ok || t.key == "" {
err = errInvalidNumberOfArguments
return
}
fromFence := t.fence
var slimit string
var ssparse string
var scursor string
var asc bool
for {
nvs, wtok, ok := tokenval(vs)
if ok && len(wtok) > 0 {
switch strings.ToLower(wtok) {
case "buffer":
vs = nvs
var sbuf string
if vs, sbuf, ok = tokenval(vs); !ok || sbuf == "" {
err = errInvalidNumberOfArguments
return
}
var buf float64
buf, err = strconv.ParseFloat(sbuf, 64)
if err != nil || buf < 0 || math.IsInf(buf, 0) || math.IsNaN(buf) {
err = errInvalidArgument(sbuf)
return
}
t.buffer = buf
t.hasbuffer = true
continue
case "cursor":
vs = nvs
if scursor != "" {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
if vs, scursor, ok = tokenval(vs); !ok || scursor == "" {
err = errInvalidNumberOfArguments
return
}
continue
case "where":
vs = nvs
if detectExprToken(vs) {
// using expressions
// WHERE expr
var expr string
if vs, expr, ok = tokenval(vs); !ok {
err = errInvalidNumberOfArguments
return
}
t.wheres = append(t.wheres, whereT{name: expr, expr: true})
continue
} else {
// using field filter
// WHERE min max
var name, smin, smax string
if vs, name, ok = tokenval(vs); !ok {
err = errInvalidNumberOfArguments
return
}
if vs, smin, ok = tokenval(vs); !ok {
err = errInvalidNumberOfArguments
return
}
if vs, smax, ok = tokenval(vs); !ok {
err = errInvalidNumberOfArguments
return
}
var minx, maxx bool
smin = strings.ToLower(smin)
smax = strings.ToLower(smax)
if smax == "+inf" || smax == "inf" {
smax = "inf"
}
switch smin {
case "<", "<=", ">", ">=", "==", "!=":
default:
if strings.HasPrefix(smin, "(") {
minx = true
smin = smin[1:]
}
if strings.HasPrefix(smax, "(") {
maxx = true
smax = smax[1:]
}
}
t.wheres = append(t.wheres, whereT{
name: name,
minx: minx,
min: field.ValueOf(smin),
maxx: maxx,
max: field.ValueOf(smax),
})
continue
}
case "wherein":
vs = nvs
var name, nvalsStr, valStr string
if vs, name, ok = tokenval(vs); !ok {
err = errInvalidNumberOfArguments
return
}
if vs, nvalsStr, ok = tokenval(vs); !ok {
err = errInvalidNumberOfArguments
return
}
var i, nvals uint64
if nvals, err = strconv.ParseUint(nvalsStr, 10, 64); err != nil {
err = errInvalidArgument(nvalsStr)
return
}
valArr := make([]field.Value, nvals)
for i = 0; i < nvals; i++ {
if vs, valStr, ok = tokenval(vs); !ok {
err = errInvalidNumberOfArguments
return
}
valArr[i] = field.ValueOf(valStr)
}
t.whereins = append(t.whereins, whereinT{
name: strings.ToLower(name),
valArr: valArr,
})
continue
case "whereevalsha":
fallthrough
case "whereeval":
scriptIsSha := strings.ToLower(wtok) == "whereevalsha"
vs = nvs
var script, nargsStr, arg string
if vs, script, ok = tokenval(vs); !ok || script == "" {
err = errInvalidNumberOfArguments
return
}
if vs, nargsStr, ok = tokenval(vs); !ok || nargsStr == "" {
err = errInvalidNumberOfArguments
return
}
var i, nargs uint64
if nargs, err = strconv.ParseUint(nargsStr, 10, 64); err != nil {
err = errInvalidArgument(nargsStr)
return
}
var luaState *lua.LState
luaState, err = s.luapool.Get()
if err != nil {
return
}
argsTbl := luaState.CreateTable(len(vs), 0)
for i = 0; i < nargs; i++ {
if vs, arg, ok = tokenval(vs); !ok || arg == "" {
err = errInvalidNumberOfArguments
return
}
argsTbl.Append(lua.LString(arg))
}
var shaSum string
if scriptIsSha {
shaSum = script
} else {
shaSum = Sha1Sum(script)
}
luaSetRawGlobals(
luaState, map[string]lua.LValue{
"ARGV": argsTbl,
})
compiled, ok := s.luascripts.Get(shaSum)
var fn *lua.LFunction
if ok {
fn = &lua.LFunction{
IsG: false,
Env: luaState.Env,
Proto: compiled,
GFunction: nil,
Upvalues: make([]*lua.Upvalue, 0),
}
} else if scriptIsSha {
err = errShaNotFound
return
} else {
fn, err = luaState.Load(strings.NewReader(script), "f_"+shaSum)
if err != nil {
err = makeSafeErr(err)
return
}
s.luascripts.PutLRU(shaSum, fn.Proto)
}
t.whereevals = append(t.whereevals, whereevalT{
c: s, luaState: luaState, fn: fn,
})
continue
case "nofields":
vs = nvs
if t.nofields {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
t.nofields = true
continue
case "limit":
vs = nvs
if slimit != "" {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
if vs, slimit, ok = tokenval(vs); !ok || slimit == "" {
err = errInvalidNumberOfArguments
return
}
continue
case "sparse":
vs = nvs
if ssparse != "" {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
if vs, ssparse, ok = tokenval(vs); !ok || ssparse == "" {
err = errInvalidNumberOfArguments
return
}
continue
case "fence":
vs = nvs
if t.fence && !fromFence {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
t.fence = true
continue
case "commands":
vs = nvs
if t.accept != nil {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
t.accept = make(map[string]bool)
var peek string
if vs, peek, ok = tokenval(vs); !ok || peek == "" {
err = errInvalidNumberOfArguments
return
}
for _, s := range strings.Split(peek, ",") {
part := strings.TrimSpace(strings.ToLower(s))
if t.accept[part] {
err = errDuplicateArgument(s)
return
}
t.accept[part] = true
}
if len(t.accept) == 0 {
t.accept = nil
}
continue
case "distance":
vs = nvs
if t.distance {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
t.distance = true
continue
case "detect":
vs = nvs
if t.detect != nil {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
t.detect = make(map[string]bool)
var peek string
if vs, peek, ok = tokenval(vs); !ok || peek == "" {
err = errInvalidNumberOfArguments
return
}
for _, s := range strings.Split(peek, ",") {
part := strings.TrimSpace(strings.ToLower(s))
switch part {
default:
err = errInvalidArgument(peek)
return
case "inside", "outside", "enter", "exit", "cross":
}
if t.detect[part] {
err = errDuplicateArgument(s)
return
}
t.detect[part] = true
}
if len(t.detect) == 0 {
t.detect = map[string]bool{
"inside": true,
"outside": true,
"enter": true,
"exit": true,
"cross": true,
}
}
continue
case "nodwell":
vs = nvs
if t.desc || asc {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
t.nodwell = true
continue
case "desc":
vs = nvs
if t.desc || asc {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
t.desc = true
continue
case "asc":
vs = nvs
if t.desc || asc {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
asc = true
continue
case "match":
vs = nvs
var glob string
if vs, glob, ok = tokenval(vs); !ok || glob == "" {
err = errInvalidNumberOfArguments
return
}
t.globs = append(t.globs, glob)
continue
case "clip":
vs = nvs
if t.clip {
err = errDuplicateArgument(strings.ToUpper(wtok))
return
}
t.clip = true
continue
}
}
break
}
// check to make sure that there aren't any conflicts
if cmd == "scan" || cmd == "search" {
if ssparse != "" {
err = errors.New("SPARSE is not allowed for " + strings.ToUpper(cmd))
return
}
if t.fence {
err = errors.New("FENCE is not allowed for " + strings.ToUpper(cmd))
return
}
} else {
if t.desc {
err = errors.New("DESC is not allowed for " + strings.ToUpper(cmd))
return
}
if asc {
err = errors.New("ASC is not allowed for " + strings.ToUpper(cmd))
return
}
}
if ssparse != "" && slimit != "" {
err = errors.New("LIMIT is not allowed when SPARSE is specified")
return
}
if scursor != "" && ssparse != "" {
err = errors.New("CURSOR is not allowed when SPARSE is specified")
return
}
if scursor != "" && t.fence {
err = errors.New("CURSOR is not allowed when FENCE is specified")
return
}
if t.detect != nil && !t.fence {
err = errors.New("DETECT is not allowed when FENCE is not specified")
return
}
t.output = defaultSearchOutput
var nvs []string
var sprecision string
var which string
if nvs, which, ok = tokenval(vs); ok && which != "" {
updline := true
switch strings.ToLower(which) {
default:
if cmd == "scan" {
err = errInvalidArgument(which)
return
}
updline = false
case "count":
t.output = outputCount
case "objects":
t.output = outputObjects
case "points":
t.output = outputPoints
case "hashes":
t.output = outputHashes
if nvs, sprecision, ok = tokenval(nvs); !ok || sprecision == "" {
err = errInvalidNumberOfArguments
return
}
case "bounds":
t.output = outputBounds
case "ids":
t.output = outputIDs
}
if updline {
vs = nvs
}
}
if scursor != "" {
if t.cursor, err = strconv.ParseUint(scursor, 10, 64); err != nil {
err = errInvalidArgument(scursor)
return
}
}
if sprecision != "" {
t.precision, err = strconv.ParseUint(sprecision, 10, 64)
if err != nil || t.precision == 0 || t.precision > 12 {
err = errInvalidArgument(sprecision)
return
}
}
if slimit != "" {
t.ulimit = true
if t.limit, err = strconv.ParseUint(slimit, 10, 64); err != nil || t.limit == 0 {
err = errInvalidArgument(slimit)
return
}
}
if ssparse != "" {
t.usparse = true
var sparse uint64
if sparse, err = strconv.ParseUint(ssparse, 10, 8); err != nil || sparse == 0 || sparse > 8 {
err = errInvalidArgument(ssparse)
return
}
t.sparse = uint8(sparse)
t.limit = math.MaxUint64
}
vsout = vs
tout = t
return
}
func detectExprToken(vs []string) bool {
// Detect the kind of where, either:
// - expr
// - name min max
if len(vs) == 0 {
return false
} else if len(vs) == 1 || (len(vs) == 2 && len(vs[1]) == 0) {
return true
}
v := vs[1]
if (v[0] >= 'a' && v[0] <= 'z') || (v[0] >= 'A' && v[0] <= 'Z') {
if (v[0] == 'i' || v[0] == 'I') && strings.ToLower(v) == "inf" {
return false
}
return true
}
return false
}
type parentStack []*areaExpression
func (ps *parentStack) isEmpty() bool {
return len(*ps) == 0
}
func (ps *parentStack) push(e *areaExpression) {
*ps = append(*ps, e)
}
func (ps *parentStack) pop() (e *areaExpression, empty bool) {
n := len(*ps)
if n == 0 {
return nil, true
}
x := (*ps)[n-1]
*ps = (*ps)[:n-1]
return x, false
}
func (s *Server) parseAreaExpression(vsin []string, doClip bool) (vsout []string, ae *areaExpression, err error) {
ps := &parentStack{}
vsout = vsin[:]
var negate, needObj bool
loop:
for {
nvs, wtok, ok := tokenval(vsout)
if !ok || len(wtok) == 0 {
break
}
switch strings.ToLower(wtok) {
case tokenLParen:
newExpr := &areaExpression{negate: negate, op: NOOP}
negate = false
needObj = false
if ae != nil {
ae.children = append(ae.children, newExpr)
}
ae = newExpr
ps.push(ae)
vsout = nvs
case tokenRParen:
if needObj {
err = errInvalidArgument(tokenRParen)
return
}
parent, empty := ps.pop()
if empty {
err = errInvalidArgument(tokenRParen)
return
}
ae = parent
vsout = nvs
case tokenNOT:
negate = !negate
needObj = true
vsout = nvs
case tokenAND:
if needObj {
err = errInvalidArgument(tokenAND)
return
}
needObj = true
if ae == nil {
err = errInvalidArgument(tokenAND)
return
} else if ae.obj == nil {
switch ae.op {
case OR:
numChildren := len(ae.children)
if numChildren < 2 {
err = errInvalidNumberOfArguments
return
}
ae.children = append(
ae.children[:numChildren-1],
&areaExpression{
op: AND,
children: []*areaExpression{ae.children[numChildren-1]}})
case NOOP:
ae.op = AND
}
} else {
ae = &areaExpression{op: AND, children: []*areaExpression{ae}}
}
vsout = nvs
case tokenOR:
if needObj {
err = errInvalidArgument(tokenOR)
return
}
needObj = true
if ae == nil {
err = errInvalidArgument(tokenOR)
return
} else if ae.obj == nil {
switch ae.op {
case AND:
if len(ae.children) < 2 {
err = errInvalidNumberOfArguments
return
}
ae = &areaExpression{op: OR, children: []*areaExpression{ae}}
case NOOP:
ae.op = OR
}
} else {
ae = &areaExpression{op: OR, children: []*areaExpression{ae}}
}
vsout = nvs
case "point", "circle", "object", "bounds", "hash", "quadkey", "tile", "get", "sector":
parsedVs, parsedObj, areaErr := s.parseArea(vsout, doClip)
if areaErr != nil {
err = areaErr
return
}
newExpr := &areaExpression{negate: negate, obj: parsedObj, op: NOOP}
negate = false
needObj = false
if ae == nil {
ae = newExpr
} else {
ae.children = append(ae.children, newExpr)
}
vsout = parsedVs
default:
break loop
}
}
if !ps.isEmpty() || needObj || ae == nil || (ae.obj == nil && len(ae.children) == 0) {
err = errInvalidNumberOfArguments
}
return
}