diff --git a/internal/deadline/deadline.go b/internal/deadline/deadline.go index cb22525e..6be76543 100644 --- a/internal/deadline/deadline.go +++ b/internal/deadline/deadline.go @@ -13,6 +13,10 @@ func New(deadline time.Time) *Deadline { return &Deadline{unixNano: deadline.UnixNano()} } +func Empty() *Deadline { + return &Deadline{} +} + // Update the deadline from a given time object func (deadline *Deadline) Update(newDeadline time.Time) { deadline.unixNano = newDeadline.UnixNano() diff --git a/internal/server/server.go b/internal/server/server.go index bea3564b..ec61f932 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1025,7 +1025,13 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { } res, d, err := func() (res resp.Value, d commandDetails, err error) { if !write { - msg.Deadline = deadline.New(start.Add(client.timeout)) + if client.timeout == 0 { + // the command itself might have a timeout, + // which will be used to update this trivial deadline. + msg.Deadline = deadline.Empty() + } else { + msg.Deadline = deadline.New(start.Add(client.timeout)) + } defer func() { if msg.Deadline.Hit() { v := recover() diff --git a/tests/tests_test.go b/tests/tests_test.go index 6624fa64..34953ddf 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -47,6 +47,7 @@ func TestAll(t *testing.T) { runSubTest(t, "scripts", mc, subTestScripts) runSubTest(t, "info", mc, subTestInfo) runSubTest(t, "client", mc, subTestClient) + runSubTest(t, "timeouts", mc, subTestTimeout) } func runSubTest(t *testing.T, name string, mc *mockServer, test func(t *testing.T, mc *mockServer)) { diff --git a/tests/timeout_test.go b/tests/timeout_test.go new file mode 100644 index 00000000..7178c0b1 --- /dev/null +++ b/tests/timeout_test.go @@ -0,0 +1,114 @@ +package tests + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/gomodule/redigo/redis" +) + +func subTestTimeout(t *testing.T, mc *mockServer) { + runStep(t, mc, "session set/unset", timeout_session_set_unset_test) + runStep(t, mc, "session spatial", timeout_session_spatial_test) + runStep(t, mc, "session search", timeout_session_search_test) + runStep(t, mc, "command spatial", timeout_command_spatial_test) + runStep(t, mc, "command search", timeout_command_search_test) +} + +func setup(mc *mockServer, count int, points bool) (err error) { + rand.Seed(time.Now().UnixNano()) + + // add a bunch of points + for i := 0; i < count; i++ { + val := fmt.Sprintf("val:%d", i) + var resp string + var lat, lon, fval float64 + fval = rand.Float64() + if points { + lat = rand.Float64()*180 - 90 + lon = rand.Float64()*360 - 180 + resp, err = redis.String(mc.conn.Do("SET", + "mykey", val, + "FIELD", "foo", fval, + "POINT", lat, lon)) + } else { + resp, err = redis.String(mc.conn.Do("SET", + "mykey", val, + "FIELD", "foo", fval, + "STRING", val)) + } + if err != nil { + return + } + if resp != "OK" { + err = fmt.Errorf("expected 'OK', got '%s'", resp) + return + } + time.Sleep(time.Nanosecond) + } + time.Sleep(time.Second * 3) + return +} + +func timeout_session_set_unset_test(mc *mockServer) (err error) { + return mc.DoBatch([][]interface{}{ + {"TIMEOUT"}, {"0"}, + {"TIMEOUT", "0.25"}, {"OK"}, + {"TIMEOUT"}, {"0.25"}, + {"TIMEOUT", "0"}, {"OK"}, + {"TIMEOUT"}, {"0"}, + }) +} + +func timeout_session_spatial_test(mc *mockServer) (err error) { + err = setup(mc, 10000, true) + + return mc.DoBatch([][]interface{}{ + {"SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"10000"}, + {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + + {"TIMEOUT", "0.000001"}, {"OK"}, + + {"SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"ERR timeout"}, + {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + }) +} + +func timeout_command_spatial_test(mc *mockServer) (err error) { + err = setup(mc, 10000, true) + + return mc.DoBatch([][]interface{}{ + {"TIMEOUT", "1"}, {"OK"}, + {"SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"10000"}, + {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + + {"SCAN", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT"}, {"ERR timeout"}, + {"INTERSECTS", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + {"WITHIN", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + }) +} + +func timeout_session_search_test(mc *mockServer) (err error) { + err = setup(mc, 10000, false) + + return mc.DoBatch([][]interface{}{ + {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"10000"}, + {"TIMEOUT", "0.000001"}, {"OK"}, + {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"ERR timeout"}, + }) +} + +func timeout_command_search_test(mc *mockServer) (err error) { + err = setup(mc, 10000, false) + + return mc.DoBatch([][]interface{}{ + {"TIMEOUT", "1"}, {"OK"}, + {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"10000"}, + {"SEARCH", "mykey", "TIMEOUT", "0.000001", "MATCH", "val:*", "COUNT"}, {"ERR timeout"}, + }) +}