Improve kNN behavior

The current KNN implementation has two areas that can be improved:

- The current behavior is somewhat incorrect. When performing a kNN
query, the current code fetches k items from the index, and then sorts
these items according to Haversine distance. The problem with this
approach is that since the items fetched from the index are ordered by
a Euclidean metric, there is no guarantee that item k + 1 is not closer
than item k in great circle distance, and hence incorrect results can be
returned when closer items beyond k exist.

- The secondary sort is a performance killer. This requires buffering
all k items (again...they were already run through a priority queue in)
the index, and then a sort. Since the items are mostly sorted, and
Go's sort implementation is a quickSort this is the worst case for the
sort algorithm.

Both of these can be fixed by applying a proper distance metric in
the index nearby operation. In addition, this cleans up the code
considerably, removing a number of special cases that applied only
to NEARBY operations.

This change implements a geodetic distance metric that ensures that
the order from the index is correct, eliminating the need for the
secondary sort and special filtering cases in the ScanWriter code.
This commit is contained in:
Mike Poindexter 2020-04-07 20:10:58 -07:00
parent f02dee3db2
commit 2a4272c95f
4 changed files with 161 additions and 84 deletions

View File

@ -1,6 +1,7 @@
package collection package collection
import ( import (
"math"
"runtime" "runtime"
"github.com/tidwall/btree" "github.com/tidwall/btree"
@ -697,7 +698,7 @@ func (c *Collection) Nearby(
target geojson.Object, target geojson.Object,
cursor Cursor, cursor Cursor,
deadline *deadline.Deadline, deadline *deadline.Deadline,
iter func(id string, obj geojson.Object, fields []float64) bool, iter func(id string, obj geojson.Object, fields []float64, dist float64) bool,
) bool { ) bool {
// First look to see if there's at least one candidate in the circle's // First look to see if there's at least one candidate in the circle's
// outer rectangle. This is a fast-fail operation. // outer rectangle. This is a fast-fail operation.
@ -732,18 +733,15 @@ func (c *Collection) Nearby(
cursor.Step(offset) cursor.Step(offset)
} }
c.index.Nearby( c.index.Nearby(
geoindex.SimpleBoxAlgo( geodeticDistAlgo([2]float64{center.X, center.Y}),
[2]float64{center.X, center.Y}, func(_, _ [2]float64, itemv interface{}, dist float64) bool {
[2]float64{center.X, center.Y},
),
func(_, _ [2]float64, itemv interface{}, _ float64) bool {
count++ count++
if count <= offset { if count <= offset {
return true return true
} }
nextStep(count, cursor, deadline) nextStep(count, cursor, deadline)
item := itemv.(*itemT) item := itemv.(*itemT)
alive = iter(item.id, item.obj, c.getFieldValues(item.id)) alive = iter(item.id, item.obj, c.getFieldValues(item.id), dist)
return alive return alive
}, },
) )
@ -759,3 +757,121 @@ func nextStep(step uint64, cursor Cursor, deadline *deadline.Deadline) {
cursor.Step(1) cursor.Step(1)
} }
} }
func geodeticDistAlgo(center [2]float64) func(
min, max [2]float64, data interface{}, item bool,
add func(min, max [2]float64, data interface{}, item bool, dist float64),
) {
const earthRadius = 6371e3
return func(
min, max [2]float64, data interface{}, item bool,
add func(min, max [2]float64, data interface{}, item bool, dist float64),
) {
add(min, max, data, item, earthRadius*pointRectDistGeodeticDeg(
center[1], center[0],
min[1], min[0],
max[1], max[0],
))
}
}
func pointRectDistGeodeticDeg(pLat, pLng, minLat, minLng, maxLat, maxLng float64) float64 {
result := pointRectDistGeodeticRad(
pLat*math.Pi/180, pLng*math.Pi/180,
minLat*math.Pi/180, minLng*math.Pi/180,
maxLat*math.Pi/180, maxLng*math.Pi/180,
)
return result
}
func pointRectDistGeodeticRad(φq, λq, φl, λl, φh, λh float64) float64 {
// Algorithm from:
// Schubert, E., Zimek, A., & Kriegel, H.-P. (2013).
// Geodetic Distance Queries on R-Trees for Indexing Geographic Data.
// Lecture Notes in Computer Science, 146164.
// doi:10.1007/978-3-642-40235-7_9
const (
twoΠ = 2 * math.Pi
halfΠ = math.Pi / 2
)
// distance on the unit sphere computed using Haversine formula
distRad := func(φa, λa, φb, λb float64) float64 {
if φa == φb && λa == λb {
return 0
}
Δφ := φa - φb
Δλ := λa - λb
sinΔφ := math.Sin(Δφ / 2)
sinΔλ := math.Sin(Δλ / 2)
cosφa := math.Cos(φa)
cosφb := math.Cos(φb)
return 2 * math.Asin(math.Sqrt(sinΔφ*sinΔφ+sinΔλ*sinΔλ*cosφa*cosφb))
}
// Simple case, point or invalid rect
if φl >= φh && λl >= λh {
return distRad(φl, λl, φq, λq)
}
if λl <= λq && λq <= λh { // q is north or south of r
if φl <= φq && φq <= φh { // Inside
return 0
}
if φq < φl { // South
return φl - φq
}
return φq - φh // North
}
// determine if q is closer to the east or west edge of r to select edge for
// tests below
Δλe := λl - λq
Δλw := λq - λh
if Δλe < 0 {
Δλe += twoΠ
}
if Δλw < 0 {
Δλw += twoΠ
}
var Δλ float64 // distance to closest edge
var λedge float64 // longitude of closest edge
if Δλe <= Δλw {
Δλ = Δλe
λedge = λl
} else {
Δλ = Δλw
λedge = λh
}
sinΔλ, cosΔλ := math.Sincos(Δλ)
tanφq := math.Tan(φq)
if Δλ >= halfΠ {
// If Δλ > 90 degrees (1/2 pi in radians) we're in one of the corners
// (NW/SW or NE/SE depending on the edge selected). Compare against the
// center line to decide which case we fall into
φmid := (φh + φl) / 2
if tanφq >= math.Tan(φmid)*cosΔλ {
return distRad(φq, λq, φh, λedge) // North corner
}
return distRad(φq, λq, φl, λedge) // South corner
}
if tanφq >= math.Tan(φh)*cosΔλ {
return distRad(φq, λq, φh, λedge) // North corner
}
if tanφq <= math.Tan(φl)*cosΔλ {
return distRad(φq, λq, φl, λedge) // South corner
}
// We're to the East or West of the rect, compute distance using cross-track
// Note that this is a simplification of the cross track distance formula
// valid since the track in question is a meridian.
return math.Asin(math.Cos(φq) * sinΔλ)
}

