diff --git a/controller/hooks.go b/controller/hooks.go index 62c79704..ea0a7dfe 100644 --- a/controller/hooks.go +++ b/controller/hooks.go @@ -133,7 +133,7 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res string, d commandDetai hook.cond = sync.NewCond(&hook.mu) var wr bytes.Buffer - hook.ScanWriter, err = c.newScanWriter(&wr, cmsg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields) + hook.ScanWriter, err = c.newScanWriter(&wr, cmsg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) if err != nil { return "", d, err } diff --git a/controller/live.go b/controller/live.go index 3bed9dfb..3c2c088d 100644 --- a/controller/live.go +++ b/controller/live.go @@ -87,7 +87,7 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.AnyReaderWrit lb.key = s.key lb.fence = &s c.mu.RLock() - sw, err = c.newScanWriter(&wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields) + sw, err = c.newScanWriter(&wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) c.mu.RUnlock() } // everything below if for live SCAN, NEARBY, WITHIN, INTERSECTS diff --git a/controller/scan.go b/controller/scan.go index 376d2fac..3a7bc2f8 100644 --- a/controller/scan.go +++ b/controller/scan.go @@ -30,7 +30,7 @@ func (c *Controller) cmdScan(msg *server.Message) (res string, err error) { if err != nil { return "", err } - sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields) + sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) if err != nil { return "", err } diff --git a/controller/scanner.go b/controller/scanner.go index d6e01f46..c569a206 100644 --- a/controller/scanner.go +++ b/controller/scanner.go @@ -39,6 +39,7 @@ type scanWriter struct { fvals []float64 output outputT wheres []whereT + whereins []whereinT numberItems uint64 nofields bool cursor uint64 @@ -66,7 +67,7 @@ type ScanWriterParams struct { func (c *Controller) newScanWriter( wr *bytes.Buffer, msg *server.Message, key string, output outputT, precision uint64, globPattern string, matchValues bool, - cursor, limit uint64, wheres []whereT, nofields bool, + cursor, limit uint64, wheres []whereT, whereins []whereinT, nofields bool, ) ( *scanWriter, error, ) { @@ -89,6 +90,7 @@ func (c *Controller) newScanWriter( cursor: cursor, limit: limit, wheres: wheres, + whereins: whereins, output: output, nofields: nofields, precision: precision, @@ -215,6 +217,18 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64, return sw.fvals, false } } + for _, wherein := range sw.whereins { + var value float64 + idx, ok := sw.fmap[wherein.field] + if ok { + if len(fields) > idx { + value = fields[idx] + } + } + if !wherein.match(value) { + return sw.fvals, false + } + } } else { for idx := range sw.farr { var value float64 @@ -242,6 +256,16 @@ func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) ([]float64, return sw.fvals, false } } + for _, wherein := range sw.whereins { + var value float64 + idx, ok := sw.fmap[wherein.field] + if ok { + value = sw.fvals[idx] + } + if !wherein.match(value) { + return sw.fvals, false + } + } } return sw.fvals, true } diff --git a/controller/search.go b/controller/search.go index 6c5a61af..d3c32398 100644 --- a/controller/search.go +++ b/controller/search.go @@ -296,7 +296,7 @@ func (c *Controller) cmdNearby(msg *server.Message) (res string, err error) { return "", s } minZ, maxZ := zMinMaxFromWheres(s.wheres) - sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields) + sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) if err != nil { return "", err } @@ -397,7 +397,7 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res if s.fence { return "", s } - sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields) + sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) if err != nil { return "", err } @@ -464,7 +464,7 @@ func (c *Controller) cmdSearch(msg *server.Message) (res string, err error) { if err != nil { return "", err } - sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, true, s.cursor, s.limit, s.wheres, s.nofields) + sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, true, s.cursor, s.limit, s.wheres, s.whereins, s.nofields) if err != nil { return "", err } diff --git a/controller/token.go b/controller/token.go index aa11cfc9..907b4830 100644 --- a/controller/token.go +++ b/controller/token.go @@ -156,6 +156,16 @@ func zMinMaxFromWheres(wheres []whereT) (minZ, maxZ float64) { return } +type whereinT struct { + field string + val_map map[float64]struct{} +} + +func (wherein whereinT) match(value float64) bool { + _, ok := wherein.val_map[value] + return ok +} + type searchScanBaseTokens struct { key string cursor uint64 @@ -168,6 +178,7 @@ type searchScanBaseTokens struct { accept map[string]bool glob string wheres []whereT + whereins []whereinT nofields bool ulimit bool limit uint64 @@ -245,6 +256,38 @@ func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, } t.wheres = append(t.wheres, whereT{field, minx, min, maxx, max}) continue + } else if (wtok[0] == 'W' || wtok[0] == 'w') && strings.ToLower(wtok) == "wherein" { + vs = nvs + var field, nvals_str, val_str string + if vs, field, ok = tokenval(vs); !ok || field == "" { + err = errInvalidNumberOfArguments + return + } + if vs, nvals_str, ok = tokenval(vs); !ok || nvals_str == "" { + err = errInvalidNumberOfArguments + return + } + var i, nvals uint64 + if nvals, err = strconv.ParseUint(nvals_str, 10, 64); err != nil { + err = errInvalidArgument(nvals_str) + return + } + val_map := make(map[float64]struct{}) + var val float64 + var empty struct{} + for i = 0; i < nvals; i++ { + if vs, val_str, ok = tokenval(vs); !ok || val_str == "" { + err = errInvalidNumberOfArguments + return + } + if val, err = strconv.ParseFloat(val_str, 64); err != nil { + err = errInvalidArgument(val_str) + return + } + val_map[val] = empty + } + t.whereins = append(t.whereins, whereinT{field, val_map}) + continue } else if (wtok[0] == 'N' || wtok[0] == 'n') && strings.ToLower(wtok) == "nofields" { vs = nvs if t.nofields { diff --git a/core/commands.json b/core/commands.json index a2ebede6..a8fd5e17 100644 --- a/core/commands.json +++ b/core/commands.json @@ -344,6 +344,14 @@ "optional": true, "multiple": true }, + { + "command": "WHEREIN", + "name": ["field","count","value"], + "type": ["string","integer","double"], + "optional": true, + "multiple": true, + "variadic": true + }, { "command": "NOFIELDS", "name": [], @@ -411,6 +419,14 @@ "optional": true, "multiple": true }, + { + "command": "WHEREIN", + "name": ["field","count","value"], + "type": ["string","integer","double"], + "optional": true, + "multiple": true, + "variadic": true + }, { "command": "NOFIELDS", "name": [], @@ -496,6 +512,14 @@ "optional": true, "multiple": true }, + { + "command": "WHEREIN", + "name": ["field","count","value"], + "type": ["string","integer","double"], + "optional": true, + "multiple": true, + "variadic": true + }, { "command": "NOFIELDS", "name": [], @@ -632,6 +656,14 @@ "optional": true, "multiple": true }, + { + "command": "WHEREIN", + "name": ["field","count","value"], + "type": ["string","integer","double"], + "optional": true, + "multiple": true, + "variadic": true + }, { "command": "NOFIELDS", "name": [], @@ -812,6 +844,14 @@ "optional": true, "multiple": true }, + { + "command": "WHEREIN", + "name": ["field","count","value"], + "type": ["string","integer","double"], + "optional": true, + "multiple": true, + "variadic": true + }, { "command": "NOFIELDS", "name": [], diff --git a/tests/keys_test.go b/tests/keys_test.go index 49b72ac5..d9d24f5d 100644 --- a/tests/keys_test.go +++ b/tests/keys_test.go @@ -29,6 +29,7 @@ func subTestKeys(t *testing.T, mc *mockServer) { runStep(t, mc, "SET EX", keys_SET_EX_test) runStep(t, mc, "PDEL", keys_PDEL_test) runStep(t, mc, "FIELDS", keys_FIELDS_test) + runStep(t, mc, "WHEREIN", keys_WHEREIN_test) } func keys_BOUNDS_test(mc *mockServer) error { @@ -332,3 +333,16 @@ func keys_PDEL_test(mc *mockServer) error { {"PDEL", "mykey", "*"}, {0}, }) } + +func keys_WHEREIN_test(mc *mockServer) error { + return mc.DoBatch([][]interface{}{ + {"SET", "mykey", "myid_a1", "FIELD", "a", 1, "POINT", 33, -115}, {"OK"}, + {"WITHIN", "mykey", "WHEREIN", "a", 3, 0, 1, 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {`[0 [[myid_a1 {"type":"Point","coordinates":[-115,33]} [a 1]]]]`}, + {"WITHIN", "mykey", "WHEREIN", "a", "a", 0, 1, 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {"ERR invalid argument 'a'"}, + {"WITHIN", "mykey", "WHEREIN", "a", 1, 0, 1, 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {"ERR invalid argument '1'"}, + {"WITHIN", "mykey", "WHEREIN", "a", 3, 0, "a", 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {"ERR invalid argument 'a'"}, + {"SET", "mykey", "myid_a2", "FIELD", "a", 2, "POINT", 32.99, -115}, {"OK"}, + {"SET", "mykey", "myid_a3", "FIELD", "a", 3, "POINT", 33, -115.02}, {"OK"}, + {"WITHIN", "mykey", "WHEREIN", "a", 3, 0, 1, 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {`[0 [[myid_a1 {"type":"Point","coordinates":[-115,33]} [a 1]] [myid_a2 {"type":"Point","coordinates":[-115,32.99]} [a 2]]]]`}, + }) +}