diff --git a/internal/deadline/deadline.go b/internal/deadline/deadline.go index 697492a3..b8256ac4 100644 --- a/internal/deadline/deadline.go +++ b/internal/deadline/deadline.go @@ -13,16 +13,6 @@ func New(deadline time.Time) *Deadline { return &Deadline{unixNano: deadline.UnixNano()} } -// Empty deadline does nothing, just a place holder for future updates -func Empty() *Deadline { - return &Deadline{} -} - -// Update the deadline from a given time object -func (deadline *Deadline) Update(newDeadline time.Time) { - deadline.unixNano = newDeadline.UnixNano() -} - // Check the deadline and panic when reached //go:noinline func (deadline *Deadline) Check() { @@ -41,6 +31,6 @@ func (deadline *Deadline) Hit() bool { } // GetDeadlineTime returns the time object for the deadline, and an "empty" boolean -func (deadline *Deadline) GetDeadlineTime() (time.Time, bool) { - return time.Unix(0, deadline.unixNano), deadline.unixNano == 0 +func (deadline *Deadline) GetDeadlineTime() (time.Time) { + return time.Unix(0, deadline.unixNano) } diff --git a/internal/server/client.go b/internal/server/client.go index 15e49e41..ab19b2c2 100644 --- a/internal/server/client.go +++ b/internal/server/client.go @@ -33,8 +33,6 @@ type Client struct { name string // optional defined name opened time.Time // when the client was created/opened, unix nano last time.Time // last client request/response, unix nano - - timeout time.Duration // command timeout } // Write ... diff --git a/internal/server/scan.go b/internal/server/scan.go index bded9749..064fca6b 100644 --- a/internal/server/scan.go +++ b/internal/server/scan.go @@ -55,9 +55,6 @@ func (c *Server) cmdScan(msg *Message) (res resp.Value, err error) { wr.WriteString(`{"ok":true`) } sw.writeHead() - if s.timeout != 0 { - msg.Deadline.Update(start.Add(s.timeout)) - } if sw.col != nil { if sw.output == outputCount && len(sw.wheres) == 0 && len(sw.whereins) == 0 && sw.globEverything == true { diff --git a/internal/server/scripts.go b/internal/server/scripts.go index 9f87a1ac..6bc2d6fa 100644 --- a/internal/server/scripts.go +++ b/internal/server/scripts.go @@ -32,6 +32,7 @@ var errNotLeader = errors.New("not the leader") var errReadOnly = errors.New("read only") var errCatchingUp = errors.New("catching up to leader") var errNoLuasAvailable = errors.New("no interpreters available") +var errTimeout = errors.New("timeout") // Go-routine-safe pool of read-to-go lua states type lStatePool struct { @@ -392,12 +393,14 @@ func (c *Server) cmdEvalUnified(scriptIsSha bool, msg *Message) (res resp.Value, if err != nil { return } - deadline, empty := msg.Deadline.GetDeadlineTime() - if !empty { - ctx, cancel := context.WithDeadline(context.Background(), deadline) + luaDeadline := lua.LNil + if msg.Deadline != nil { + dlTime := msg.Deadline.GetDeadlineTime() + ctx, cancel := context.WithDeadline(context.Background(), dlTime) defer cancel() luaState.SetContext(ctx) defer luaState.RemoveContext() + luaDeadline = lua.LNumber(float64(dlTime.UnixNano()) / 1e9) } defer c.luapool.Put(luaState) @@ -430,6 +433,7 @@ func (c *Server) cmdEvalUnified(scriptIsSha bool, msg *Message) (res resp.Value, luaState, map[string]lua.LValue{ "KEYS": keysTbl, "ARGV": argsTbl, + "DEADLINE": luaDeadline, "EVAL_CMD": lua.LString(msg.Command()), }) @@ -459,6 +463,7 @@ func (c *Server) cmdEvalUnified(scriptIsSha bool, msg *Message) (res resp.Value, luaState, map[string]lua.LValue{ "KEYS": lua.LNil, "ARGV": lua.LNil, + "DEADLINE": lua.LNil, "EVAL_CMD": lua.LNil, }) if err := luaState.PCall(0, 1, nil); err != nil { @@ -643,6 +648,13 @@ func (c *Server) luaTile38Call(evalcmd string, cmd string, args ...string) (resp msg := &Message{} msg.OutputType = RESP msg.Args = append([]string{cmd}, args...) + + if msg.Command() == "timeout" { + if err := rewriteTimeoutMsg(msg); err != nil { + return resp.NullValue(), err + } + } + switch msg.Command() { case "ping", "echo", "auth", "massinsert", "shutdown", "gc", "sethook", "pdelhook", "delhook", @@ -690,7 +702,28 @@ func (c *Server) luaTile38AtomicRW(msg *Message) (resp.Value, error) { } } - res, d, err := c.commandInScript(msg) + res, d, err := func() (res resp.Value, d commandDetails, err error) { + if msg.Deadline != nil { + if write { + res = NOMessage + err = errTimeoutOnCmd(msg.Command()) + return + } + defer func() { + if msg.Deadline.Hit() { + v := recover() + if v != nil { + if s, ok := v.(string); !ok || s != "deadline" { + panic(v) + } + } + res = NOMessage + err = errTimeout + } + }() + } + return c.commandInScript(msg) + }() if err != nil { return resp.NullValue(), err } @@ -722,7 +755,23 @@ func (c *Server) luaTile38AtomicRO(msg *Message) (resp.Value, error) { } } - res, _, err := c.commandInScript(msg) + res, _, err := func() (res resp.Value, d commandDetails, err error) { + if msg.Deadline != nil { + defer func() { + if msg.Deadline.Hit() { + v := recover() + if v != nil { + if s, ok := v.(string); !ok || s != "deadline" { + panic(v) + } + } + res = NOMessage + err = errTimeout + } + }() + } + return c.commandInScript(msg) + }() if err != nil { return resp.NullValue(), err } @@ -759,7 +808,28 @@ func (c *Server) luaTile38NonAtomic(msg *Message) (resp.Value, error) { } } - res, d, err := c.commandInScript(msg) + res, d, err := func() (res resp.Value, d commandDetails, err error) { + if msg.Deadline != nil { + if write { + res = NOMessage + err = errTimeoutOnCmd(msg.Command()) + return + } + defer func() { + if msg.Deadline.Hit() { + v := recover() + if v != nil { + if s, ok := v.(string); !ok || s != "deadline" { + panic(v) + } + } + res = NOMessage + err = errTimeout + } + }() + } + return c.commandInScript(msg) + }() if err != nil { return resp.NullValue(), err } diff --git a/internal/server/search.go b/internal/server/search.go index d4083a74..1b7f9af0 100644 --- a/internal/server/search.go +++ b/internal/server/search.go @@ -370,9 +370,6 @@ func (server *Server) cmdNearby(msg *Message) (res resp.Value, err error) { wr.WriteString(`{"ok":true`) } sw.writeHead() - if s.timeout != 0 { - msg.Deadline.Update(start.Add(s.timeout)) - } if sw.col != nil { iter := func(id string, o geojson.Object, fields []float64, dist float64) bool { meters := 0.0 @@ -483,9 +480,6 @@ func (server *Server) cmdWithinOrIntersects(cmd string, msg *Message) (res resp. wr.WriteString(`{"ok":true`) } sw.writeHead() - if s.timeout != 0 { - msg.Deadline.Update(start.Add(s.timeout)) - } if sw.col != nil { if cmd == "within" { sw.col.Within(s.obj, s.sparse, sw, msg.Deadline, func( @@ -576,9 +570,6 @@ func (server *Server) cmdSearch(msg *Message) (res resp.Value, err error) { wr.WriteString(`{"ok":true`) } sw.writeHead() - if s.timeout != 0 { - msg.Deadline.Update(start.Add(s.timeout)) - } if sw.col != nil { if sw.output == outputCount && len(sw.wheres) == 0 && sw.globEverything == true { count := sw.col.Count() - int(s.cursor) diff --git a/internal/server/server.go b/internal/server/server.go index ec61f932..2091aecf 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -39,6 +39,9 @@ import ( ) var errOOM = errors.New("OOM command not allowed when used memory > 'maxmemory'") +func errTimeoutOnCmd(cmd string) error { + return fmt.Errorf("timeout not supported for '%s'", cmd) +} const ( goingLive = "going live" @@ -839,6 +842,26 @@ func isReservedFieldName(field string) bool { return false } +func rewriteTimeoutMsg(msg *Message) (err error) { + vs := msg.Args[1:] + var valStr string + var ok bool + if vs, valStr, ok = tokenval(vs); !ok || valStr == "" || len(vs) == 0 { + err = errInvalidNumberOfArguments + return + } + timeoutSec, _err := strconv.ParseFloat(valStr, 64) + if _err != nil || timeoutSec < 0 { + err = errInvalidArgument(valStr) + return + } + msg.Args = vs[:] + msg._command = "" + msg.Deadline = deadline.New( + time.Now().Add(time.Duration(timeoutSec * float64(time.Second)))) + return +} + func (server *Server) handleInputCommand(client *Client, msg *Message) error { start := time.Now() serializeOutput := func(res resp.Value) (string, error) { @@ -923,6 +946,12 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { return nil } + if msg.Command() == "timeout" { + if err := rewriteTimeoutMsg(msg); err != nil { + return writeErr(err.Error()) + } + } + var write bool if !client.authd || msg.Command() == "auth" { @@ -997,7 +1026,7 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { // does not write to aof, but requires a write lock. server.mu.Lock() defer server.mu.Unlock() - case "output", "timeout": + case "output": // this is local connection operation. Locks not needed. case "echo": case "massinsert": @@ -1024,13 +1053,11 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { // No locking for pubsub } res, d, err := func() (res resp.Value, d commandDetails, err error) { - if !write { - if client.timeout == 0 { - // the command itself might have a timeout, - // which will be used to update this trivial deadline. - msg.Deadline = deadline.Empty() - } else { - msg.Deadline = deadline.New(start.Add(client.timeout)) + if msg.Deadline != nil { + if write { + res = NOMessage + err = errTimeoutOnCmd(msg.Command()) + return } defer func() { if msg.Deadline.Hit() { @@ -1203,8 +1230,6 @@ func (server *Server) command(msg *Message, client *Client) ( res, err = server.cmdKeys(msg) case "output": res, err = server.cmdOutput(msg) - case "timeout": - res, err = server.cmdTimeout(msg, client) case "aof": res, err = server.cmdAOF(msg) case "aofmd5": diff --git a/internal/server/timeout.go b/internal/server/timeout.go deleted file mode 100644 index 4db38ff6..00000000 --- a/internal/server/timeout.go +++ /dev/null @@ -1,38 +0,0 @@ -package server - -import ( - "strconv" - "time" - - "github.com/tidwall/resp" -) - -func (c *Server) cmdTimeout(msg *Message, client *Client) (res resp.Value, err error) { - start := time.Now() - vs := msg.Args[1:] - var arg string - var ok bool - - if len(vs) != 0 { - if _, arg, ok = tokenval(vs); !ok || arg == "" { - return NOMessage, errInvalidNumberOfArguments - } - timeout, err := strconv.ParseFloat(arg, 64) - if err != nil || timeout < 0 { - return NOMessage, errInvalidArgument(arg) - } - client.timeout = time.Duration(timeout * float64(time.Second)) - return OKMessage(msg, start), nil - } - // return the timeout - switch msg.OutputType { - default: - return NOMessage, nil - case JSON: - return resp.StringValue(`{"ok":true` + - `,"seconds":` + strconv.FormatFloat(client.timeout.Seconds(), 'f', -1, 64) + - `,"elapsed":` + time.Now().Sub(start).String() + `}`), nil - case RESP: - return resp.FloatValue(client.timeout.Seconds()), nil - } -} diff --git a/internal/server/token.go b/internal/server/token.go index 8b46ef01..71a81044 100644 --- a/internal/server/token.go +++ b/internal/server/token.go @@ -6,7 +6,6 @@ import ( "math" "strconv" "strings" - "time" "github.com/yuin/gopher-lua" ) @@ -248,7 +247,6 @@ type searchScanBaseTokens struct { sparse uint8 desc bool clip bool - timeout time.Duration } func (c *Server) parseSearchScanBaseTokens( @@ -581,20 +579,6 @@ func (c *Server) parseSearchScanBaseTokens( } t.clip = true continue - case "timeout": - vs = nvs - var valStr string - if vs, valStr, ok = tokenval(vs); !ok || valStr == "" { - err = errInvalidNumberOfArguments - return - } - timeout, _err := strconv.ParseFloat(valStr, 64) - if _err != nil || timeout < 0 { - err = errInvalidArgument(valStr) - return - } - t.timeout = time.Duration(timeout * float64(time.Second)) - continue } } break diff --git a/tests/timeout_test.go b/tests/timeout_test.go index 011ad281..aa8d4263 100644 --- a/tests/timeout_test.go +++ b/tests/timeout_test.go @@ -3,6 +3,7 @@ package tests import ( "fmt" "math/rand" + "strings" "testing" "time" @@ -10,12 +11,12 @@ import ( ) func subTestTimeout(t *testing.T, mc *mockServer) { - runStep(t, mc, "session set/unset", timeout_session_set_unset_test) - runStep(t, mc, "session spatial", timeout_session_spatial_test) - runStep(t, mc, "session search", timeout_session_search_test) - runStep(t, mc, "session scripts", timeout_session_scripts_test) - runStep(t, mc, "command spatial", timeout_command_spatial_test) - runStep(t, mc, "command search", timeout_command_search_test) + runStep(t, mc, "spatial", timeout_spatial_test) + runStep(t, mc, "search", timeout_search_test) + runStep(t, mc, "scripts", timeout_scripts_test) + runStep(t, mc, "no writes", timeout_no_writes_test) + runStep(t, mc, "within scripts", timeout_within_scripts_test) + runStep(t, mc, "no writes within scripts", timeout_no_writes_within_scripts_test) } func setup(mc *mockServer, count int, points bool) (err error) { @@ -53,17 +54,7 @@ func setup(mc *mockServer, count int, points bool) (err error) { return } -func timeout_session_set_unset_test(mc *mockServer) (err error) { - return mc.DoBatch([][]interface{}{ - {"TIMEOUT"}, {"0"}, - {"TIMEOUT", "0.25"}, {"OK"}, - {"TIMEOUT"}, {"0.25"}, - {"TIMEOUT", "0"}, {"OK"}, - {"TIMEOUT"}, {"0"}, - }) -} - -func timeout_session_spatial_test(mc *mockServer) (err error) { +func timeout_spatial_test(mc *mockServer) (err error) { err = setup(mc, 10000, true) return mc.DoBatch([][]interface{}{ @@ -71,50 +62,22 @@ func timeout_session_spatial_test(mc *mockServer) (err error) { {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, - {"TIMEOUT", "0.000001"}, {"OK"}, - - {"SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"ERR timeout"}, - {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, - {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + {"TIMEOUT", "0.000001", "SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"ERR timeout"}, + {"TIMEOUT", "0.000001", "INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + {"TIMEOUT", "0.000001", "WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, }) } -func timeout_command_spatial_test(mc *mockServer) (err error) { - err = setup(mc, 10000, true) - - return mc.DoBatch([][]interface{}{ - {"TIMEOUT", "1"}, {"OK"}, - {"SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"10000"}, - {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, - {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, - - {"SCAN", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT"}, {"ERR timeout"}, - {"INTERSECTS", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, - {"WITHIN", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, - }) -} - -func timeout_session_search_test(mc *mockServer) (err error) { +func timeout_search_test(mc *mockServer) (err error) { err = setup(mc, 10000, false) return mc.DoBatch([][]interface{}{ {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"10000"}, - {"TIMEOUT", "0.000001"}, {"OK"}, - {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"ERR timeout"}, + {"TIMEOUT", "0.000001", "SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"ERR timeout"}, }) } -func timeout_command_search_test(mc *mockServer) (err error) { - err = setup(mc, 10000, false) - - return mc.DoBatch([][]interface{}{ - {"TIMEOUT", "1"}, {"OK"}, - {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"10000"}, - {"SEARCH", "mykey", "TIMEOUT", "0.000001", "MATCH", "val:*", "COUNT"}, {"ERR timeout"}, - }) -} - -func timeout_session_scripts_test(mc *mockServer) (err error) { +func timeout_scripts_test(mc *mockServer) (err error) { script := ` local clock = os.clock local function sleep(n) @@ -132,16 +95,70 @@ func timeout_session_scripts_test(mc *mockServer) (err error) { {"EVALROSHA", sha, 0}, {nil}, {"EVALNASHA", sha, 0}, {nil}, - {"TIMEOUT", "0.1"}, {"OK"}, + {"TIMEOUT", "0.1", "EVALSHA", sha, 0}, {"ERR timeout"}, + {"TIMEOUT", "0.1", "EVALROSHA", sha, 0}, {"ERR timeout"}, + {"TIMEOUT", "0.1", "EVALNASHA", sha, 0}, {"ERR timeout"}, - {"EVALSHA", sha, 0}, {"ERR timeout"}, - {"EVALROSHA", sha, 0}, {"ERR timeout"}, - {"EVALNASHA", sha, 0}, {"ERR timeout"}, - - {"TIMEOUT", "0.9"}, {"OK"}, - - {"EVALSHA", sha, 0}, {nil}, - {"EVALROSHA", sha, 0}, {nil}, - {"EVALNASHA", sha, 0}, {nil}, + {"TIMEOUT", "0.9", "EVALSHA", sha, 0}, {nil}, + {"TIMEOUT", "0.9", "EVALROSHA", sha, 0}, {nil}, + {"TIMEOUT", "0.9", "EVALNASHA", sha, 0}, {nil}, + }) +} + +func timeout_no_writes_test(mc *mockServer) (err error) { + return mc.DoBatch([][]interface{}{ + {"SET", "mykey", "myid", "STRING", "foo"}, {"OK"}, + {"TIMEOUT", 1, "SET", "mykey", "myid", "STRING", "foo"}, {"ERR timeout not supported for 'set'"}, + }) +} + +func scriptTimeoutErr(v interface{}) (resp, expect interface{}) { + s := fmt.Sprintf("%v", v) + if strings.Contains(s, "ERR timeout") { + return v, v + } + return v, "A lua stack containing 'ERR timeout'" +} + +func timeout_within_scripts_test(mc *mockServer) (err error) { + err = setup(mc, 10000, true) + + script1 := "return tile38.call('timeout', 10, 'SCAN', 'mykey', 'WHERE', 'foo', -1, 2, 'COUNT')" + script2 := "return tile38.call('timeout', 0.000001, 'SCAN', 'mykey', 'WHERE', 'foo', -1, 2, 'COUNT')" + sha1 := "27a364b4e46ef493f6b70371086c286e2d5b5f49" + sha2 := "2da9c05b54abfe870bdc8383a143f9d3aa656192" + + return mc.DoBatch([][]interface{}{ + {"SCRIPT LOAD", script1}, {sha1}, + {"SCRIPT LOAD", script2}, {sha2}, + + {"EVALSHA", sha1, 0}, {"10000"}, + {"EVALROSHA", sha1, 0}, {"10000"}, + {"EVALNASHA", sha1, 0}, {"10000"}, + {"EVALSHA", sha2, 0}, {scriptTimeoutErr}, + {"EVALROSHA", sha2, 0}, {scriptTimeoutErr}, + {"EVALNASHA", sha2, 0}, {scriptTimeoutErr}, + }) +} + +func scriptTimeoutNotSupportedErr(v interface{}) (resp, expect interface{}) { + s := fmt.Sprintf("%v", v) + if strings.Contains(s, "ERR timeout not supported for") { + return v, v + } + return v, "A lua stack containing 'ERR timeout not supported for'" +} + +func timeout_no_writes_within_scripts_test(mc *mockServer) (err error) { + script1 := "return tile38.call('SET', 'mykey', 'myval', 'STRING', 'foo')" + script2 := "return tile38.call('timeout', 10, 'SET', 'mykey', 'myval', 'STRING', 'foo')" + sha1 := "393d0adff113fdda45e3b5aff93c188c30099f48" + sha2 := "5287c158d15eb53d800b7389d82df0d73b004bf1" + + return mc.DoBatch([][]interface{}{ + {"SCRIPT LOAD", script1}, {sha1}, + {"SCRIPT LOAD", script2}, {sha2}, + {"EVALSHA", sha1, 0, "foo"}, {"OK"}, + {"EVALSHA", sha2, 0, "foo"}, {scriptTimeoutNotSupportedErr}, }) }