From 0aecef6a5c7965a33859d49b0f1f3b2f98b6397c Mon Sep 17 00:00:00 2001 From: tidwall Date: Wed, 24 Apr 2019 05:09:41 -0700 Subject: [PATCH] Added TIMEOUT command --- internal/collection/collection.go | 106 ++++++++++--------------- internal/collection/collection_test.go | 48 +++++------ internal/deadline/deadline.go | 31 ++++++++ internal/server/aofshrink.go | 2 +- internal/server/client.go | 1 + internal/server/crud.go | 4 +- internal/server/fence.go | 6 +- internal/server/scan.go | 2 + internal/server/search.go | 15 ++-- internal/server/server.go | 26 +++++- internal/server/timeout.go | 38 +++++++++ 11 files changed, 175 insertions(+), 104 deletions(-) create mode 100644 internal/deadline/deadline.go create mode 100644 internal/server/timeout.go diff --git a/internal/collection/collection.go b/internal/collection/collection.go index ad107e38..8562c3d4 100644 --- a/internal/collection/collection.go +++ b/internal/collection/collection.go @@ -8,11 +8,12 @@ import ( "github.com/tidwall/geojson" "github.com/tidwall/geojson/geo" "github.com/tidwall/geojson/geometry" + "github.com/tidwall/tile38/internal/deadline" "github.com/tidwall/tinybtree" ) -// yieldStep forces the iterator to yield goroutine every N steps. -const yieldStep = 0xFF +// yieldStep forces the iterator to yield goroutine every 255 steps. +const yieldStep = 255 // Cursor allows for quickly paging through Scan, Within, Intersects, and Nearby type Cursor interface { @@ -320,7 +321,10 @@ func (c *Collection) FieldArr() []string { } // Scan iterates though the collection ids. -func (c *Collection) Scan(desc bool, cursor Cursor, +func (c *Collection) Scan( + desc bool, + cursor Cursor, + deadline *deadline.Deadline, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true @@ -335,12 +339,7 @@ func (c *Collection) Scan(desc bool, cursor Cursor, if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) iitm := value.(*itemT) keepon = iterator(iitm.id, iitm.obj, c.getFieldValues(iitm.id)) return keepon @@ -354,7 +353,11 @@ func (c *Collection) Scan(desc bool, cursor Cursor, } // ScanRange iterates though the collection starting with specified id. -func (c *Collection) ScanRange(start, end string, desc bool, cursor Cursor, +func (c *Collection) ScanRange( + start, end string, + desc bool, + cursor Cursor, + deadline *deadline.Deadline, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true @@ -369,12 +372,7 @@ func (c *Collection) ScanRange(start, end string, desc bool, cursor Cursor, if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) if !desc { if key >= end { return false @@ -398,7 +396,10 @@ func (c *Collection) ScanRange(start, end string, desc bool, cursor Cursor, } // SearchValues iterates though the collection values. -func (c *Collection) SearchValues(desc bool, cursor Cursor, +func (c *Collection) SearchValues( + desc bool, + cursor Cursor, + deadline *deadline.Deadline, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true @@ -413,12 +414,7 @@ func (c *Collection) SearchValues(desc bool, cursor Cursor, if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) iitm := item.(*itemT) keepon = iterator(iitm.id, iitm.obj, c.getFieldValues(iitm.id)) return keepon @@ -434,6 +430,7 @@ func (c *Collection) SearchValues(desc bool, cursor Cursor, // SearchValuesRange iterates though the collection values. func (c *Collection) SearchValuesRange(start, end string, desc bool, cursor Cursor, + deadline *deadline.Deadline, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true @@ -448,12 +445,7 @@ func (c *Collection) SearchValuesRange(start, end string, desc bool, if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) iitm := item.(*itemT) keepon = iterator(iitm.id, iitm.obj, c.getFieldValues(iitm.id)) return keepon @@ -471,6 +463,7 @@ func (c *Collection) SearchValuesRange(start, end string, desc bool, // ScanGreaterOrEqual iterates though the collection starting with specified id. func (c *Collection) ScanGreaterOrEqual(id string, desc bool, cursor Cursor, + deadline *deadline.Deadline, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true @@ -485,12 +478,7 @@ func (c *Collection) ScanGreaterOrEqual(id string, desc bool, if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) iitm := value.(*itemT) keepon = iterator(iitm.id, iitm.obj, c.getFieldValues(iitm.id)) return keepon @@ -594,6 +582,7 @@ func (c *Collection) Within( obj geojson.Object, sparse uint8, cursor Cursor, + deadline *deadline.Deadline, iter func(id string, obj geojson.Object, fields []float64) bool, ) bool { var count uint64 @@ -611,12 +600,7 @@ func (c *Collection) Within( if count <= offset { return false, true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) if match = o.Within(obj); match { ok = iter(id, o, fields) } @@ -630,12 +614,7 @@ func (c *Collection) Within( if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) if o.Within(obj) { return iter(id, o, fields) } @@ -650,6 +629,7 @@ func (c *Collection) Intersects( obj geojson.Object, sparse uint8, cursor Cursor, + deadline *deadline.Deadline, iter func(id string, obj geojson.Object, fields []float64) bool, ) bool { var count uint64 @@ -667,12 +647,7 @@ func (c *Collection) Intersects( if count <= offset { return false, true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) if match = o.Intersects(obj); match { ok = iter(id, o, fields) } @@ -686,12 +661,7 @@ func (c *Collection) Intersects( if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) if o.Intersects(obj) { return iter(id, o, fields) } @@ -704,6 +674,7 @@ func (c *Collection) Intersects( func (c *Collection) Nearby( target geojson.Object, cursor Cursor, + deadline *deadline.Deadline, iter func(id string, obj geojson.Object, fields []float64) bool, ) bool { // First look to see if there's at least one candidate in the circle's @@ -746,12 +717,7 @@ func (c *Collection) Nearby( if count <= offset { return true } - if count&yieldStep == yieldStep { - runtime.Gosched() - } - if cursor != nil { - cursor.Step(1) - } + nextStep(count, cursor, deadline) item := itemv.(*itemT) alive = iter(item.id, item.obj, c.getFieldValues(item.id)) return alive @@ -759,3 +725,13 @@ func (c *Collection) Nearby( ) return alive } + +func nextStep(step uint64, cursor Cursor, deadline *deadline.Deadline) { + if step&yieldStep == yieldStep { + runtime.Gosched() + deadline.Check() + } + if cursor != nil { + cursor.Step(1) + } +} diff --git a/internal/collection/collection_test.go b/internal/collection/collection_test.go index 21325676..edf6fd81 100644 --- a/internal/collection/collection_test.go +++ b/internal/collection/collection_test.go @@ -230,7 +230,7 @@ func TestCollectionScan(t *testing.T) { } var n int var prevID string - c.Scan(false, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.Scan(false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id > prevID) } @@ -241,7 +241,7 @@ func TestCollectionScan(t *testing.T) { }) expect(t, n == c.Count()) n = 0 - c.Scan(true, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.Scan(true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id < prevID) } @@ -253,7 +253,7 @@ func TestCollectionScan(t *testing.T) { expect(t, n == c.Count()) n = 0 - c.ScanRange("0060", "0070", false, nil, + c.ScanRange("0060", "0070", false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id > prevID) @@ -266,7 +266,7 @@ func TestCollectionScan(t *testing.T) { expect(t, n == 10) n = 0 - c.ScanRange("0070", "0060", true, nil, + c.ScanRange("0070", "0060", true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id < prevID) @@ -279,7 +279,7 @@ func TestCollectionScan(t *testing.T) { expect(t, n == 10) n = 0 - c.ScanGreaterOrEqual("0070", true, nil, + c.ScanGreaterOrEqual("0070", true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id < prevID) @@ -292,7 +292,7 @@ func TestCollectionScan(t *testing.T) { expect(t, n == 71) n = 0 - c.ScanGreaterOrEqual("0070", false, nil, + c.ScanGreaterOrEqual("0070", false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, id > prevID) @@ -317,7 +317,7 @@ func TestCollectionSearch(t *testing.T) { } var n int var prevValue string - c.SearchValues(false, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.SearchValues(false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, obj.String() > prevValue) } @@ -328,7 +328,7 @@ func TestCollectionSearch(t *testing.T) { }) expect(t, n == c.Count()) n = 0 - c.SearchValues(true, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.SearchValues(true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, obj.String() < prevValue) } @@ -340,7 +340,7 @@ func TestCollectionSearch(t *testing.T) { expect(t, n == c.Count()) n = 0 - c.SearchValuesRange("0060", "0070", false, nil, + c.SearchValuesRange("0060", "0070", false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, obj.String() > prevValue) @@ -353,7 +353,7 @@ func TestCollectionSearch(t *testing.T) { expect(t, n == 10) n = 0 - c.SearchValuesRange("0070", "0060", true, nil, + c.SearchValuesRange("0070", "0060", true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if n > 0 { expect(t, obj.String() < prevValue) @@ -436,7 +436,7 @@ func TestSpatialSearch(t *testing.T) { var n int n = 0 - c.Within(q1, 0, nil, + c.Within(q1, 0, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -445,7 +445,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 3) n = 0 - c.Within(q2, 0, nil, + c.Within(q2, 0, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -454,7 +454,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 7) n = 0 - c.Within(q3, 0, nil, + c.Within(q3, 0, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -463,7 +463,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 4) n = 0 - c.Intersects(q1, 0, nil, + c.Intersects(q1, 0, nil, nil, func(_ string, _ geojson.Object, _ []float64) bool { n++ return true @@ -472,7 +472,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 4) n = 0 - c.Intersects(q2, 0, nil, + c.Intersects(q2, 0, nil, nil, func(_ string, _ geojson.Object, _ []float64) bool { n++ return true @@ -481,7 +481,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 7) n = 0 - c.Intersects(q3, 0, nil, + c.Intersects(q3, 0, nil, nil, func(_ string, _ geojson.Object, _ []float64) bool { n++ return true @@ -490,7 +490,7 @@ func TestSpatialSearch(t *testing.T) { expect(t, n == 5) n = 0 - c.Intersects(q3, 0, nil, + c.Intersects(q3, 0, nil, nil, func(_ string, _ geojson.Object, _ []float64) bool { n++ return n <= 1 @@ -502,7 +502,7 @@ func TestSpatialSearch(t *testing.T) { exitems := []geojson.Object{ r2, p1, p4, r1, p3, r3, p2, } - c.Nearby(q4, nil, + c.Nearby(q4, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { items = append(items, obj) return true @@ -528,7 +528,7 @@ func TestCollectionSparse(t *testing.T) { } var n int n = 0 - c.Within(rect, 1, nil, + c.Within(rect, 1, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -537,7 +537,7 @@ func TestCollectionSparse(t *testing.T) { expect(t, n == 4) n = 0 - c.Within(rect, 2, nil, + c.Within(rect, 2, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -546,7 +546,7 @@ func TestCollectionSparse(t *testing.T) { expect(t, n == 16) n = 0 - c.Within(rect, 3, nil, + c.Within(rect, 3, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return true @@ -555,7 +555,7 @@ func TestCollectionSparse(t *testing.T) { expect(t, n == 64) n = 0 - c.Within(rect, 3, nil, + c.Within(rect, 3, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { n++ return n <= 30 @@ -564,7 +564,7 @@ func TestCollectionSparse(t *testing.T) { expect(t, n == 31) n = 0 - c.Intersects(rect, 3, nil, + c.Intersects(rect, 3, nil, nil, func(id string, _ geojson.Object, _ []float64) bool { n++ return true @@ -573,7 +573,7 @@ func TestCollectionSparse(t *testing.T) { expect(t, n == 64) n = 0 - c.Intersects(rect, 3, nil, + c.Intersects(rect, 3, nil, nil, func(id string, _ geojson.Object, _ []float64) bool { n++ return n <= 30 diff --git a/internal/deadline/deadline.go b/internal/deadline/deadline.go new file mode 100644 index 00000000..7acf2577 --- /dev/null +++ b/internal/deadline/deadline.go @@ -0,0 +1,31 @@ +package deadline + +import "time" + +// Deadline allows for commands to expire when they run too long +type Deadline struct { + unixNano int64 + hit bool +} + +// New returns a new deadline object +func New(deadline time.Time) *Deadline { + return &Deadline{unixNano: deadline.UnixNano()} +} + +// Check the deadline and panic when reached +//go:noinline +func (deadline *Deadline) Check() { + if deadline == nil || deadline.unixNano == 0 { + return + } + if !deadline.hit && time.Now().UnixNano() > deadline.unixNano { + deadline.hit = true + panic("deadline") + } +} + +// Hit returns true if the deadline has been hit +func (deadline *Deadline) Hit() bool { + return deadline.hit +} diff --git a/internal/server/aofshrink.go b/internal/server/aofshrink.go index 095cad79..8e118d6c 100644 --- a/internal/server/aofshrink.go +++ b/internal/server/aofshrink.go @@ -96,7 +96,7 @@ func (server *Server) aofshrink() { var exm = server.expires[keys[0]] // the expiration map var now = time.Now() // used for expiration var count = 0 // the object count - col.ScanGreaterOrEqual(nextid, false, nil, + col.ScanGreaterOrEqual(nextid, false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { if count == maxids { // we reached the max number of ids for one batch diff --git a/internal/server/client.go b/internal/server/client.go index dbfaa737..15e49e41 100644 --- a/internal/server/client.go +++ b/internal/server/client.go @@ -34,6 +34,7 @@ type Client struct { opened time.Time // when the client was created/opened, unix nano last time.Time // last client request/response, unix nano + timeout time.Duration // command timeout } // Write ... diff --git a/internal/server/crud.go b/internal/server/crud.go index 664a9ed3..4ddf4b5a 100644 --- a/internal/server/crud.go +++ b/internal/server/crud.go @@ -373,9 +373,9 @@ func (server *Server) cmdPdel(msg *Message) (res resp.Value, d commandDetails, e if col != nil { g := glob.Parse(d.pattern, false) if g.Limits[0] == "" && g.Limits[1] == "" { - col.Scan(false, nil, iter) + col.Scan(false, nil, msg.Deadline, iter) } else { - col.ScanRange(g.Limits[0], g.Limits[1], false, nil, iter) + col.ScanRange(g.Limits[0], g.Limits[1], false, nil, msg.Deadline, iter) } var atLeastOneNotDeleted bool for i, dc := range d.children { diff --git a/internal/server/fence.go b/internal/server/fence.go index c1d65460..0817ee26 100644 --- a/internal/server/fence.go +++ b/internal/server/fence.go @@ -300,10 +300,10 @@ func extendRoamMessage( } g := glob.Parse(pattern, false) if g.Limits[0] == "" && g.Limits[1] == "" { - col.Scan(false, nil, iterator) + col.Scan(false, nil, nil, iterator) } else { col.ScanRange(g.Limits[0], g.Limits[1], - false, nil, iterator) + false, nil, nil, iterator) } } nmsg = append(nmsg, ']') @@ -370,7 +370,7 @@ func fenceMatchRoam( Min: geometry.Point{X: minLon, Y: minLat}, Max: geometry.Point{X: maxLon, Y: maxLat}, } - col.Intersects(geojson.NewRect(rect), 0, nil, func( + col.Intersects(geojson.NewRect(rect), 0, nil, nil, func( id string, obj2 geojson.Object, fields []float64, ) bool { if c.hasExpired(fence.roam.key, id) { diff --git a/internal/server/scan.go b/internal/server/scan.go index c19894b2..064fca6b 100644 --- a/internal/server/scan.go +++ b/internal/server/scan.go @@ -67,6 +67,7 @@ func (c *Server) cmdScan(msg *Message) (res resp.Value, err error) { g := glob.Parse(sw.globPattern, s.desc) if g.Limits[0] == "" && g.Limits[1] == "" { sw.col.Scan(s.desc, sw, + msg.Deadline, func(id string, o geojson.Object, fields []float64) bool { return sw.writeObject(ScanWriterParams{ id: id, @@ -77,6 +78,7 @@ func (c *Server) cmdScan(msg *Message) (res resp.Value, err error) { ) } else { sw.col.ScanRange(g.Limits[0], g.Limits[1], s.desc, sw, + msg.Deadline, func(id string, o geojson.Object, fields []float64) bool { return sw.writeObject(ScanWriterParams{ id: id, diff --git a/internal/server/search.go b/internal/server/search.go index 9cc12615..1b7f9af0 100644 --- a/internal/server/search.go +++ b/internal/server/search.go @@ -14,6 +14,7 @@ import ( "github.com/tidwall/geojson/geometry" "github.com/tidwall/resp" "github.com/tidwall/tile38/internal/bing" + "github.com/tidwall/tile38/internal/deadline" "github.com/tidwall/tile38/internal/glob" ) @@ -385,7 +386,7 @@ func (server *Server) cmdNearby(msg *Message) (res resp.Value, err error) { skipTesting: true, }) } - server.nearestNeighbors(&s, sw, s.obj.(*geojson.Circle), iter) + server.nearestNeighbors(&s, sw, msg.Deadline, s.obj.(*geojson.Circle), iter) } sw.writeFoot() if msg.OutputType == JSON { @@ -403,13 +404,14 @@ type iterItem struct { } func (server *Server) nearestNeighbors( - s *liveFenceSwitches, sw *scanWriter, target *geojson.Circle, + s *liveFenceSwitches, sw *scanWriter, dl *deadline.Deadline, + target *geojson.Circle, iter func(id string, o geojson.Object, fields []float64, dist float64, ) bool) { maxDist := target.Haversine() limit := int(sw.limit) var items []iterItem - sw.col.Nearby(target, sw, func(id string, o geojson.Object, fields []float64) bool { + sw.col.Nearby(target, sw, dl, func(id string, o geojson.Object, fields []float64) bool { if server.hasExpired(s.key, id) { return true } @@ -480,7 +482,7 @@ func (server *Server) cmdWithinOrIntersects(cmd string, msg *Message) (res resp. sw.writeHead() if sw.col != nil { if cmd == "within" { - sw.col.Within(s.obj, s.sparse, sw, func( + sw.col.Within(s.obj, s.sparse, sw, msg.Deadline, func( id string, o geojson.Object, fields []float64, ) bool { if server.hasExpired(s.key, id) { @@ -494,7 +496,7 @@ func (server *Server) cmdWithinOrIntersects(cmd string, msg *Message) (res resp. }) }) } else if cmd == "intersects" { - sw.col.Intersects(s.obj, s.sparse, sw, func( + sw.col.Intersects(s.obj, s.sparse, sw, msg.Deadline, func( id string, o geojson.Object, fields []float64, @@ -578,7 +580,7 @@ func (server *Server) cmdSearch(msg *Message) (res resp.Value, err error) { } else { g := glob.Parse(sw.globPattern, s.desc) if g.Limits[0] == "" && g.Limits[1] == "" { - sw.col.SearchValues(s.desc, sw, + sw.col.SearchValues(s.desc, sw, msg.Deadline, func(id string, o geojson.Object, fields []float64) bool { return sw.writeObject(ScanWriterParams{ id: id, @@ -593,6 +595,7 @@ func (server *Server) cmdSearch(msg *Message) (res resp.Value, err error) { // globSingle is only for ID matches, not values. sw.globSingle = false sw.col.SearchValuesRange(g.Limits[0], g.Limits[1], s.desc, sw, + msg.Deadline, func(id string, o geojson.Object, fields []float64) bool { return sw.writeObject(ScanWriterParams{ id: id, diff --git a/internal/server/server.go b/internal/server/server.go index 16c9b25e..fd60b582 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -31,6 +31,7 @@ import ( "github.com/tidwall/resp" "github.com/tidwall/tile38/core" "github.com/tidwall/tile38/internal/collection" + "github.com/tidwall/tile38/internal/deadline" "github.com/tidwall/tile38/internal/endpoint" "github.com/tidwall/tile38/internal/expire" "github.com/tidwall/tile38/internal/log" @@ -996,7 +997,7 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { // does not write to aof, but requires a write lock. server.mu.Lock() defer server.mu.Unlock() - case "output": + case "output", "timeout": // this is local connection operation. Locks not needed. case "echo": case "massinsert": @@ -1022,8 +1023,24 @@ func (server *Server) handleInputCommand(client *Client, msg *Message) error { case "subscribe", "psubscribe", "publish": // No locking for pubsub } - - res, d, err := server.command(msg, client) + res, d, err := func() (res resp.Value, d commandDetails, err error) { + if client.timeout != 0 && !write { + msg.Deadline = deadline.New(start.Add(client.timeout)) + defer func() { + if msg.Deadline.Hit() { + v := recover() + if v != nil { + if s, ok := v.(string); !ok || s != "deadline" { + panic(v) + } + } + res = NOMessage + err = writeErr("timeout") + } + }() + } + return server.command(msg, client) + }() if res.Type() == resp.Error { return writeErr(res.String()) } @@ -1180,6 +1197,8 @@ func (server *Server) command(msg *Message, client *Client) ( res, err = server.cmdKeys(msg) case "output": res, err = server.cmdOutput(msg) + case "timeout": + res, err = server.cmdTimeout(msg, client) case "aof": res, err = server.cmdAOF(msg) case "aofmd5": @@ -1323,6 +1342,7 @@ type Message struct { ConnType Type OutputType Type Auth string + Deadline *deadline.Deadline } // Command returns the first argument as a lowercase string diff --git a/internal/server/timeout.go b/internal/server/timeout.go new file mode 100644 index 00000000..4db38ff6 --- /dev/null +++ b/internal/server/timeout.go @@ -0,0 +1,38 @@ +package server + +import ( + "strconv" + "time" + + "github.com/tidwall/resp" +) + +func (c *Server) cmdTimeout(msg *Message, client *Client) (res resp.Value, err error) { + start := time.Now() + vs := msg.Args[1:] + var arg string + var ok bool + + if len(vs) != 0 { + if _, arg, ok = tokenval(vs); !ok || arg == "" { + return NOMessage, errInvalidNumberOfArguments + } + timeout, err := strconv.ParseFloat(arg, 64) + if err != nil || timeout < 0 { + return NOMessage, errInvalidArgument(arg) + } + client.timeout = time.Duration(timeout * float64(time.Second)) + return OKMessage(msg, start), nil + } + // return the timeout + switch msg.OutputType { + default: + return NOMessage, nil + case JSON: + return resp.StringValue(`{"ok":true` + + `,"seconds":` + strconv.FormatFloat(client.timeout.Seconds(), 'f', -1, 64) + + `,"elapsed":` + time.Now().Sub(start).String() + `}`), nil + case RESP: + return resp.FloatValue(client.timeout.Seconds()), nil + } +}