From 0aecef6a5c7965a33859d49b0f1f3b2f98b6397c Mon Sep 17 00:00:00 2001 From: tidwall Date: Wed, 24 Apr 2019 05:09:41 -0700 Subject: [PATCH 1/5] Added TIMEOUT command --- internal/collection/collection.go | 106 ++++++++++--------------- internal/collection/collection_test.go | 48 +++++------ internal/deadline/deadline.go | 31 ++++++++ internal/server/aofshrink.go | 2 +- internal/server/client.go | 1 + internal/server/crud.go | 4 +- internal/server/fence.go | 6 +- internal/server/scan.go | 2 + internal/server/search.go | 15 ++-- internal/server/server.go | 26 +++++- internal/server/timeout.go | 38 +++++++++ 11 files changed, 175 insertions(+), 104 deletions(-) create mode 100644 internal/deadline/deadline.go create mode 100644 internal/server/timeout.go diff --git a/internal/collection/collection.go b/internal/collection/collection.go index ad107e38..8562c3d4 100644 --- a/internal/collection/collection.go +++ b/internal/collection/collection.go @@ -8,11 +8,12 @@ import ( "github.com/tidwall/geojson" "github.com/tidwall/geojson/geo" "github.com/tidwall/geojson/geometry" + "github.com/tidwall/tile38/internal/deadline" "github.com/tidwall/tinybtree" ) -// yieldStep forces the iterator to yield goroutine every N steps. -const yieldStep = 0xFF +// yieldStep forces the iterator to yield goroutine every 255 steps. +const yieldStep = 255 // Cursor allows for quickly paging through Scan, Within, Intersects, and Nearby type Cursor interface { @@ -320,7 +321,10 @@ func (c *Collection) FieldArr() []string { } // Scan iterates though the collection ids. -func (c *Collection) Scan(desc bool, cursor Cursor, +func (c *Collection) Scan( + desc bool, + cursor Cursor, + deadline *deadline.Deadline, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true @@ -335,12 +339,7 @@ func (c *Collection) Scan(desc bool, cursor Cursor, if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) iitm := value.(*itemT) keepon = iterator(iitm.id, iitm.obj, c.getFieldValues(iitm.id)) return keepon @@ -354,7 +353,11 @@ func (c *Collection) Scan(desc bool, cursor Cursor, } // ScanRange iterates though the collection starting with specified id. -func (c *Collection) ScanRange(start, end string, desc bool, cursor Cursor, +func (c *Collection) ScanRange( + start, end string, + desc bool, + cursor Cursor, + deadline *deadline.Deadline, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true @@ -369,12 +372,7 @@ func (c *Collection) ScanRange(start, end string, desc bool, cursor Cursor, if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) if !desc { if key >= end { return false @@ -398,7 +396,10 @@ func (c *Collection) ScanRange(start, end string, desc bool, cursor Cursor, } // SearchValues iterates though the collection values. -func (c *Collection) SearchValues(desc bool, cursor Cursor, +func (c *Collection) SearchValues( + desc bool, + cursor Cursor, + deadline *deadline.Deadline, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true @@ -413,12 +414,7 @@ func (c *Collection) SearchValues(desc bool, cursor Cursor, if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) iitm := item.(*itemT) keepon = iterator(iitm.id, iitm.obj, c.getFieldValues(iitm.id)) return keepon @@ -434,6 +430,7 @@ func (c *Collection) SearchValues(desc bool, cursor Cursor, // SearchValuesRange iterates though the collection values. func (c *Collection) SearchValuesRange(start, end string, desc bool, cursor Cursor, + deadline *deadline.Deadline, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true @@ -448,12 +445,7 @@ func (c *Collection) SearchValuesRange(start, end string, desc bool, if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) iitm := item.(*itemT) keepon = iterator(iitm.id, iitm.obj, c.getFieldValues(iitm.id)) return keepon @@ -471,6 +463,7 @@ func (c *Collection) SearchValuesRange(start, end string, desc bool, // ScanGreaterOrEqual iterates though the collection starting with specified id. func (c *Collection) ScanGreaterOrEqual(id string, desc bool, cursor Cursor, + deadline *deadline.Deadline, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true @@ -485,12 +478,7 @@ func (c *Collection) ScanGreaterOrEqual(id string, desc bool, if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) iitm := value.(*itemT) keepon = iterator(iitm.id, iitm.obj, c.getFieldValues(iitm.id)) return keepon @@ -594,6 +582,7 @@ func (c *Collection) Within( obj geojson.Object, sparse uint8, cursor Cursor, + deadline *deadline.Deadline, iter func(id string, obj geojson.Object, fields []float64) bool, ) bool { var count uint64 @@ -611,12 +600,7 @@ func (c *Collection) Within( if count <= offset { return false, true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) if match = o.Within(obj); match { ok = iter(id, o, fields) } @@ -630,12 +614,7 @@ func (c *Collection) Within( if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) if o.Within(obj) { return iter(id, o, fields) } @@ -650,6 +629,7 @@ func (c *Collection) Intersects( obj geojson.Object, sparse uint8, cursor Cursor, + deadline *deadline.Deadline, iter func(id string, obj geojson.Object, fields []float64) bool, ) bool { var count uint64 @@ -667,12 +647,7 @@ func (c *Collection) Intersects( if count <= offset { return false, true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) if match = o.Intersects(obj); match { ok = iter(id, o, fields) } @@ -686,12 +661,7 @@ func (c *Collection) Intersects( if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) if o.Intersects(obj) { return iter(id, o, fields) } @@ -704,6 +674,7 @@ func (c *Collection) Intersects( func (c *Collection) Nearby( target geojson.Object, cursor Cursor, + deadline *deadline.Deadline, iter func(id string, obj geojson.Object, fields []float64) bool, ) bool { // First look to see if there's at least one candidate in the circle's @@ -746,12 +717,7 @@ func (c *Collection) Nearby( if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) item := itemv.(*itemT) alive = iter(item.id, item.obj, c.getFieldValues(item.id)) return alive @@ -759,3 +725,13 @@ func (c *Collection) Nearby( ) return alive } + +func nextStep(step uint64, cursor Cursor, deadline *deadline.Deadline) { + if step&yieldStep == yieldStep { + runtime.Gosched() + deadline.Check() + } + if cursor != nil { + cursor.Step(1) + } +} diff --git a/internal/collection/collection_test.go b/internal/collection/collection_test.go index 21325676..edf6fd81 100644 --- a/internal/collection/collection_test.go +++ b/internal/collection/collection_test.go @@ -230,7 +230,7 @@ func TestCollectionScan(t *testing.T) { } var n int var prevID string - c.Scan(false, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.Scan(false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id > prevID) } @@ -241,7 +241,7 @@ func TestCollectionScan(t *testing.T) { }) expect(t, n == c.Count()) n = 0 - c.Scan(true, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.Scan(true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id < prevID) } @@ -253,7 +253,7 @@ func TestCollectionScan(t *testing.T) { expect(t, n == c.Count()) n = 0 - c.ScanRange("0060", "0070", false, nil, + c.ScanRange("0060", "0070", false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id > prevID) @@ -266,7 +266,7 @@ func TestCollectionScan(t *testing.T) { expect(t, n == 10) n = 0 - c.ScanRange("0070", "0060", true, nil, + c.ScanRange("0070", "0060", true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id < prevID) @@ -279,7 +279,7 @@ func TestCollectionScan(t *testing.T) { expect(t, n == 10) n = 0 - c.ScanGreaterOrEqual("0070", true, nil, + c.ScanGreaterOrEqual("0070", true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id < prevID) @@ -292,7 +292,7 @@ func TestCollectionScan(t *testing.T) { expect(t, n == 71) n = 0 - c.ScanGreaterOrEqual("0070", false, nil, + c.ScanGreaterOrEqual("0070", false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id > prevID) @@ -317,7 +317,7 @@ func TestCollectionSearch(t *testing.T) { } var n int var prevValue string - c.SearchValues(false, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.SearchValues(false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, obj.String() > prevValue) } @@ -328,7 +328,7 @@ func TestCollectionSearch(t *testing.T) { }) expect(t, n == c.Count()) n = 0 - c.SearchValues(true, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.SearchValues(true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, obj.String() < prevValue) } @@ -340,7 +340,7 @@ func TestCollectionSearch(t *testing.T) { expect(t, n == c.Count()) n = 0 - c.SearchValuesRange("0060", "0070", false, nil, + c.SearchValuesRange("0060", "0070", false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, obj.String() > prevValue) @@ -353,7 +353,7 @@ func TestCollectionSearch(t *testing.T) { expect(t, n == 10) n = 0 - c.SearchValuesRange("0070", "0060", true, nil, + c.SearchValuesRange("0070", "0060", true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, obj.String() < prevValue) @@ -436,7 +436,7 @@ func TestSpatialSearch(t *testing.T) { var n int n = 0 - c.Within(q1, 0, nil, + c.Within(q1, 0, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -445,7 +445,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 3) n = 0 - c.Within(q2, 0, nil, + c.Within(q2, 0, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -454,7 +454,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 7) n = 0 - c.Within(q3, 0, nil, + c.Within(q3, 0, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -463,7 +463,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 4) n = 0 - c.Intersects(q1, 0, nil, + c.Intersects(q1, 0, nil, nil, func(_ string, _ geojson.Object, _ []float64) bool { n++ return true @@ -472,7 +472,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 4) n = 0 - c.Intersects(q2, 0, nil, + c.Intersects(q2, 0, nil, nil, func(_ string, _ geojson.Object, _ []float64) bool { n++ return true @@ -481,7 +481,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 7) n = 0 - c.Intersects(q3, 0, nil, + c.Intersects(q3, 0, nil, nil, func(_ string, _ geojson.Object, _ []float64) bool { n++ return true @@ -490,7 +490,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 5) n = 0 - c.Intersects(q3, 0, nil, + c.Intersects(q3, 0, nil, nil, func(_ string, _ geojson.Object, _ []float64) bool { n++ return n <= 1 @@ -502,7 +502,7 @@ func TestSpatialSearch(t *testing.T) { exitems := []geojson.Object{ r2, p1, p4, r1, p3, r3, p2, } - c.Nearby(q4, nil, + c.Nearby(q4, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { items = append(items, obj) return true @@ -528,7 +528,7 @@ func TestCollectionSparse(t *testing.T) { } var n int n = 0 - c.Within(rect, 1, nil, + c.Within(rect, 1, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -537,7 +537,7 @@ func TestCollectionSparse(t *testing.T) { expect(t, n == 4) n = 0 - c.Within(rect, 2, nil, + c.Within(rect, 2, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -546,7 +546,7 @@ func TestCollectionSparse(t *testing.T) { expect(t, n == 16) n = 0 - c.Within(rect, 3, nil, + c.Within(rect, 3, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -555,7 +555,7 @@ func TestCollectionSparse(t *testing.T) { expect(t, n == 64) n = 0 - c.Within(rect, 3, nil, + c.Within(rect, 3, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return n <= 30 @@ -564,7 +564,7 @@ func TestCollectionSparse(t *testing.T) { expect(t, n == 31) n = 0 - c.Intersects(rect, 3, nil, + c.Intersects(rect, 3, nil, nil, func(id string, _ geojson.Object, _ []float64) bool { n++ return true @@ -573,7 +573,7 @@ func TestCollectionSparse(t *testing.T) { expect(t, n == 64) n = 0 - c.Intersects(rect, 3, nil, + c.Intersects(rect, 3, nil, nil, func(id string, _ geojson.Object, _ []float64) bool { n++ return n <= 30 diff --git a/internal/deadline/deadline.go b/internal/deadline/deadline.go new file mode 100644 index 00000000..7acf2577 --- /dev/null +++ b/internal/deadline/deadline.go @@ -0,0 +1,31 @@ +package deadline + +import "time" + +// Deadline allows for commands to expire when they run too long +type Deadline struct { + unixNano int64 + hit bool +} + +// New returns a new deadline object +func New(deadline time.Time) *Deadline { + return &Deadline{unixNano: deadline.UnixNano()} +} + +// Check the deadline and panic when reached +//go:noinline +func (deadline *Deadline) Check() { + if deadline == nil || deadline.unixNano == 0 { + return + } + if !deadline.hit && time.Now().UnixNano() > deadline.unixNano { + deadline.hit = true + panic("deadline") + } +} + +// Hit returns true if the deadline has been hit +func (deadline *Deadline) Hit() bool { + return deadline.hit +} diff --git a/internal/server/aofshrink.go b/internal/server/aofshrink.go index 095cad79..8e118d6c 100644 --- a/internal/server/aofshrink.go +++ b/internal/server/aofshrink.go @@ -96,7 +96,7 @@ func (server *Server) aofshrink() { var exm = server.expires[keys[0]] // the expiration map var now = time.Now() // used for expiration var count = 0 // the object count - col.ScanGreaterOrEqual(nextid, false, nil, + col.ScanGreaterOrEqual(nextid, false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if count == maxids { // we reached the max number of ids for one batch diff --git a/internal/server/client.go b/internal/server/client.go index dbfaa737..15e49e41 100644 --- a/internal/server/client.go +++ b/internal/server/client.go @@ -34,6 +34,7 @@ type Client struct { opened time.Time // when the client was created/opened, unix nano last time.Time // last client request/response, unix nano + timeout time.Duration // command timeout } // Write ... diff --git a/internal/server/crud.go b/internal/server/crud.go index 664a9ed3..4ddf4b5a 100644 --- a/internal/server/crud.go +++ b/internal/server/crud.go @@ -373,9 +373,9 @@ func (server *Server) cmdPdel(msg *Message) (res resp.Value, d commandDetails, e if col != nil { g := glob.Parse(d.pattern, false) if g.Limits[0] == "" && g.Limits[1] == "" { - col.Scan(false, nil, iter) + col.Scan(false, nil, msg.Deadline, iter) } else { - col.ScanRange(g.Limits[0], g.Limits[1], false, nil, iter) + col.ScanRange(g.Limits[0], g.Limits[1], false, nil, msg.Deadline, iter) } var atLeastOneNotDeleted bool for i, dc := range d.children { diff --git a/internal/server/fence.go b/internal/server/fence.go index c1d65460..0817ee26 100644 --- a/internal/server/fence.go +++ b/internal/server/fence.go @@ -300,10 +300,10 @@ func extendRoamMessage( } g := glob.Parse(pattern, false) if g.Limits[0] == "" && g.Limits[1] == "" { - col.Scan(false, nil, iterator) + col.Scan(false, nil, nil, iterator) } else { col.ScanRange(g.Limits[0], g.Limits[1], - false, nil, iterator) + false, nil, nil, iterator) } } nmsg = append(nmsg, ']') @@ -370,7 +370,7 @@ func fenceMatchRoam( Min: geometry.Point{X: minLon, Y: minLat}, Max: geometry.Point{X: maxLon, Y: maxLat}, } - col.Intersects(geojson.NewRect(rect), 0, nil, func( + col.Intersects(geojson.NewRect(rect), 0, nil, nil, func( id string, obj2 geojson.Object, fields []float64, ) bool { if c.hasExpired(fence.roam.key, id) { diff --git a/internal/server/scan.go b/internal/server/scan.go index c19894b2..064fca6b 100644 --- a/internal/server/scan.go +++ b/internal/server/scan.go @@ -67,6 +67,7 @@ func (c *Server) cmdScan(msg *Message) (res resp.Value, err error) { g := glob.Parse(sw.globPattern, s.desc) if g.Limits[0] == "" && g.Limits[1] == "" { sw.col.Scan(s.desc, sw, + msg.Deadline, func(id string, o geojson.Object, fields []float64) bool { return sw.writeObject(ScanWriterParams{ id: id, @@ -77,6 +78,7 @@ func (c *Server) cmdScan(msg *Message) (res resp.Value, err error) { ) } else { sw.col.ScanRange(g.Limits[0], g.Limits[1], s.desc, sw, + msg.Deadline, func(id string, o geojson.Object, fields []float64) bool { return sw.writeObject(ScanWriterParams{ id: id, diff --git a/internal/server/search.go b/internal/server/search.go index 9cc12615..1b7f9af0 100644 --- a/internal/server/search.go +++ b/internal/server/search.go @@ -14,6 +14,7 @@ import ( "github.com/tidwall/geojson/geometry" "github.com/tidwall/resp" "github.com/tidwall/tile38/internal/bing" + "github.com/tidwall/tile38/internal/deadline" "github.com/tidwall/tile38/internal/glob" ) @@ -385,7 +386,7 @@ func (server *Server) cmdNearby(msg *Message) (res resp.Value, err error) { skipTesting: true, }) } - server.nearestNeighbors(&s, sw, s.obj.(*geojson.Circle), iter) + server.nearestNeighbors(&s, sw, msg.Deadline, s.obj.(*geojson.Circle), iter) } sw.writeFoot() if msg.OutputType == JSON { @@ -403,13 +404,14 @@ type iterItem struct { } func (server *Server) nearestNeighbors( - s *liveFenceSwitches, sw *scanWriter, target *geojson.Circle, + s *liveFenceSwitches, sw *scanWriter, dl *deadline.Deadline, + target *geojson.Circle, iter func(id string, o geojson.Object, fields []float64, dist float64, ) bool) { maxDist := target.Haversine() limit := int(sw.limit) var items []iterItem - sw.col.Nearby(target, sw, func(id string, o geojson.Object, fields []float64) bool { + sw.col.Nearby(target, sw, dl, func(id string, o geojson.Object, fields []float64) bool { if server.hasExpired(s.key, id) { return true } @@ -480,7 +482,7 @@ func (server *Server) cmdWithinOrIntersects(cmd string, msg *Message) (res resp. sw.writeHead() if sw.col != nil { if cmd == "within" { - sw.col.Within(s.obj, s.sparse, sw, func( + sw.col.Within(s.obj, s.sparse, sw, msg.Deadline, func( id string, o geojson.Object, fields []float64, ) bool { if server.hasExpired(s.key, id) { @@ -494,7 +496,7 @@ func (server *Server) cmdWithinOrIntersects(cmd string, msg *Message) (res resp. }) }) } else if cmd == "intersects" { - sw.col.Intersects(s.obj, s.sparse, sw, func( + sw.col.Intersects(s.obj, s.sparse, sw, msg.Deadline, func( id string, o geojson.Object, fields []float64, @@ -578,7 +580,7 @@ func (server *Server) cmdSearch(msg *Message) (res resp.Value, err error) { } else { g := glob.Parse(sw.globPattern, s.desc) if g.Limits[0] == "" && g.Limits[1] == "" { - sw.col.SearchValues(s.desc, sw, + sw.col.SearchValues(s.desc, sw, msg.Deadline, func(id string, o geojson.Object, fields []float64) bool { return sw.writeObject(ScanWriterParams{ id: id, @@ -593,6 +595,7 @@ func (server *Server) cmdSearch(msg *Message) (res resp.Value, err error) { // globSingle is only for ID matches, not values. sw.globSingle = false sw.col.SearchValuesRange(g.Limits[0], g.Limits[1], s.desc, sw, + msg.Deadline, func(id string, o geojson.Object, fields []float64) bool { return sw.writeObject(ScanWriterParams{ id: id, diff --git a/internal/server/server.go b/internal/server/server.go index 16c9b25e..fd60b582 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -31,6 +31,7 @@ import ( "github.com/tidwall/resp" "github.com/tidwall/tile38/core" "github.com/tidwall/tile38/internal/collection" + "github.com/tidwall/tile38/internal/deadline" "github.com/tidwall/tile38/internal/endpoint" "github.com/tidwall/tile38/internal/expire" "github.com/tidwall/tile38/internal/log" @@ -996,7 +997,7 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { // does not write to aof, but requires a write lock. server.mu.Lock() defer server.mu.Unlock() - case "output": + case "output", "timeout": // this is local connection operation. Locks not needed. case "echo": case "massinsert": @@ -1022,8 +1023,24 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { case "subscribe", "psubscribe", "publish": // No locking for pubsub } - - res, d, err := server.command(msg, client) + res, d, err := func() (res resp.Value, d commandDetails, err error) { + if client.timeout != 0 && !write { + msg.Deadline = deadline.New(start.Add(client.timeout)) + defer func() { + if msg.Deadline.Hit() { + v := recover() + if v != nil { + if s, ok := v.(string); !ok || s != "deadline" { + panic(v) + } + } + res = NOMessage + err = writeErr("timeout") + } + }() + } + return server.command(msg, client) + }() if res.Type() == resp.Error { return writeErr(res.String()) } @@ -1180,6 +1197,8 @@ func (server *Server) command(msg *Message, client *Client) ( res, err = server.cmdKeys(msg) case "output": res, err = server.cmdOutput(msg) + case "timeout": + res, err = server.cmdTimeout(msg, client) case "aof": res, err = server.cmdAOF(msg) case "aofmd5": @@ -1323,6 +1342,7 @@ type Message struct { ConnType Type OutputType Type Auth string + Deadline *deadline.Deadline } // Command returns the first argument as a lowercase string diff --git a/internal/server/timeout.go b/internal/server/timeout.go new file mode 100644 index 00000000..4db38ff6 --- /dev/null +++ b/internal/server/timeout.go @@ -0,0 +1,38 @@ +package server + +import ( + "strconv" + "time" + + "github.com/tidwall/resp" +) + +func (c *Server) cmdTimeout(msg *Message, client *Client) (res resp.Value, err error) { + start := time.Now() + vs := msg.Args[1:] + var arg string + var ok bool + + if len(vs) != 0 { + if _, arg, ok = tokenval(vs); !ok || arg == "" { + return NOMessage, errInvalidNumberOfArguments + } + timeout, err := strconv.ParseFloat(arg, 64) + if err != nil || timeout < 0 { + return NOMessage, errInvalidArgument(arg) + } + client.timeout = time.Duration(timeout * float64(time.Second)) + return OKMessage(msg, start), nil + } + // return the timeout + switch msg.OutputType { + default: + return NOMessage, nil + case JSON: + return resp.StringValue(`{"ok":true` + + `,"seconds":` + strconv.FormatFloat(client.timeout.Seconds(), 'f', -1, 64) + + `,"elapsed":` + time.Now().Sub(start).String() + `}`), nil + case RESP: + return resp.FloatValue(client.timeout.Seconds()), nil + } +} From e514a0287fb7e4c6ec9dc8d7c24550151039842a Mon Sep 17 00:00:00 2001 From: Alex Roitman Date: Wed, 24 Apr 2019 12:02:39 -0700 Subject: [PATCH 2/5] Add timeout subcommand to scan/search commands. Use per-query timeout for those commands, if it was given. --- internal/deadline/deadline.go | 5 +++++ internal/server/scan.go | 3 +++ internal/server/search.go | 9 +++++++++ internal/server/server.go | 2 +- internal/server/token.go | 16 ++++++++++++++++ 5 files changed, 34 insertions(+), 1 deletion(-) diff --git a/internal/deadline/deadline.go b/internal/deadline/deadline.go index 7acf2577..cb22525e 100644 --- a/internal/deadline/deadline.go +++ b/internal/deadline/deadline.go @@ -13,6 +13,11 @@ func New(deadline time.Time) *Deadline { return &Deadline{unixNano: deadline.UnixNano()} } +// Update the deadline from a given time object +func (deadline *Deadline) Update(newDeadline time.Time) { + deadline.unixNano = newDeadline.UnixNano() +} + // Check the deadline and panic when reached //go:noinline func (deadline *Deadline) Check() { diff --git a/internal/server/scan.go b/internal/server/scan.go index 064fca6b..bded9749 100644 --- a/internal/server/scan.go +++ b/internal/server/scan.go @@ -55,6 +55,9 @@ func (c *Server) cmdScan(msg *Message) (res resp.Value, err error) { wr.WriteString(`{"ok":true`) } sw.writeHead() + if s.timeout != 0 { + msg.Deadline.Update(start.Add(s.timeout)) + } if sw.col != nil { if sw.output == outputCount && len(sw.wheres) == 0 && len(sw.whereins) == 0 && sw.globEverything == true { diff --git a/internal/server/search.go b/internal/server/search.go index 1b7f9af0..d4083a74 100644 --- a/internal/server/search.go +++ b/internal/server/search.go @@ -370,6 +370,9 @@ func (server *Server) cmdNearby(msg *Message) (res resp.Value, err error) { wr.WriteString(`{"ok":true`) } sw.writeHead() + if s.timeout != 0 { + msg.Deadline.Update(start.Add(s.timeout)) + } if sw.col != nil { iter := func(id string, o geojson.Object, fields []float64, dist float64) bool { meters := 0.0 @@ -480,6 +483,9 @@ func (server *Server) cmdWithinOrIntersects(cmd string, msg *Message) (res resp. wr.WriteString(`{"ok":true`) } sw.writeHead() + if s.timeout != 0 { + msg.Deadline.Update(start.Add(s.timeout)) + } if sw.col != nil { if cmd == "within" { sw.col.Within(s.obj, s.sparse, sw, msg.Deadline, func( @@ -570,6 +576,9 @@ func (server *Server) cmdSearch(msg *Message) (res resp.Value, err error) { wr.WriteString(`{"ok":true`) } sw.writeHead() + if s.timeout != 0 { + msg.Deadline.Update(start.Add(s.timeout)) + } if sw.col != nil { if sw.output == outputCount && len(sw.wheres) == 0 && sw.globEverything == true { count := sw.col.Count() - int(s.cursor) diff --git a/internal/server/server.go b/internal/server/server.go index fd60b582..bea3564b 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1024,7 +1024,7 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { // No locking for pubsub } res, d, err := func() (res resp.Value, d commandDetails, err error) { - if client.timeout != 0 && !write { + if !write { msg.Deadline = deadline.New(start.Add(client.timeout)) defer func() { if msg.Deadline.Hit() { diff --git a/internal/server/token.go b/internal/server/token.go index 71a81044..8b46ef01 100644 --- a/internal/server/token.go +++ b/internal/server/token.go @@ -6,6 +6,7 @@ import ( "math" "strconv" "strings" + "time" "github.com/yuin/gopher-lua" ) @@ -247,6 +248,7 @@ type searchScanBaseTokens struct { sparse uint8 desc bool clip bool + timeout time.Duration } func (c *Server) parseSearchScanBaseTokens( @@ -579,6 +581,20 @@ func (c *Server) parseSearchScanBaseTokens( } t.clip = true continue + case "timeout": + vs = nvs + var valStr string + if vs, valStr, ok = tokenval(vs); !ok || valStr == "" { + err = errInvalidNumberOfArguments + return + } + timeout, _err := strconv.ParseFloat(valStr, 64) + if _err != nil || timeout < 0 { + err = errInvalidArgument(valStr) + return + } + t.timeout = time.Duration(timeout * float64(time.Second)) + continue } } break From 7177a6468f6ae82b3bd6bad66352516dae26fccd Mon Sep 17 00:00:00 2001 From: Alex Roitman Date: Wed, 24 Apr 2019 12:08:07 -0700 Subject: [PATCH 3/5] Add TIMEOUT description to the commands.json --- core/commands.json | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/core/commands.json b/core/commands.json index 93adc1bb..bbc1409a 100644 --- a/core/commands.json +++ b/core/commands.json @@ -359,6 +359,12 @@ "type": "integer", "optional": true }, + { + "command": "TIMEOUT", + "name": "seconds", + "type": "double", + "optional": true + }, { "command": "MATCH", "name": "pattern", @@ -450,6 +456,12 @@ "type": "integer", "optional": true }, + { + "command": "TIMEOUT", + "name": "seconds", + "type": "double", + "optional": true + }, { "command": "MATCH", "name": "pattern", @@ -559,6 +571,12 @@ "type": "integer", "optional": true }, + { + "command": "TIMEOUT", + "name": "seconds", + "type": "double", + "optional": true + }, { "command": "SPARSE", "name": "spread", @@ -725,6 +743,12 @@ "type": "integer", "optional": true }, + { + "command": "TIMEOUT", + "name": "seconds", + "type": "double", + "optional": true + }, { "command": "SPARSE", "name": "spread", @@ -946,6 +970,12 @@ "type": "integer", "optional": true }, + { + "command": "TIMEOUT", + "name": "seconds", + "type": "double", + "optional": true + }, { "command": "SPARSE", "name": "spread", From 31525487c3cac6986ea7850579a96f2c02128ae8 Mon Sep 17 00:00:00 2001 From: Alex Roitman Date: Wed, 24 Apr 2019 13:20:57 -0700 Subject: [PATCH 4/5] Add timeout tests. Fix a bug. --- internal/deadline/deadline.go | 4 ++ internal/server/server.go | 8 ++- tests/tests_test.go | 1 + tests/timeout_test.go | 114 ++++++++++++++++++++++++++++++++++ 4 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 tests/timeout_test.go diff --git a/internal/deadline/deadline.go b/internal/deadline/deadline.go index cb22525e..6be76543 100644 --- a/internal/deadline/deadline.go +++ b/internal/deadline/deadline.go @@ -13,6 +13,10 @@ func New(deadline time.Time) *Deadline { return &Deadline{unixNano: deadline.UnixNano()} } +func Empty() *Deadline { + return &Deadline{} +} + // Update the deadline from a given time object func (deadline *Deadline) Update(newDeadline time.Time) { deadline.unixNano = newDeadline.UnixNano() diff --git a/internal/server/server.go b/internal/server/server.go index bea3564b..ec61f932 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1025,7 +1025,13 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { } res, d, err := func() (res resp.Value, d commandDetails, err error) { if !write { - msg.Deadline = deadline.New(start.Add(client.timeout)) + if client.timeout == 0 { + // the command itself might have a timeout, + // which will be used to update this trivial deadline. + msg.Deadline = deadline.Empty() + } else { + msg.Deadline = deadline.New(start.Add(client.timeout)) + } defer func() { if msg.Deadline.Hit() { v := recover() diff --git a/tests/tests_test.go b/tests/tests_test.go index 6624fa64..34953ddf 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -47,6 +47,7 @@ func TestAll(t *testing.T) { runSubTest(t, "scripts", mc, subTestScripts) runSubTest(t, "info", mc, subTestInfo) runSubTest(t, "client", mc, subTestClient) + runSubTest(t, "timeouts", mc, subTestTimeout) } func runSubTest(t *testing.T, name string, mc *mockServer, test func(t *testing.T, mc *mockServer)) { diff --git a/tests/timeout_test.go b/tests/timeout_test.go new file mode 100644 index 00000000..7178c0b1 --- /dev/null +++ b/tests/timeout_test.go @@ -0,0 +1,114 @@ +package tests + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/gomodule/redigo/redis" +) + +func subTestTimeout(t *testing.T, mc *mockServer) { + runStep(t, mc, "session set/unset", timeout_session_set_unset_test) + runStep(t, mc, "session spatial", timeout_session_spatial_test) + runStep(t, mc, "session search", timeout_session_search_test) + runStep(t, mc, "command spatial", timeout_command_spatial_test) + runStep(t, mc, "command search", timeout_command_search_test) +} + +func setup(mc *mockServer, count int, points bool) (err error) { + rand.Seed(time.Now().UnixNano()) + + // add a bunch of points + for i := 0; i < count; i++ { + val := fmt.Sprintf("val:%d", i) + var resp string + var lat, lon, fval float64 + fval = rand.Float64() + if points { + lat = rand.Float64()*180 - 90 + lon = rand.Float64()*360 - 180 + resp, err = redis.String(mc.conn.Do("SET", + "mykey", val, + "FIELD", "foo", fval, + "POINT", lat, lon)) + } else { + resp, err = redis.String(mc.conn.Do("SET", + "mykey", val, + "FIELD", "foo", fval, + "STRING", val)) + } + if err != nil { + return + } + if resp != "OK" { + err = fmt.Errorf("expected 'OK', got '%s'", resp) + return + } + time.Sleep(time.Nanosecond) + } + time.Sleep(time.Second * 3) + return +} + +func timeout_session_set_unset_test(mc *mockServer) (err error) { + return mc.DoBatch([][]interface{}{ + {"TIMEOUT"}, {"0"}, + {"TIMEOUT", "0.25"}, {"OK"}, + {"TIMEOUT"}, {"0.25"}, + {"TIMEOUT", "0"}, {"OK"}, + {"TIMEOUT"}, {"0"}, + }) +} + +func timeout_session_spatial_test(mc *mockServer) (err error) { + err = setup(mc, 10000, true) + + return mc.DoBatch([][]interface{}{ + {"SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"10000"}, + {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + + {"TIMEOUT", "0.000001"}, {"OK"}, + + {"SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"ERR timeout"}, + {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + }) +} + +func timeout_command_spatial_test(mc *mockServer) (err error) { + err = setup(mc, 10000, true) + + return mc.DoBatch([][]interface{}{ + {"TIMEOUT", "1"}, {"OK"}, + {"SCAN", "mykey", "WHERE", "foo", -1, 2, "COUNT"}, {"10000"}, + {"INTERSECTS", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + {"WITHIN", "mykey", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"10000"}, + + {"SCAN", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT"}, {"ERR timeout"}, + {"INTERSECTS", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + {"WITHIN", "mykey", "TIMEOUT", "0.000001", "WHERE", "foo", -1, 2, "COUNT", "BOUNDS", -90, -180, 90, 180}, {"ERR timeout"}, + }) +} + +func timeout_session_search_test(mc *mockServer) (err error) { + err = setup(mc, 10000, false) + + return mc.DoBatch([][]interface{}{ + {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"10000"}, + {"TIMEOUT", "0.000001"}, {"OK"}, + {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"ERR timeout"}, + }) +} + +func timeout_command_search_test(mc *mockServer) (err error) { + err = setup(mc, 10000, false) + + return mc.DoBatch([][]interface{}{ + {"TIMEOUT", "1"}, {"OK"}, + {"SEARCH", "mykey", "MATCH", "val:*", "COUNT"}, {"10000"}, + {"SEARCH", "mykey", "TIMEOUT", "0.000001", "MATCH", "val:*", "COUNT"}, {"ERR timeout"}, + }) +} From d3350c033f7c72b6407134c3ccf69e904187dc91 Mon Sep 17 00:00:00 2001 From: Alex Roitman Date: Wed, 24 Apr 2019 13:26:25 -0700 Subject: [PATCH 5/5] Add session timeout description to commands.json --- core/commands.json | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/core/commands.json b/core/commands.json index bbc1409a..e4bc7f96 100644 --- a/core/commands.json +++ b/core/commands.json @@ -1329,6 +1329,17 @@ ], "group": "connection" }, + "TIMEOUT": { + "summary": "Gets or sets the query timeout for the current connection.", + "arguments": [ + { + "name": "seconds", + "optional": true, + "type": "double" + } + ], + "group": "connection" + }, "SETHOOK": { "summary": "Creates a webhook which points to geofenced search", "arguments": [