diff --git a/core/commands.json b/core/commands.json index 93adc1bb..e4bc7f96 100644 --- a/core/commands.json +++ b/core/commands.json @@ -359,6 +359,12 @@ "type": "integer", "optional": true }, + { + "command": "TIMEOUT", + "name": "seconds", + "type": "double", + "optional": true + }, { "command": "MATCH", "name": "pattern", @@ -450,6 +456,12 @@ "type": "integer", "optional": true }, + { + "command": "TIMEOUT", + "name": "seconds", + "type": "double", + "optional": true + }, { "command": "MATCH", "name": "pattern", @@ -559,6 +571,12 @@ "type": "integer", "optional": true }, + { + "command": "TIMEOUT", + "name": "seconds", + "type": "double", + "optional": true + }, { "command": "SPARSE", "name": "spread", @@ -725,6 +743,12 @@ "type": "integer", "optional": true }, + { + "command": "TIMEOUT", + "name": "seconds", + "type": "double", + "optional": true + }, { "command": "SPARSE", "name": "spread", @@ -946,6 +970,12 @@ "type": "integer", "optional": true }, + { + "command": "TIMEOUT", + "name": "seconds", + "type": "double", + "optional": true + }, { "command": "SPARSE", "name": "spread", @@ -1299,6 +1329,17 @@ ], "group": "connection" }, + "TIMEOUT": { + "summary": "Gets or sets the query timeout for the current connection.", + "arguments": [ + { + "name": "seconds", + "optional": true, + "type": "double" + } + ], + "group": "connection" + }, "SETHOOK": { "summary": "Creates a webhook which points to geofenced search", "arguments": [ diff --git a/internal/deadline/deadline.go b/internal/deadline/deadline.go index 7acf2577..6be76543 100644 --- a/internal/deadline/deadline.go +++ b/internal/deadline/deadline.go @@ -13,6 +13,15 @@ func New(deadline time.Time) *Deadline { return &Deadline{unixNano: deadline.UnixNano()} } +func Empty() *Deadline { + return &Deadline{} +} + +// Update the deadline from a given time object +func (deadline *Deadline) Update(newDeadline time.Time) { + deadline.unixNano = newDeadline.UnixNano() +} + // Check the deadline and panic when reached //go:noinline func (deadline *Deadline) Check() { diff --git a/internal/server/scan.go b/internal/server/scan.go index 064fca6b..bded9749 100644 --- a/internal/server/scan.go +++ b/internal/server/scan.go @@ -55,6 +55,9 @@ func (c *Server) cmdScan(msg *Message) (res resp.Value, err error) { wr.WriteString(`{"ok":true`) } sw.writeHead() + if s.timeout != 0 { + msg.Deadline.Update(start.Add(s.timeout)) + } if sw.col != nil { if sw.output == outputCount && len(sw.wheres) == 0 && len(sw.whereins) == 0 && sw.globEverything == true { diff --git a/internal/server/search.go b/internal/server/search.go index 1b7f9af0..d4083a74 100644 --- a/internal/server/search.go +++ b/internal/server/search.go @@ -370,6 +370,9 @@ func (server *Server) cmdNearby(msg *Message) (res resp.Value, err error) { wr.WriteString(`{"ok":true`) } sw.writeHead() + if s.timeout != 0 { + msg.Deadline.Update(start.Add(s.timeout)) + } if sw.col != nil { iter := func(id string, o geojson.Object, fields []float64, dist float64) bool { meters := 0.0 @@ -480,6 +483,9 @@ func (server *Server) cmdWithinOrIntersects(cmd string, msg *Message) (res resp. wr.WriteString(`{"ok":true`) } sw.writeHead() + if s.timeout != 0 { + msg.Deadline.Update(start.Add(s.timeout)) + } if sw.col != nil { if cmd == "within" { sw.col.Within(s.obj, s.sparse, sw, msg.Deadline, func( @@ -570,6 +576,9 @@ func (server *Server) cmdSearch(msg *Message) (res resp.Value, err error) { wr.WriteString(`{"ok":true`) } sw.writeHead() + if s.timeout != 0 { + msg.Deadline.Update(start.Add(s.timeout)) + } if sw.col != nil { if sw.output == outputCount && len(sw.wheres) == 0 && sw.globEverything == true { count := sw.col.Count() - int(s.cursor) diff --git a/internal/server/server.go b/internal/server/server.go index fd60b582..ec61f932 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1024,8 +1024,14 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { // No locking for pubsub } res, d, err := func() (res resp.Value, d commandDetails, err error) { - if client.timeout != 0 && !write { - msg.Deadline = deadline.New(start.Add(client.timeout)) + if !write { + if client.timeout == 0 { + // the command itself might have a timeout, + // which will be used to update this trivial deadline. + msg.Deadline = deadline.Empty() + } else { + msg.Deadline = deadline.New(start.Add(client.timeout)) + } defer func() { if msg.Deadline.Hit() { v := recover() diff --git a/internal/server/token.go b/internal/server/token.go index 71a81044..8b46ef01 100644 --- a/internal/server/token.go +++ b/internal/server/token.go @@ -6,6 +6,7 @@ import ( "math" "strconv" "strings" + "time" "github.com/yuin/gopher-lua" ) @@ -247,6 +248,7 @@ type searchScanBaseTokens struct { sparse uint8 desc bool clip bool + timeout time.Duration } func (c *Server) parseSearchScanBaseTokens( @@ -579,6 +581,20 @@ func (c *Server) parseSearchScanBaseTokens( } t.clip = true continue + case "timeout": + vs = nvs + var valStr string + if vs, valStr, ok = tokenval(vs); !ok || valStr == "" { + err = errInvalidNumberOfArguments + return + } + timeout, _err := strconv.ParseFloat(valStr, 64) + if _err != nil || timeout < 0 { + err = errInvalidArgument(valStr) + return + } + t.timeout = time.Duration(timeout * float64(time.Second)) + continue } } break diff --git a/tests/tests_test.go b/tests/tests_test.go index 6624fa64..34953ddf 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -47,6 +47,7 @@ func TestAll(t *testing.T) { runSubTest(t, "scripts", mc, subTestScripts) runSubTest(t, "info", mc, subTestInfo) runSubTest(t, "client", mc, subTestClient) + runSubTest(t, "timeouts", mc, subTestTimeout) } func runSubTest(t *testing.T, name string, mc *mockServer, test func(t *testing.T, mc *mockServer)) { diff --git a/tests/timeout_test.go b/tests/timeout_test.go new file mode 100644 index 00000000..7178c0b1 --- /dev/null +++ b/tests/timeout_test.go @@ -0,0 +1,114 @@ +package tests + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/gomodule/redigo/redis" +) + +func subTestTimeout(t *testing.T, mc *mockServer) { + runStep(t, mc, "session set/unset", timeout_session_set_unset_test) + runStep(t, mc, "session spatial", timeout_session_spatial_test) + runStep(t, mc, "session search", timeout_session_search_test) + runStep(t, mc, "command spatial", timeout_command_spatial_test) + runStep(t, mc, "command search", timeout_command_search_test) +} + +func setup(mc *mockServer, count int, points bool) (err error) { + rand.Seed(time.Now().UnixNano()) + + // add a bunch of points + for i := 0; i < count; i++ { + val := fmt.Sprintf("val:%d", i) + var resp string + var lat, lon, fval float64 + fval = rand.Float64() + if points { + lat = rand.Float64()*180 - 90 + lon = rand.Float64()*360 - 180 + resp, err = redis.String(mc.conn.Do("SET", + "mykey", val, + "FIELD", "foo", fval, + "POINT", lat, lon)) + } else { + resp, err = redis.String(mc.conn.Do("SET", + "mykey", val, + "FIELD", "foo", fval, + "STRING", val)) + } + if err != nil { + return + } + if resp != "OK" { + err = fmt.Errorf("expected 'OK', got '%s'", resp) + return + } + time.Sleep(time.Nanosecond) + } + time.Sleep(time.Second * 3) + return +} + +func timeout_session_set_unset_test(mc *mockServer) (err error) { + return mc.DoBatch([][]interface{}{ + {"TIMEOUT"}, {"0"}, + {"TIMEOUT", "0.25"}, {"OK"}, + {"TIMEOUT"}, {"0.25"}, + {"TIMEOUT", "0"}, {"OK"}, + {"TIMEOUT"}, {"0"}, + }) +} + +func timeout_session_spatial_test(mc *mockServer) (err error) { + err = setup(mc, 10000, true) + + return mc.DoBatch([][]interface{}{ + {"SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"10000"}, + {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + + {"TIMEOUT", "0.000001"}, {"OK"}, + + {"SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"ERR timeout"}, + {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + }) +} + +func timeout_command_spatial_test(mc *mockServer) (err error) { + err = setup(mc, 10000, true) + + return mc.DoBatch([][]interface{}{ + {"TIMEOUT", "1"}, {"OK"}, + {"SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"10000"}, + {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + + {"SCAN", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT"}, {"ERR timeout"}, + {"INTERSECTS", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + {"WITHIN", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + }) +} + +func timeout_session_search_test(mc *mockServer) (err error) { + err = setup(mc, 10000, false) + + return mc.DoBatch([][]interface{}{ + {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"10000"}, + {"TIMEOUT", "0.000001"}, {"OK"}, + {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"ERR timeout"}, + }) +} + +func timeout_command_search_test(mc *mockServer) (err error) { + err = setup(mc, 10000, false) + + return mc.DoBatch([][]interface{}{ + {"TIMEOUT", "1"}, {"OK"}, + {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"10000"}, + {"SEARCH", "mykey", "TIMEOUT", "0.000001", "MATCH", "val:*", "COUNT"}, {"ERR timeout"}, + }) +}