Add wherein command

This commit is contained in:
Alex Roitman 2017-08-23 13:13:12 -07:00
parent 9cdbea19c1
commit c8ed7caa2e
6 changed files with 74 additions and 7 deletions

View File

@ -133,7 +133,7 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res string, d commandDetai
hook.cond = sync.NewCond(&hook.mu) hook.cond = sync.NewCond(&hook.mu)
var wr bytes.Buffer var wr bytes.Buffer
hook.ScanWriter, err = c.newScanWriter(&wr, cmsg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields) hook.ScanWriter, err = c.newScanWriter(&wr, cmsg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
if err != nil { if err != nil {
return "", d, err return "", d, err
} }

View File

@ -87,7 +87,7 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.AnyReaderWrit
lb.key = s.key lb.key = s.key
lb.fence = &s lb.fence = &s
c.mu.RLock() c.mu.RLock()
sw, err = c.newScanWriter(&wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields) sw, err = c.newScanWriter(&wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
c.mu.RUnlock() c.mu.RUnlock()
} }
// everything below if for live SCAN, NEARBY, WITHIN, INTERSECTS // everything below if for live SCAN, NEARBY, WITHIN, INTERSECTS

View File

@ -30,7 +30,7 @@ func (c *Controller) cmdScan(msg *server.Message) (res string, err error) {
if err != nil { if err != nil {
return "", err return "", err
} }
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields) sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -39,6 +39,7 @@ type scanWriter struct {
fvals []float64 fvals []float64
output outputT output outputT
wheres []whereT wheres []whereT
whereins []whereinT
numberItems uint64 numberItems uint64
nofields bool nofields bool
cursor uint64 cursor uint64
@ -66,7 +67,7 @@ type ScanWriterParams struct {
func (c *Controller) newScanWriter( func (c *Controller) newScanWriter(
wr *bytes.Buffer, msg *server.Message, key string, output outputT, wr *bytes.Buffer, msg *server.Message, key string, output outputT,
precision uint64, globPattern string, matchValues bool, precision uint64, globPattern string, matchValues bool,
cursor, limit uint64, wheres []whereT, nofields bool, cursor, limit uint64, wheres []whereT, whereins []whereinT, nofields bool,
) ( ) (
*scanWriter, error, *scanWriter, error,
) { ) {
@ -89,6 +90,7 @@ func (c *Controller) newScanWriter(
cursor: cursor, cursor: cursor,
limit: limit, limit: limit,
wheres: wheres, wheres: wheres,
whereins: whereins,
output: output, output: output,
nofields: nofields, nofields: nofields,
precision: precision, precision: precision,
@ -215,6 +217,18 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
return sw.fvals, false return sw.fvals, false
} }
} }
for _, wherein := range sw.whereins {
var value float64
idx, ok := sw.fmap[wherein.field]
if ok {
if len(fields) > idx {
value = fields[idx]
}
}
if !wherein.match(value) {
return sw.fvals, false
}
}
} else { } else {
for idx := range sw.farr { for idx := range sw.farr {
var value float64 var value float64
@ -242,6 +256,16 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64,
return sw.fvals, false return sw.fvals, false
} }
} }
for _, wherein := range sw.whereins {
var value float64
idx, ok := sw.fmap[wherein.field]
if ok {
value = sw.fvals[idx]
}
if !wherein.match(value) {
return sw.fvals, false
}
}
} }
return sw.fvals, true return sw.fvals, true
} }

View File

@ -296,7 +296,7 @@ func (c *Controller) cmdNearby(msg *server.Message) (res string, err error) {
return "", s return "", s
} }
minZ, maxZ := zMinMaxFromWheres(s.wheres) minZ, maxZ := zMinMaxFromWheres(s.wheres)
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields) sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -397,7 +397,7 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res
if s.fence { if s.fence {
return "", s return "", s
} }
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields) sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -464,7 +464,7 @@ func (c *Controller) cmdSearch(msg *server.Message) (res string, err error) {
if err != nil { if err != nil {
return "", err return "", err
} }
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, true, s.cursor, s.limit, s.wheres, s.nofields) sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, true, s.cursor, s.limit, s.wheres, s.whereins, s.nofields)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -156,6 +156,16 @@ func zMinMaxFromWheres(wheres []whereT) (minZ, maxZ float64) {
return return
} }
type whereinT struct {
field string
val_map map[float64]struct{}
}
func (wherein whereinT) match(value float64) bool {
_, ok := wherein.val_map[value]
return ok
}
type searchScanBaseTokens struct { type searchScanBaseTokens struct {
key string key string
cursor uint64 cursor uint64
@ -168,6 +178,7 @@ type searchScanBaseTokens struct {
accept map[string]bool accept map[string]bool
glob string glob string
wheres []whereT wheres []whereT
whereins []whereinT
nofields bool nofields bool
ulimit bool ulimit bool
limit uint64 limit uint64
@ -245,6 +256,38 @@ func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value,
} }
t.wheres = append(t.wheres, whereT{field, minx, min, maxx, max}) t.wheres = append(t.wheres, whereT{field, minx, min, maxx, max})
continue continue
} else if (wtok[0] == 'W' || wtok[0] == 'w') && strings.ToLower(wtok) == "wherein" {
vs = nvs
var field, nvals_str, val_str string
if vs, field, ok = tokenval(vs); !ok || field == "" {
err = errInvalidNumberOfArguments
return
}
if vs, nvals_str, ok = tokenval(vs); !ok || nvals_str == "" {
err = errInvalidNumberOfArguments
return
}
var i, nvals uint64
if nvals, err = strconv.ParseUint(nvals_str, 10, 64); err != nil {
err = errInvalidArgument(nvals_str)
return
}
val_map := make(map[float64]struct{})
var val float64
var empty struct{}
for i = 0; i < nvals; i++ {
if vs, val_str, ok = tokenval(vs); !ok || val_str == "" {
err = errInvalidNumberOfArguments
return
}
if val, err = strconv.ParseFloat(val_str, 64); err != nil {
err = errInvalidArgument(val_str)
return
}
val_map[val] = empty
}
t.whereins = append(t.whereins, whereinT{field, val_map})
continue
} else if (wtok[0] == 'N' || wtok[0] == 'n') && strings.ToLower(wtok) == "nofields" { } else if (wtok[0] == 'N' || wtok[0] == 'n') && strings.ToLower(wtok) == "nofields" {
vs = nvs vs = nvs
if t.nofields { if t.nofields {