View File

@ -500,15 +500,22 @@ func TestSpatialSearch(t *testing.T) {
var items []geojson.Object var items []geojson.Object
exitems := []geojson.Object{ exitems := []geojson.Object{
r2, p1, p4, r1, p3, r3, p2, r2, p4, p1, r1, r3, p3, p2,
} }
lastDist := float64(-1)
distsMonotonic := true
c.Nearby(q4, nil, nil, c.Nearby(q4, nil, nil,
func(id string, obj geojson.Object, fields []float64) bool { func(id string, obj geojson.Object, fields []float64, dist float64) bool {
if dist < lastDist {
distsMonotonic = false
}
items = append(items, obj) items = append(items, obj)
return true return true
}, },
) )
expect(t, len(items) == 7) expect(t, len(items) == 7)
expect(t, distsMonotonic)
expect(t, reflect.DeepEqual(items, exitems)) expect(t, reflect.DeepEqual(items, exitems))
} }

View File

@ -62,14 +62,12 @@ type scanWriter struct {
// ScanWriterParams ... // ScanWriterParams ...
type ScanWriterParams struct { type ScanWriterParams struct {
id string id string
o geojson.Object o geojson.Object
fields []float64 fields []float64
distance float64 distance float64
noLock bool noLock bool
ignoreGlobMatch bool clip geojson.Object
clip geojson.Object
skipTesting bool
} }
func (s *Server) newScanWriter( func (s *Server) newScanWriter(
@ -337,13 +335,11 @@ func (sw *scanWriter) Step(n uint64) {
// ok is whether the object passes the test and should be written // ok is whether the object passes the test and should be written
// keepGoing is whether there could be more objects to test // keepGoing is whether there could be more objects to test
func (sw *scanWriter) testObject(id string, o geojson.Object, fields []float64, ignoreGlobMatch bool) ( func (sw *scanWriter) testObject(id string, o geojson.Object, fields []float64) (
ok, keepGoing bool, fieldVals []float64) { ok, keepGoing bool, fieldVals []float64) {
if !ignoreGlobMatch { match, kg := sw.globMatch(id, o)
match, kg := sw.globMatch(id, o) if !match {
if !match { return false, kg, fieldVals
return false, kg, fieldVals
}
} }
nf, ok := sw.fieldMatch(fields, o) nf, ok := sw.fieldMatch(fields, o)
return ok, true, nf return ok, true, nf
@ -355,13 +351,9 @@ func (sw *scanWriter) writeObject(opts ScanWriterParams) bool {
sw.mu.Lock() sw.mu.Lock()
defer sw.mu.Unlock() defer sw.mu.Unlock()
} }
var ok bool ok, keepGoing, _ := sw.testObject(opts.id, opts.o, opts.fields)
keepGoing := true if !ok {
if !opts.skipTesting { return keepGoing
ok, keepGoing, _ = sw.testObject(opts.id, opts.o, opts.fields, opts.ignoreGlobMatch)
if !ok {
return keepGoing
}
} }
sw.count++ sw.count++
if sw.output == outputCount { if sw.output == outputCount {

View File

@ -3,18 +3,15 @@ package server
import ( import (
"bytes" "bytes"
"errors" "errors"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/mmcloughlin/geohash" "github.com/mmcloughlin/geohash"
"github.com/tidwall/geojson" "github.com/tidwall/geojson"
"github.com/tidwall/geojson/geo"
"github.com/tidwall/geojson/geometry" "github.com/tidwall/geojson/geometry"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"github.com/tidwall/tile38/internal/bing" "github.com/tidwall/tile38/internal/bing"
"github.com/tidwall/tile38/internal/deadline"
"github.com/tidwall/tile38/internal/glob" "github.com/tidwall/tile38/internal/glob"
) )
@ -370,22 +367,29 @@ func (server *Server) cmdNearby(msg *Message) (res resp.Value, err error) {
} }
sw.writeHead() sw.writeHead()
if sw.col != nil { if sw.col != nil {
maxDist := s.obj.(*geojson.Circle).Meters()
iter := func(id string, o geojson.Object, fields []float64, dist float64) bool { iter := func(id string, o geojson.Object, fields []float64, dist float64) bool {
if server.hasExpired(s.key, id) {
return true
}
if maxDist > 0 && dist > maxDist {
return false
}
meters := 0.0 meters := 0.0
if s.distance { if s.distance {
meters = geo.DistanceFromHaversine(dist) meters = dist
} }
return sw.writeObject(ScanWriterParams{ return sw.writeObject(ScanWriterParams{
id: id, id: id,
o: o, o: o,
fields: fields, fields: fields,
distance: meters, distance: meters,
noLock: true, noLock: true,
ignoreGlobMatch: true,
skipTesting: true,
}) })
} }
server.nearestNeighbors(&s, sw, msg.Deadline, s.obj.(*geojson.Circle), iter) sw.col.Nearby(s.obj, sw, msg.Deadline, iter)
} }
sw.writeFoot() sw.writeFoot()
if msg.OutputType == JSON { if msg.OutputType == JSON {
@ -395,48 +399,6 @@ func (server *Server) cmdNearby(msg *Message) (res resp.Value, err error) {
return sw.respOut, nil return sw.respOut, nil
} }
type iterItem struct {
id string
o geojson.Object
fields []float64
dist float64
}
func (server *Server) nearestNeighbors(
s *liveFenceSwitches, sw *scanWriter, dl *deadline.Deadline,
target *geojson.Circle,
iter func(id string, o geojson.Object, fields []float64, dist float64,
) bool) {
maxDist := target.Haversine()
var items []iterItem
sw.col.Nearby(target, sw, dl, func(id string, o geojson.Object, fields []float64) bool {
if server.hasExpired(s.key, id) {
return true
}
ok, keepGoing, _ := sw.testObject(id, o, fields, false)
if !ok {
return true
}
dist := target.HaversineTo(o.Center())
if maxDist > 0 && dist > maxDist {
return false
}
items = append(items, iterItem{id: id, o: o, fields: fields, dist: dist})
if !keepGoing {
return false
}
return uint64(len(items)) < sw.limit
})
sort.Slice(items, func(i, j int) bool {
return items[i].dist < items[j].dist
})
for _, item := range items {
if !iter(item.id, item.o, item.fields, item.dist) {
return
}
}
}
func (server *Server) cmdWithin(msg *Message) (res resp.Value, err error) { func (server *Server) cmdWithin(msg *Message) (res resp.Value, err error) {
return server.cmdWithinOrIntersects("within", msg) return server.cmdWithinOrIntersects("within", msg)
} }