apply LIMIT after WHERE clause, fix #199

This commit is contained in:
Josh Baker 2017-07-24 08:26:10 -07:00
parent 29634f86ba
commit 300635727a
19 changed files with 390 additions and 372 deletions

View File

@ -93,7 +93,7 @@ func (c *Controller) aofshrink() {
var exm = c.expires[keys[0]] // the expiration map var exm = c.expires[keys[0]] // the expiration map
var now = time.Now() // used for expiration var now = time.Now() // used for expiration
var count = 0 // the object count var count = 0 // the object count
col.ScanGreaterOrEqual(nextid, 0, false, col.ScanGreaterOrEqual(nextid, false,
func(id string, obj geojson.Object, fields []float64) bool { func(id string, obj geojson.Object, fields []float64) bool {
if count == maxids { if count == maxids {
// we reached the max number of ids for one batch // we reached the max number of ids for one batch

View File

@ -265,41 +265,33 @@ func (c *Collection) FieldArr() []string {
return arr return arr
} }
// Scan iterates though the collection ids. A cursor can be used for paging. // Scan iterates though the collection ids.
func (c *Collection) Scan(cursor uint64, desc bool, func (c *Collection) Scan(desc bool,
iterator func(id string, obj geojson.Object, fields []float64) bool, iterator func(id string, obj geojson.Object, fields []float64) bool,
) (ncursor uint64) { ) bool {
var i uint64 var keepon = true
var active = true
iter := func(item btree.Item) bool { iter := func(item btree.Item) bool {
if i >= cursor {
iitm := item.(*itemT) iitm := item.(*itemT)
active = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id)) keepon = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id))
} return keepon
i++
return active
} }
if desc { if desc {
c.items.Descend(iter) c.items.Descend(iter)
} else { } else {
c.items.Ascend(iter) c.items.Ascend(iter)
} }
return i return keepon
} }
// ScanGreaterOrEqual iterates though the collection starting with specified id. A cursor can be used for paging. // ScanGreaterOrEqual iterates though the collection starting with specified id.
func (c *Collection) ScanRange(cursor uint64, start, end string, desc bool, func (c *Collection) ScanRange(start, end string, desc bool,
iterator func(id string, obj geojson.Object, fields []float64) bool, iterator func(id string, obj geojson.Object, fields []float64) bool,
) (ncursor uint64) { ) bool {
var i uint64 var keepon = true
var active = true
iter := func(item btree.Item) bool { iter := func(item btree.Item) bool {
if i >= cursor {
iitm := item.(*itemT) iitm := item.(*itemT)
active = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id)) keepon = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id))
} return keepon
i++
return active
} }
if desc { if desc {
@ -307,77 +299,65 @@ func (c *Collection) ScanRange(cursor uint64, start, end string, desc bool,
} else { } else {
c.items.AscendRange(&itemT{id: start}, &itemT{id: end}, iter) c.items.AscendRange(&itemT{id: start}, &itemT{id: end}, iter)
} }
return i return keepon
} }
// SearchValues iterates though the collection values. A cursor can be used for paging. // SearchValues iterates though the collection values.
func (c *Collection) SearchValues(cursor uint64, desc bool, func (c *Collection) SearchValues(desc bool,
iterator func(id string, obj geojson.Object, fields []float64) bool, iterator func(id string, obj geojson.Object, fields []float64) bool,
) (ncursor uint64) { ) bool {
var i uint64 var keepon = true
var active = true
iter := func(item btree.Item) bool { iter := func(item btree.Item) bool {
if i >= cursor {
iitm := item.(*itemT) iitm := item.(*itemT)
active = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id)) keepon = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id))
} return keepon
i++
return active
} }
if desc { if desc {
c.values.Descend(iter) c.values.Descend(iter)
} else { } else {
c.values.Ascend(iter) c.values.Ascend(iter)
} }
return i return keepon
} }
// SearchValuesRange iterates though the collection values. A cursor can be used for paging. // SearchValuesRange iterates though the collection values.
func (c *Collection) SearchValuesRange(cursor uint64, start, end string, desc bool, func (c *Collection) SearchValuesRange(start, end string, desc bool,
iterator func(id string, obj geojson.Object, fields []float64) bool, iterator func(id string, obj geojson.Object, fields []float64) bool,
) (ncursor uint64) { ) bool {
var i uint64 var keepon = true
var active = true
iter := func(item btree.Item) bool { iter := func(item btree.Item) bool {
if i >= cursor {
iitm := item.(*itemT) iitm := item.(*itemT)
active = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id)) keepon = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id))
} return keepon
i++
return active
} }
if desc { if desc {
c.values.DescendRange(&itemT{object: geojson.String(start)}, &itemT{object: geojson.String(end)}, iter) c.values.DescendRange(&itemT{object: geojson.String(start)}, &itemT{object: geojson.String(end)}, iter)
} else { } else {
c.values.AscendRange(&itemT{object: geojson.String(start)}, &itemT{object: geojson.String(end)}, iter) c.values.AscendRange(&itemT{object: geojson.String(start)}, &itemT{object: geojson.String(end)}, iter)
} }
return i return keepon
} }
// ScanGreaterOrEqual iterates though the collection starting with specified id. A cursor can be used for paging. // ScanGreaterOrEqual iterates though the collection starting with specified id.
func (c *Collection) ScanGreaterOrEqual(id string, cursor uint64, desc bool, func (c *Collection) ScanGreaterOrEqual(id string, desc bool,
iterator func(id string, obj geojson.Object, fields []float64) bool, iterator func(id string, obj geojson.Object, fields []float64) bool,
) (ncursor uint64) { ) bool {
var i uint64 var keepon = true
var active = true
iter := func(item btree.Item) bool { iter := func(item btree.Item) bool {
if i >= cursor {
iitm := item.(*itemT) iitm := item.(*itemT)
active = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id)) keepon = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id))
} return keepon
i++
return active
} }
if desc { if desc {
c.items.DescendLessOrEqual(&itemT{id: id}, iter) c.items.DescendLessOrEqual(&itemT{id: id}, iter)
} else { } else {
c.items.AscendGreaterOrEqual(&itemT{id: id}, iter) c.items.AscendGreaterOrEqual(&itemT{id: id}, iter)
} }
return i return keepon
} }
func (c *Collection) geoSearch(cursor uint64, bbox geojson.BBox, iterator func(id string, obj geojson.Object, fields []float64) bool) (ncursor uint64) { func (c *Collection) geoSearch(bbox geojson.BBox, iterator func(id string, obj geojson.Object, fields []float64) bool) bool {
return c.index.Search(cursor, bbox.Min.Y, bbox.Min.X, bbox.Max.Y, bbox.Max.X, bbox.Min.Z, bbox.Max.Z, func(item index.Item) bool { return c.index.Search(bbox.Min.Y, bbox.Min.X, bbox.Max.Y, bbox.Max.X, bbox.Min.Z, bbox.Max.Z, func(item index.Item) bool {
var iitm *itemT var iitm *itemT
iitm, ok := item.(*itemT) iitm, ok := item.(*itemT)
if !ok { if !ok {
@ -391,14 +371,15 @@ func (c *Collection) geoSearch(cursor uint64, bbox geojson.BBox, iterator func(i
} }
// Nearby returns all object that are nearby a point. // Nearby returns all object that are nearby a point.
func (c *Collection) Nearby(cursor uint64, sparse uint8, lat, lon, meters, minZ, maxZ float64, iterator func(id string, obj geojson.Object, fields []float64) bool) (ncursor uint64) { func (c *Collection) Nearby(sparse uint8, lat, lon, meters, minZ, maxZ float64, iterator func(id string, obj geojson.Object, fields []float64) bool) bool {
var keepon = true
center := geojson.Position{X: lon, Y: lat, Z: 0} center := geojson.Position{X: lon, Y: lat, Z: 0}
bbox := geojson.BBoxesFromCenter(lat, lon, meters) bbox := geojson.BBoxesFromCenter(lat, lon, meters)
bboxes := bbox.Sparse(sparse) bboxes := bbox.Sparse(sparse)
if sparse > 0 { if sparse > 0 {
for _, bbox := range bboxes { for _, bbox := range bboxes {
bbox.Min.Z, bbox.Max.Z = minZ, maxZ bbox.Min.Z, bbox.Max.Z = minZ, maxZ
c.geoSearch(cursor, bbox, func(id string, obj geojson.Object, fields []float64) bool { keepon = c.geoSearch(bbox, func(id string, obj geojson.Object, fields []float64) bool {
if obj.Nearby(center, meters) { if obj.Nearby(center, meters) {
if iterator(id, obj, fields) { if iterator(id, obj, fields) {
return false return false
@ -406,11 +387,14 @@ func (c *Collection) Nearby(cursor uint64, sparse uint8, lat, lon, meters, minZ,
} }
return true return true
}) })
if !keepon {
break
} }
return 0 }
return keepon
} }
bbox.Min.Z, bbox.Max.Z = minZ, maxZ bbox.Min.Z, bbox.Max.Z = minZ, maxZ
return c.geoSearch(cursor, bbox, func(id string, obj geojson.Object, fields []float64) bool { return c.geoSearch(bbox, func(id string, obj geojson.Object, fields []float64) bool {
if obj.Nearby(center, meters) { if obj.Nearby(center, meters) {
return iterator(id, obj, fields) return iterator(id, obj, fields)
} }
@ -419,7 +403,8 @@ func (c *Collection) Nearby(cursor uint64, sparse uint8, lat, lon, meters, minZ,
} }
// Within returns all object that are fully contained within an object or bounding box. Set obj to nil in order to use the bounding box. // Within returns all object that are fully contained within an object or bounding box. Set obj to nil in order to use the bounding box.
func (c *Collection) Within(cursor uint64, sparse uint8, obj geojson.Object, minLat, minLon, maxLat, maxLon, minZ, maxZ float64, iterator func(id string, obj geojson.Object, fields []float64) bool) (ncursor uint64) { func (c *Collection) Within(sparse uint8, obj geojson.Object, minLat, minLon, maxLat, maxLon, minZ, maxZ float64, iterator func(id string, obj geojson.Object, fields []float64) bool) bool {
var keepon = true
var bbox geojson.BBox var bbox geojson.BBox
if obj != nil { if obj != nil {
bbox = obj.CalculatedBBox() bbox = obj.CalculatedBBox()
@ -436,7 +421,7 @@ func (c *Collection) Within(cursor uint64, sparse uint8, obj geojson.Object, min
if sparse > 0 { if sparse > 0 {
for _, bbox := range bboxes { for _, bbox := range bboxes {
if obj != nil { if obj != nil {
c.geoSearch(cursor, bbox, func(id string, o geojson.Object, fields []float64) bool { keepon = c.geoSearch(bbox, func(id string, o geojson.Object, fields []float64) bool {
if o.Within(obj) { if o.Within(obj) {
if iterator(id, o, fields) { if iterator(id, o, fields) {
return false return false
@ -445,7 +430,8 @@ func (c *Collection) Within(cursor uint64, sparse uint8, obj geojson.Object, min
return true return true
}) })
} }
c.geoSearch(cursor, bbox, func(id string, o geojson.Object, fields []float64) bool { if keepon {
keepon = c.geoSearch(bbox, func(id string, o geojson.Object, fields []float64) bool {
if o.WithinBBox(bbox) { if o.WithinBBox(bbox) {
if iterator(id, o, fields) { if iterator(id, o, fields) {
return false return false
@ -454,17 +440,21 @@ func (c *Collection) Within(cursor uint64, sparse uint8, obj geojson.Object, min
return true return true
}) })
} }
return 0 if !keepon {
break
}
}
return keepon
} }
if obj != nil { if obj != nil {
return c.geoSearch(cursor, bbox, func(id string, o geojson.Object, fields []float64) bool { return c.geoSearch(bbox, func(id string, o geojson.Object, fields []float64) bool {
if o.Within(obj) { if o.Within(obj) {
return iterator(id, o, fields) return iterator(id, o, fields)
} }
return true return true
}) })
} }
return c.geoSearch(cursor, bbox, func(id string, o geojson.Object, fields []float64) bool { return c.geoSearch(bbox, func(id string, o geojson.Object, fields []float64) bool {
if o.WithinBBox(bbox) { if o.WithinBBox(bbox) {
return iterator(id, o, fields) return iterator(id, o, fields)
} }
@ -473,7 +463,8 @@ func (c *Collection) Within(cursor uint64, sparse uint8, obj geojson.Object, min
} }
// Intersects returns all object that are intersect an object or bounding box. Set obj to nil in order to use the bounding box. // Intersects returns all object that are intersect an object or bounding box. Set obj to nil in order to use the bounding box.
func (c *Collection) Intersects(cursor uint64, sparse uint8, obj geojson.Object, minLat, minLon, maxLat, maxLon, minZ, maxZ float64, iterator func(id string, obj geojson.Object, fields []float64) bool) (ncursor uint64) { func (c *Collection) Intersects(sparse uint8, obj geojson.Object, minLat, minLon, maxLat, maxLon, minZ, maxZ float64, iterator func(id string, obj geojson.Object, fields []float64) bool) bool {
var keepon = true
var bbox geojson.BBox var bbox geojson.BBox
if obj != nil { if obj != nil {
bbox = obj.CalculatedBBox() bbox = obj.CalculatedBBox()
@ -501,7 +492,7 @@ func (c *Collection) Intersects(cursor uint64, sparse uint8, obj geojson.Object,
} }
for _, bbox := range bboxes { for _, bbox := range bboxes {
if obj != nil { if obj != nil {
c.geoSearch(cursor, bbox, func(id string, o geojson.Object, fields []float64) bool { keepon = c.geoSearch(bbox, func(id string, o geojson.Object, fields []float64) bool {
if o.Intersects(obj) { if o.Intersects(obj) {
if iterator(id, o, fields) { if iterator(id, o, fields) {
return false return false
@ -510,7 +501,8 @@ func (c *Collection) Intersects(cursor uint64, sparse uint8, obj geojson.Object,
return true return true
}) })
} }
c.geoSearch(cursor, bbox, func(id string, o geojson.Object, fields []float64) bool { if keepon {
keepon = c.geoSearch(bbox, func(id string, o geojson.Object, fields []float64) bool {
if o.IntersectsBBox(bbox) { if o.IntersectsBBox(bbox) {
if iterator(id, o, fields) { if iterator(id, o, fields) {
return false return false
@ -519,17 +511,21 @@ func (c *Collection) Intersects(cursor uint64, sparse uint8, obj geojson.Object,
return true return true
}) })
} }
return 0 if !keepon {
break
}
}
return keepon
} }
if obj != nil { if obj != nil {
return c.geoSearch(cursor, bbox, func(id string, o geojson.Object, fields []float64) bool { return c.geoSearch(bbox, func(id string, o geojson.Object, fields []float64) bool {
if o.Intersects(obj) { if o.Intersects(obj) {
return iterator(id, o, fields) return iterator(id, o, fields)
} }
return true return true
}) })
} }
return c.geoSearch(cursor, bbox, func(id string, o geojson.Object, fields []float64) bool { return c.geoSearch(bbox, func(id string, o geojson.Object, fields []float64) bool {
if o.IntersectsBBox(bbox) { if o.IntersectsBBox(bbox) {
return iterator(id, o, fields) return iterator(id, o, fields)
} }
@ -537,8 +533,8 @@ func (c *Collection) Intersects(cursor uint64, sparse uint8, obj geojson.Object,
}) })
} }
func (c *Collection) NearestNeighbors(k int, lat, lon float64, iterator func(id string, obj geojson.Object, fields []float64) bool) { func (c *Collection) NearestNeighbors(lat, lon float64, iterator func(id string, obj geojson.Object, fields []float64) bool) bool {
c.index.NearestNeighbors(k, lat, lon, func(item index.Item) bool { return c.index.NearestNeighbors(lat, lon, func(item index.Item) bool {
var iitm *itemT var iitm *itemT
iitm, ok := item.(*itemT) iitm, ok := item.(*itemT)
if !ok { if !ok {

View File

@ -32,7 +32,7 @@ func TestCollection(t *testing.T) {
} }
count := 0 count := 0
bbox := geojson.BBox{Min: geojson.Position{X: -180, Y: -90, Z: 0}, Max: geojson.Position{X: 180, Y: 90, Z: 0}} bbox := geojson.BBox{Min: geojson.Position{X: -180, Y: -90, Z: 0}, Max: geojson.Position{X: 180, Y: 90, Z: 0}}
c.geoSearch(0, bbox, func(id string, obj geojson.Object, field []float64) bool { c.geoSearch(bbox, func(id string, obj geojson.Object, field []float64) bool {
count++ count++
return true return true
}) })
@ -84,7 +84,7 @@ func TestManyCollections(t *testing.T) {
col := colsM["13"] col := colsM["13"]
//println(col.Count()) //println(col.Count())
bbox := geojson.BBox{Min: geojson.Position{X: -180, Y: 30, Z: 0}, Max: geojson.Position{X: 34, Y: 100, Z: 0}} bbox := geojson.BBox{Min: geojson.Position{X: -180, Y: 30, Z: 0}, Max: geojson.Position{X: 34, Y: 100, Z: 0}}
col.geoSearch(0, bbox, func(id string, obj geojson.Object, fields []float64) bool { col.geoSearch(bbox, func(id string, obj geojson.Object, fields []float64) bool {
//println(id) //println(id)
return true return true
}) })

View File

@ -378,9 +378,9 @@ func (c *Controller) cmdPdel(msg *server.Message) (res string, d commandDetailsT
if col != nil { if col != nil {
g := glob.Parse(d.pattern, false) g := glob.Parse(d.pattern, false)
if g.Limits[0] == "" && g.Limits[1] == "" { if g.Limits[0] == "" && g.Limits[1] == "" {
col.Scan(0, false, iter) col.Scan(false, iter)
} else { } else {
col.ScanRange(0, g.Limits[0], g.Limits[1], false, iter) col.ScanRange(g.Limits[0], g.Limits[1], false, iter)
} }
var atLeastOneNotDeleted bool var atLeastOneNotDeleted bool
for i, dc := range d.children { for i, dc := range d.children {

View File

@ -261,9 +261,9 @@ func fenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, metas
} }
g := glob.Parse(pattern, false) g := glob.Parse(pattern, false)
if g.Limits[0] == "" && g.Limits[1] == "" { if g.Limits[0] == "" && g.Limits[1] == "" {
col.Scan(0, false, iterator) col.Scan(false, iterator)
} else { } else {
col.ScanRange(0, g.Limits[0], g.Limits[1], false, iterator) col.ScanRange(g.Limits[0], g.Limits[1], false, iterator)
} }
} }
}() }()
@ -332,7 +332,7 @@ func fenceMatchRoam(c *Controller, fence *liveFenceSwitches, tkey, tid string, o
return return
} }
p := obj.CalculatedPoint() p := obj.CalculatedPoint()
col.Nearby(0, 0, p.Y, p.X, fence.roam.meters, math.Inf(-1), math.Inf(+1), col.Nearby(0, p.Y, p.X, fence.roam.meters, math.Inf(-1), math.Inf(+1),
func(id string, obj geojson.Object, fields []float64) bool { func(id string, obj geojson.Object, fields []float64) bool {
var match bool var match bool
if id == tid { if id == tid {

View File

@ -133,7 +133,7 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res string, d commandDetai
hook.cond = sync.NewCond(&hook.mu) hook.cond = sync.NewCond(&hook.mu)
var wr bytes.Buffer var wr bytes.Buffer
hook.ScanWriter, err = c.newScanWriter(&wr, cmsg, s.key, s.output, s.precision, s.glob, false, s.limit, s.wheres, s.nofields) hook.ScanWriter, err = c.newScanWriter(&wr, cmsg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields)
if err != nil { if err != nil {
return "", d, err return "", d, err
} }

View File

@ -87,7 +87,7 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.AnyReaderWrit
lb.key = s.key lb.key = s.key
lb.fence = &s lb.fence = &s
c.mu.RLock() c.mu.RLock()
sw, err = c.newScanWriter(&wr, msg, s.key, s.output, s.precision, s.glob, false, s.limit, s.wheres, s.nofields) sw, err = c.newScanWriter(&wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields)
c.mu.RUnlock() c.mu.RUnlock()
} }
// everything below if for live SCAN, NEARBY, WITHIN, INTERSECTS // everything below if for live SCAN, NEARBY, WITHIN, INTERSECTS

View File

@ -30,7 +30,7 @@ func (c *Controller) cmdScan(msg *server.Message) (res string, err error) {
if err != nil { if err != nil {
return "", err return "", err
} }
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.limit, s.wheres, s.nofields) sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -48,7 +48,7 @@ func (c *Controller) cmdScan(msg *server.Message) (res string, err error) {
} else { } else {
g := glob.Parse(sw.globPattern, s.desc) g := glob.Parse(sw.globPattern, s.desc)
if g.Limits[0] == "" && g.Limits[1] == "" { if g.Limits[0] == "" && g.Limits[1] == "" {
s.cursor = sw.col.Scan(s.cursor, s.desc, sw.col.Scan(s.desc,
func(id string, o geojson.Object, fields []float64) bool { func(id string, o geojson.Object, fields []float64) bool {
return sw.writeObject(ScanWriterParams{ return sw.writeObject(ScanWriterParams{
id: id, id: id,
@ -58,8 +58,7 @@ func (c *Controller) cmdScan(msg *server.Message) (res string, err error) {
}, },
) )
} else { } else {
s.cursor = sw.col.ScanRange( sw.col.ScanRange(g.Limits[0], g.Limits[1], s.desc,
s.cursor, g.Limits[0], g.Limits[1], s.desc,
func(id string, o geojson.Object, fields []float64) bool { func(id string, o geojson.Object, fields []float64) bool {
return sw.writeObject(ScanWriterParams{ return sw.writeObject(ScanWriterParams{
id: id, id: id,
@ -71,7 +70,7 @@ func (c *Controller) cmdScan(msg *server.Message) (res string, err error) {
} }
} }
} }
sw.writeFoot(s.cursor) sw.writeFoot()
if msg.OutputType == server.JSON { if msg.OutputType == server.JSON {
wr.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}") wr.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}")
} }

View File

@ -41,6 +41,7 @@ type scanWriter struct {
wheres []whereT wheres []whereT
numberItems uint64 numberItems uint64
nofields bool nofields bool
cursor uint64
limit uint64 limit uint64
hitLimit bool hitLimit bool
once bool once bool
@ -65,7 +66,7 @@ type ScanWriterParams struct {
func (c *Controller) newScanWriter( func (c *Controller) newScanWriter(
wr *bytes.Buffer, msg *server.Message, key string, output outputT, wr *bytes.Buffer, msg *server.Message, key string, output outputT,
precision uint64, globPattern string, matchValues bool, precision uint64, globPattern string, matchValues bool,
limit uint64, wheres []whereT, nofields bool, cursor, limit uint64, wheres []whereT, nofields bool,
) ( ) (
*scanWriter, error, *scanWriter, error,
) { ) {
@ -83,6 +84,7 @@ func (c *Controller) newScanWriter(
c: c, c: c,
wr: wr, wr: wr,
msg: msg, msg: msg,
cursor: cursor,
limit: limit, limit: limit,
wheres: wheres, wheres: wheres,
output: output, output: output,
@ -149,9 +151,10 @@ func (sw *scanWriter) writeHead() {
} }
} }
func (sw *scanWriter) writeFoot(cursor uint64) { func (sw *scanWriter) writeFoot() {
sw.mu.Lock() sw.mu.Lock()
defer sw.mu.Unlock() defer sw.mu.Unlock()
cursor := sw.numberItems + sw.cursor
if !sw.hitLimit { if !sw.hitLimit {
cursor = 0 cursor = 0
} }
@ -275,7 +278,6 @@ func (sw *scanWriter) writeObject(opts ScanWriterParams) bool {
if sw.output == outputCount { if sw.output == outputCount {
return true return true
} }
switch sw.msg.OutputType { switch sw.msg.OutputType {
case server.JSON: case server.JSON:
var wr bytes.Buffer var wr bytes.Buffer

View File

@ -296,7 +296,7 @@ func (c *Controller) cmdNearby(msg *server.Message) (res string, err error) {
return "", s return "", s
} }
minZ, maxZ := zMinMaxFromWheres(s.wheres) minZ, maxZ := zMinMaxFromWheres(s.wheres)
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.limit, s.wheres, s.nofields) sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -322,12 +322,12 @@ func (c *Controller) cmdNearby(msg *server.Message) (res string, err error) {
}) })
} }
if s.knn { if s.knn {
sw.col.NearestNeighbors(int(s.limit), s.lat, s.lon, iter) sw.col.NearestNeighbors(s.lat, s.lon, iter)
} else { } else {
s.cursor = sw.col.Nearby(s.cursor, s.sparse, s.lat, s.lon, s.meters, minZ, maxZ, iter) sw.col.Nearby(s.sparse, s.lat, s.lon, s.meters, minZ, maxZ, iter)
} }
} }
sw.writeFoot(s.cursor) sw.writeFoot()
if msg.OutputType == server.JSON { if msg.OutputType == server.JSON {
wr.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}") wr.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}")
} }
@ -355,7 +355,7 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res
if s.fence { if s.fence {
return "", s return "", s
} }
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.limit, s.wheres, s.nofields) sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, false, s.cursor, s.limit, s.wheres, s.nofields)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -368,7 +368,7 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res
sw.writeHead() sw.writeHead()
minZ, maxZ := zMinMaxFromWheres(s.wheres) minZ, maxZ := zMinMaxFromWheres(s.wheres)
if cmd == "within" { if cmd == "within" {
s.cursor = sw.col.Within(s.cursor, s.sparse, s.o, s.minLat, s.minLon, s.maxLat, s.maxLon, minZ, maxZ, sw.col.Within(s.sparse, s.o, s.minLat, s.minLon, s.maxLat, s.maxLon, minZ, maxZ,
func(id string, o geojson.Object, fields []float64) bool { func(id string, o geojson.Object, fields []float64) bool {
if c.hasExpired(s.key, id) { if c.hasExpired(s.key, id) {
return true return true
@ -381,7 +381,7 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res
}, },
) )
} else if cmd == "intersects" { } else if cmd == "intersects" {
s.cursor = sw.col.Intersects(s.cursor, s.sparse, s.o, s.minLat, s.minLon, s.maxLat, s.maxLon, minZ, maxZ, sw.col.Intersects(s.sparse, s.o, s.minLat, s.minLon, s.maxLat, s.maxLon, minZ, maxZ,
func(id string, o geojson.Object, fields []float64) bool { func(id string, o geojson.Object, fields []float64) bool {
if c.hasExpired(s.key, id) { if c.hasExpired(s.key, id) {
return true return true
@ -394,7 +394,7 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res
}, },
) )
} }
sw.writeFoot(s.cursor) sw.writeFoot()
if msg.OutputType == server.JSON { if msg.OutputType == server.JSON {
wr.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}") wr.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}")
} }
@ -421,7 +421,7 @@ func (c *Controller) cmdSearch(msg *server.Message) (res string, err error) {
if err != nil { if err != nil {
return "", err return "", err
} }
sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, true, s.limit, s.wheres, s.nofields) sw, err := c.newScanWriter(wr, msg, s.key, s.output, s.precision, s.glob, true, s.cursor, s.limit, s.wheres, s.nofields)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -439,7 +439,7 @@ func (c *Controller) cmdSearch(msg *server.Message) (res string, err error) {
} else { } else {
g := glob.Parse(sw.globPattern, s.desc) g := glob.Parse(sw.globPattern, s.desc)
if g.Limits[0] == "" && g.Limits[1] == "" { if g.Limits[0] == "" && g.Limits[1] == "" {
s.cursor = sw.col.SearchValues(s.cursor, s.desc, sw.col.SearchValues(s.desc,
func(id string, o geojson.Object, fields []float64) bool { func(id string, o geojson.Object, fields []float64) bool {
return sw.writeObject(ScanWriterParams{ return sw.writeObject(ScanWriterParams{
id: id, id: id,
@ -452,8 +452,7 @@ func (c *Controller) cmdSearch(msg *server.Message) (res string, err error) {
// must disable globSingle for string value type matching because // must disable globSingle for string value type matching because
// globSingle is only for ID matches, not values. // globSingle is only for ID matches, not values.
sw.globSingle = false sw.globSingle = false
s.cursor = sw.col.SearchValuesRange( sw.col.SearchValuesRange(g.Limits[0], g.Limits[1], s.desc,
s.cursor, g.Limits[0], g.Limits[1], s.desc,
func(id string, o geojson.Object, fields []float64) bool { func(id string, o geojson.Object, fields []float64) bool {
return sw.writeObject(ScanWriterParams{ return sw.writeObject(ScanWriterParams{
id: id, id: id,
@ -465,7 +464,7 @@ func (c *Controller) cmdSearch(msg *server.Message) (res string, err error) {
} }
} }
} }
sw.writeFoot(s.cursor) sw.writeFoot()
if msg.OutputType == server.JSON { if msg.OutputType == server.JSON {
wr.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}") wr.WriteString(`,"elapsed":"` + time.Now().Sub(start).String() + "\"}")
} }

View File

@ -96,7 +96,7 @@ func (ix *Index) Remove(item Item) {
// Count counts all items in the index. // Count counts all items in the index.
func (ix *Index) Count() int { func (ix *Index) Count() int {
count := 0 count := 0
ix.Search(0, -90, -180, 90, 180, math.Inf(-1), math.Inf(+1), func(item Item) bool { ix.Search(-90, -180, 90, 180, math.Inf(-1), math.Inf(+1), func(item Item) bool {
count++ count++
return true return true
}) })
@ -123,72 +123,62 @@ func (ix *Index) getRTreeItem(item rtree.Item) Item {
return nil return nil
} }
func (ix *Index) NearestNeighbors(k int, lat, lon float64, iterator func(item Item) bool) { func (ix *Index) NearestNeighbors(lat, lon float64, iterator func(item Item) bool) bool {
x, y, _ := normPoint(lat, lon) x, y, _ := normPoint(lat, lon)
items := ix.r.NearestNeighbors(k, x, y, 0) return ix.r.NearestNeighbors(x, y, 0, func(item rtree.Item, dist float64) bool {
for _, item := range items {
iitm := ix.getRTreeItem(item) iitm := ix.getRTreeItem(item)
if item == nil { if item == nil {
continue return true
}
if !iterator(iitm) {
break
}
} }
return iterator(iitm)
})
} }
// Search returns all items that intersect the bounding box. // Search returns all items that intersect the bounding box.
func (ix *Index) Search(cursor uint64, swLat, swLon, neLat, neLon, minZ, maxZ float64, iterator func(item Item) bool) (ncursor uint64) { func (ix *Index) Search(swLat, swLon, neLat, neLon, minZ, maxZ float64, iterator func(item Item) bool) bool {
var idx uint64 var keepon = true
var active = true
var idm = make(map[Item]bool) var idm = make(map[Item]bool)
mins, maxs, _ := normRect(swLat, swLon, neLat, neLon) mins, maxs, _ := normRect(swLat, swLon, neLat, neLon)
// Points // Points
if len(mins) == 1 { if len(mins) == 1 {
// There is only one rectangle. // There is only one rectangle.
// It's possible that a r rect may span multiple entries. Check mulm map for spanning rects. // It's possible that a r rect may span multiple entries. Check mulm map for spanning rects.
if active { if keepon {
ix.r.Search(mins[0][0], mins[0][1], minZ, maxs[0][0], maxs[0][1], maxZ, func(item rtree.Item) bool { ix.r.Search(mins[0][0], mins[0][1], minZ, maxs[0][0], maxs[0][1], maxZ, func(item rtree.Item) bool {
if idx >= cursor {
iitm := ix.getRTreeItem(item) iitm := ix.getRTreeItem(item)
if iitm != nil { if iitm != nil {
if ix.mulm[iitm] { if ix.mulm[iitm] {
if !idm[iitm] { if !idm[iitm] {
idm[iitm] = true idm[iitm] = true
active = iterator(iitm) keepon = iterator(iitm)
} }
} else { } else {
active = iterator(iitm) keepon = iterator(iitm)
} }
} }
} return keepon
idx++
return active
}) })
} }
} else { } else {
// There are multiple rectangles. Duplicates might occur. // There are multiple rectangles. Duplicates might occur.
for i := range mins { for i := range mins {
if active { if keepon {
ix.r.Search(mins[i][0], mins[i][1], minZ, maxs[i][0], maxs[i][1], maxZ, func(item rtree.Item) bool { ix.r.Search(mins[i][0], mins[i][1], minZ, maxs[i][0], maxs[i][1], maxZ, func(item rtree.Item) bool {
if idx >= cursor {
iitm := ix.getRTreeItem(item) iitm := ix.getRTreeItem(item)
if iitm != nil { if iitm != nil {
if ix.mulm[iitm] { if ix.mulm[iitm] {
if !idm[iitm] { if !idm[iitm] {
idm[iitm] = true idm[iitm] = true
active = iterator(iitm) keepon = iterator(iitm)
} }
} else { } else {
active = iterator(iitm) keepon = iterator(iitm)
} }
} }
} return keepon
idx++
return active
}) })
} }
} }
} }
return idx return keepon
} }

View File

@ -60,7 +60,7 @@ func TestRandomInserts(t *testing.T) {
} }
count = 0 count = 0
items := make([]Item, 0, l) items := make([]Item, 0, l)
tr.Search(0, -90, -180, 90, 180, 0, 0, func(item Item) bool { tr.Search(-90, -180, 90, 180, 0, 0, func(item Item) bool {
count++ count++
items = append(items, item) items = append(items, item)
return true return true
@ -70,7 +70,7 @@ func TestRandomInserts(t *testing.T) {
} }
start = time.Now() start = time.Now()
count1 := 0 count1 := 0
tr.Search(0, 33, -115, 34, -114, 0, 0, func(item Item) bool { tr.Search(33, -115, 34, -114, 0, 0, func(item Item) bool {
count1++ count1++
return true return true
}) })
@ -79,7 +79,7 @@ func TestRandomInserts(t *testing.T) {
start = time.Now() start = time.Now()
count2 := 0 count2 := 0
tr.Search(0, 33-180, -115-360, 34-180, -114-360, 0, 0, func(item Item) bool { tr.Search(33-180, -115-360, 34-180, -114-360, 0, 0, func(item Item) bool {
count2++ count2++
return true return true
}) })
@ -87,7 +87,7 @@ func TestRandomInserts(t *testing.T) {
start = time.Now() start = time.Now()
count3 := 0 count3 := 0
tr.Search(0, -10, 170, 20, 200, 0, 0, func(item Item) bool { tr.Search(-10, 170, 20, 200, 0, 0, func(item Item) bool {
count3++ count3++
return true return true
}) })
@ -99,7 +99,7 @@ func TestRandomInserts(t *testing.T) {
fmt.Printf("Searched %d items in %s.\n", count2, searchdur2.String()) fmt.Printf("Searched %d items in %s.\n", count2, searchdur2.String())
fmt.Printf("Searched %d items in %s.\n", count3, searchdur3.String()) fmt.Printf("Searched %d items in %s.\n", count3, searchdur3.String())
tr.Search(0, -10, 170, 20, 200, 0, 0, func(item Item) bool { tr.Search(-10, 170, 20, 200, 0, 0, func(item Item) bool {
lat1, lon1, _, lat2, lon2, _ := item.Rect() lat1, lon1, _, lat2, lon2, _ := item.Rect()
if lat1 == lat2 && lon1 == lon2 { if lat1 == lat2 && lon1 == lon2 {
return false return false
@ -107,7 +107,7 @@ func TestRandomInserts(t *testing.T) {
return true return true
}) })
tr.Search(0, -10, 170, 20, 200, 0, 0, func(item Item) bool { tr.Search(-10, 170, 20, 200, 0, 0, func(item Item) bool {
lat1, lon1, _, lat2, lon2, _ := item.Rect() lat1, lon1, _, lat2, lon2, _ := item.Rect()
if lat1 != lat2 || lon1 != lon2 { if lat1 != lat2 || lon1 != lon2 {
return false return false
@ -173,7 +173,7 @@ func TestInsertVarious(t *testing.T) {
t.Fatalf("count = %d, expect 1", count) t.Fatalf("count = %d, expect 1", count)
} }
found := false found := false
tr.Search(0, -90, -180, 90, 180, 0, 0, func(item2 Item) bool { tr.Search(-90, -180, 90, 180, 0, 0, func(item2 Item) bool {
if item2 == item { if item2 == item {
found = true found = true
} }

View File

@ -1,205 +1,60 @@
// Much of the KNN code has been adapted from the
// github.com/dhconnelly/rtreego project.
//
// Copyright 2012 Daniel Connelly. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rtree package rtree
import ( import (
"math" "github.com/tidwall/tinyqueue"
"sort"
) )
type queueItem struct {
node *d3nodeT
data interface{}
isItem bool
dist float64
}
func (item *queueItem) Less(b tinyqueue.Item) bool {
return item.dist < b.(*queueItem).dist
}
func boxDistPoint(point []float64, childBox d3rectT) float64 {
var dist float64
for i := 0; i < len(point); i++ {
d := axisDist(point[i], float64(childBox.min[i]), float64(childBox.max[i]))
dist += d * d
}
return dist
}
func axisDist(k, min, max float64) float64 {
if k < min {
return min - k
}
if k <= max {
return 0
}
return k - max
}
// NearestNeighbors gets the closest Spatials to the Point. // NearestNeighbors gets the closest Spatials to the Point.
func (tr *RTree) NearestNeighbors(k int, x, y, z float64) []Item { func (tr *RTree) NearestNeighbors(x, y, z float64, iter func(item Item, dist float64) bool) bool {
if tr.tr.root == nil { knnPoint := []float64{x, y, z}
return nil queue := tinyqueue.New(nil)
node := tr.tr.root
for node != nil {
for i := 0; i < node.count; i++ {
child := node.branch[i]
dist := boxDistPoint(knnPoint, node.branch[i].rect)
queue.Push(&queueItem{node: child.child, data: child.data, isItem: node.isLeaf(), dist: dist})
} }
dists := make([]float64, k) for queue.Len() > 0 && queue.Peek().(*queueItem).isItem {
objs := make([]Item, k) item := queue.Pop().(*queueItem)
for i := 0; i < k; i++ { if !iter(item.data.(Item), item.dist) {
dists[i] = math.MaxFloat64 return false
objs[i] = nil
}
objs, _ = tr.nearestNeighbors(k, x, y, z, tr.tr.root, dists, objs)
//for i := 0; i < len(objs); i++ {
// fmt.Printf("%v\n", objs[i])
//}
for i := 0; i < len(objs); i++ {
if objs[i] == nil {
return objs[:i]
} }
} }
return objs last := queue.Pop()
} if last != nil {
node = last.(*queueItem).node
// minDist computes the square of the distance from a point to a rectangle.
// If the point is contained in the rectangle then the distance is zero.
//
// Implemented per Definition 2 of "Nearest Neighbor Queries" by
// N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995.
func minDist(x, y, z float64, r d3rectT) float64 {
sum := 0.0
p := [3]float64{x, y, z}
rp := [3]float64{
float64(r.min[0]), float64(r.min[1]), float64(r.min[2]),
}
rq := [3]float64{
float64(r.max[0]), float64(r.max[1]), float64(r.max[2]),
}
for i := 0; i < 3; i++ {
if p[i] < float64(rp[i]) {
d := p[i] - float64(rp[i])
sum += d * d
} else if p[i] > float64(rq[i]) {
d := p[i] - float64(rq[i])
sum += d * d
}
}
return sum
}
func (tr *RTree) nearestNeighbors(k int, x, y, z float64, n *d3nodeT, dists []float64, nearest []Item) ([]Item, []float64) {
if n.isLeaf() {
for i := 0; i < n.count; i++ {
e := n.branch[i]
dist := math.Sqrt(minDist(x, y, z, e.rect))
dists, nearest = insertNearest(k, dists, nearest, dist, e.data.(Item))
}
} else { } else {
branches, branchDists := sortEntries(x, y, z, n.branch[:n.count]) node = nil
branches = pruneEntries(x, y, z, branches, branchDists)
for _, e := range branches {
nearest, dists = tr.nearestNeighbors(k, x, y, z, e.child, dists, nearest)
} }
} }
return nearest, dists return true
}
// insert obj into nearest and return the first k elements in increasing order.
func insertNearest(k int, dists []float64, nearest []Item, dist float64, obj Item) ([]float64, []Item) {
i := 0
for i < k && dist >= dists[i] {
i++
}
if i >= k {
return dists, nearest
}
left, right := dists[:i], dists[i:k-1]
updatedDists := make([]float64, k)
copy(updatedDists, left)
updatedDists[i] = dist
copy(updatedDists[i+1:], right)
leftObjs, rightObjs := nearest[:i], nearest[i:k-1]
updatedNearest := make([]Item, k)
copy(updatedNearest, leftObjs)
updatedNearest[i] = obj
copy(updatedNearest[i+1:], rightObjs)
return updatedDists, updatedNearest
}
type entrySlice struct {
entries []d3branchT
dists []float64
x, y, z float64
}
func (s entrySlice) Len() int { return len(s.entries) }
func (s entrySlice) Swap(i, j int) {
s.entries[i], s.entries[j] = s.entries[j], s.entries[i]
s.dists[i], s.dists[j] = s.dists[j], s.dists[i]
}
func (s entrySlice) Less(i, j int) bool {
return s.dists[i] < s.dists[j]
}
func sortEntries(x, y, z float64, entries []d3branchT) ([]d3branchT, []float64) {
sorted := make([]d3branchT, len(entries))
dists := make([]float64, len(entries))
for i := 0; i < len(entries); i++ {
sorted[i] = entries[i]
dists[i] = minDist(x, y, z, entries[i].rect)
}
sort.Sort(entrySlice{sorted, dists, x, y, z})
return sorted, dists
}
func pruneEntries(x, y, z float64, entries []d3branchT, minDists []float64) []d3branchT {
minMinMaxDist := math.MaxFloat64
for i := range entries {
minMaxDist := minMaxDist(x, y, z, entries[i].rect)
if minMaxDist < minMinMaxDist {
minMinMaxDist = minMaxDist
}
}
// remove all entries with minDist > minMinMaxDist
pruned := []d3branchT{}
for i := range entries {
if minDists[i] <= minMinMaxDist {
pruned = append(pruned, entries[i])
}
}
return pruned
}
// minMaxDist computes the minimum of the maximum distances from p to points
// on r. If r is the bounding box of some geometric objects, then there is
// at least one object contained in r within minMaxDist(p, r) of p.
//
// Implemented per Definition 4 of "Nearest Neighbor Queries" by
// N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995.
func minMaxDist(x, y, z float64, r d3rectT) float64 {
p := [3]float64{x, y, z}
rp := [3]float64{
float64(r.min[0]), float64(r.min[1]), float64(r.min[2]),
}
rq := [3]float64{
float64(r.max[0]), float64(r.max[1]), float64(r.max[2]),
}
// by definition, MinMaxDist(p, r) =
// min{1<=k<=n}(|pk - rmk|^2 + sum{1<=i<=n, i != k}(|pi - rMi|^2))
// where rmk and rMk are defined as follows:
rm := func(k int) float64 {
if p[k] <= (rp[k]+rq[k])/2 {
return rp[k]
}
return rq[k]
}
rM := func(k int) float64 {
if p[k] >= (rp[k]+rq[k])/2 {
return rp[k]
}
return rq[k]
}
// This formula can be computed in linear time by precomputing
// S = sum{1<=i<=n}(|pi - rMi|^2).
S := 0.0
for i := range p {
d := p[i] - rM(i)
S += d * d
}
// Compute MinMaxDist using the precomputed S.
min := math.MaxFloat64
for k := range p {
d1 := p[k] - rM(k)
d2 := p[k] - rm(k)
d := S - d1*d1 + d2*d2
if d < min {
min = d
}
}
return min
} }

View File

@ -96,7 +96,11 @@ func TestKNN(t *testing.T) {
tr.Insert(wpp(12, 19, 0)) tr.Insert(wpp(12, 19, 0))
tr.Insert(wpp(-5, 5, 0)) tr.Insert(wpp(-5, 5, 0))
tr.Insert(wpp(33, 21, 0)) tr.Insert(wpp(33, 21, 0))
items := tr.NearestNeighbors(10, x, y, z) var items []Item
tr.NearestNeighbors(x, y, z, func(item Item, dist float64) bool {
items = append(items, item)
return true
})
var res string var res string
for i, item := range items { for i, item := range items {
ix, iy, _, _, _, _ := item.Rect() ix, iy, _, _, _, _ := item.Rect()

View File

@ -2,9 +2,9 @@ package rtree
import "math" import "math"
type float float32 type float float64
const d3roundValues = true // only set to true when using 32-bit floats const d3roundValues = false // only set to true when using 32-bit floats
func d3fmin(a, b float) float { func d3fmin(a, b float) float {
if a < b { if a < b {

15
vendor/github.com/tidwall/tinyqueue/LICENSE generated vendored Normal file
View File

@ -0,0 +1,15 @@
ISC License
Copyright (c) 2017, Vladimir Agafonkin
Permission to use, copy, modify, and/or distribute this software for any purpose
with or without fee is hereby granted, provided that the above copyright notice
and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND
FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER
TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF
THIS SOFTWARE.

7
vendor/github.com/tidwall/tinyqueue/README.md generated vendored Normal file
View File

@ -0,0 +1,7 @@
# tinyqueue
<a href="https://godoc.org/github.com/tidwall/tinyqueue"><img src="https://img.shields.io/badge/api-reference-blue.svg?style=flat-square" alt="GoDoc"></a>
tinyqueue is a Go package for binary heap priority queues.
Ported from the [tinyqueue](https://github.com/mourner/tinyqueue) Javascript library.

86
vendor/github.com/tidwall/tinyqueue/tinyqueue.go generated vendored Normal file
View File

@ -0,0 +1,86 @@
package tinyqueue
type Queue struct {
length int
data []Item
}
type Item interface {
Less(Item) bool
}
func New(data []Item) *Queue {
q := &Queue{}
q.data = data
q.length = len(data)
if q.length > 0 {
i := q.length >> 1
for ; i >= 0; i-- {
q.down(i)
}
}
return q
}
func (q *Queue) Push(item Item) {
q.data = append(q.data, item)
q.length++
q.up(q.length - 1)
}
func (q *Queue) Pop() Item {
if q.length == 0 {
return nil
}
top := q.data[0]
q.length--
if q.length > 0 {
q.data[0] = q.data[q.length]
q.down(0)
}
q.data = q.data[:len(q.data)-1]
return top
}
func (q *Queue) Peek() Item {
if q.length == 0 {
return nil
}
return q.data[0]
}
func (q *Queue) Len() int {
return q.length
}
func (q *Queue) down(pos int) {
data := q.data
halfLength := q.length >> 1
item := data[pos]
for pos < halfLength {
left := (pos << 1) + 1
right := left + 1
best := data[left]
if right < q.length && data[right].Less(best) {
left = right
best = data[right]
}
if !best.Less(item) {
break
}
data[pos] = best
pos = left
}
data[pos] = item
}
func (q *Queue) up(pos int) {
data := q.data
item := data[pos]
for pos > 0 {
parent := (pos - 1) >> 1
current := data[parent]
if !item.Less(current) {
break
}
data[pos] = current
pos = parent
}
data[pos] = item
}

65
vendor/github.com/tidwall/tinyqueue/tinyqueue_test.go generated vendored Normal file
View File

@ -0,0 +1,65 @@
package tinyqueue
import (
"math/rand"
"sort"
"testing"
"time"
"github.com/json-iterator/go/assert"
)
type floatValue float64
func (a floatValue) Less(b Item) bool {
return a < b.(floatValue)
}
var data, sorted = func() ([]Item, []Item) {
rand.Seed(time.Now().UnixNano())
var data []Item
for i := 0; i < 100; i++ {
data = append(data, floatValue(rand.Float64()*100))
}
sorted := make([]Item, len(data))
copy(sorted, data)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Less(sorted[j])
})
return data, sorted
}()
func TestMaintainsPriorityQueue(t *testing.T) {
q := New(nil)
for i := 0; i < len(data); i++ {
q.Push(data[i])
}
assert.Equal(t, q.Peek(), sorted[0])
var result []Item
for q.length > 0 {
result = append(result, q.Pop())
}
assert.Equal(t, result, sorted)
}
func TestAcceptsDataInConstructor(t *testing.T) {
q := New(data)
var result []Item
for q.length > 0 {
result = append(result, q.Pop())
}
assert.Equal(t, result, sorted)
}
func TestHandlesEdgeCasesWithFewElements(t *testing.T) {
q := New(nil)
q.Push(floatValue(2))
q.Push(floatValue(1))
q.Pop()
q.Pop()
q.Pop()
q.Push(floatValue(2))
q.Push(floatValue(1))
assert.Equal(t, float64(q.Pop().(floatValue)), 1.0)
assert.Equal(t, float64(q.Pop().(floatValue)), 2.0)
assert.Equal(t, q.Pop(), nil)
}