Merge branch 'housecanary-fix-knn'

This commit is contained in:
tidwall 2021-07-11 10:02:59 -07:00
commit 579a41abae
7 changed files with 284 additions and 96 deletions

View File

@ -1,6 +1,7 @@
package collection package collection
import ( import (
"math"
"runtime" "runtime"
"github.com/tidwall/btree" "github.com/tidwall/btree"
@ -699,7 +700,7 @@ func (c *Collection) Nearby(
target geojson.Object, target geojson.Object,
cursor Cursor, cursor Cursor,
deadline *deadline.Deadline, deadline *deadline.Deadline,
iter func(id string, obj geojson.Object, fields []float64) bool, iter func(id string, obj geojson.Object, fields []float64, dist float64) bool,
) bool { ) bool {
// First look to see if there's at least one candidate in the circle's // First look to see if there's at least one candidate in the circle's
// outer rectangle. This is a fast-fail operation. // outer rectangle. This is a fast-fail operation.
@ -734,18 +735,15 @@ func (c *Collection) Nearby(
cursor.Step(offset) cursor.Step(offset)
} }
c.index.Nearby( c.index.Nearby(
algo.Box( geodeticDistAlgo([2]float64{center.X, center.Y}),
[2]float64{center.X, center.Y}, func(_, _ [2]float64, itemv interface{}, dist float64) bool {
[2]float64{center.X, center.Y},
false, nil),
func(_, _ [2]float64, itemv interface{}, _ float64) bool {
count++ count++
if count <= offset { if count <= offset {
return true return true
} }
nextStep(count, cursor, deadline) nextStep(count, cursor, deadline)
item := itemv.(*itemT) item := itemv.(*itemT)
alive = iter(item.id, item.obj, c.fieldValues.get(item.fieldValuesSlot)) alive = iter(item.id, item.obj, c.getFieldValues(item.id), dist)
return alive return alive
}, },
) )
@ -761,3 +759,123 @@ func nextStep(step uint64, cursor Cursor, deadline *deadline.Deadline) {
cursor.Step(1) cursor.Step(1)
} }
} }
func geodeticDistAlgo(center [2]float64) func(
min, max [2]float64, data interface{}, item bool,
add func(min, max [2]float64, data interface{}, item bool, dist float64),
) {
const earthRadius = 6371e3
return func(
min, max [2]float64, data interface{}, item bool,
add func(min, max [2]float64, data interface{}, item bool, dist float64),
) {
add(min, max, data, item, earthRadius*pointRectDistGeodeticDeg(
center[1], center[0],
min[1], min[0],
max[1], max[0],
))
}
}
func pointRectDistGeodeticDeg(pLat, pLng, minLat, minLng, maxLat, maxLng float64) float64 {
result := pointRectDistGeodeticRad(
pLat*math.Pi/180, pLng*math.Pi/180,
minLat*math.Pi/180, minLng*math.Pi/180,
maxLat*math.Pi/180, maxLng*math.Pi/180,
)
return result
}
func pointRectDistGeodeticRad(φq, λq, φl, λl, φh, λh float64) float64 {
// Algorithm from:
// Schubert, E., Zimek, A., & Kriegel, H.-P. (2013).
// Geodetic Distance Queries on R-Trees for Indexing Geographic Data.
// Lecture Notes in Computer Science, 146164.
// doi:10.1007/978-3-642-40235-7_9
const (
twoΠ = 2 * math.Pi
halfΠ = math.Pi / 2
)
// distance on the unit sphere computed using Haversine formula
distRad := func(φa, λa, φb, λb float64) float64 {
if φa == φb && λa == λb {
return 0
}
Δφ := φa - φb
Δλ := λa - λb
sinΔφ := math.Sin(Δφ / 2)
sinΔλ := math.Sin(Δλ / 2)
cosφa := math.Cos(φa)
cosφb := math.Cos(φb)
return 2 * math.Asin(math.Sqrt(sinΔφ*sinΔφ+sinΔλ*sinΔλ*cosφa*cosφb))
}
// Simple case, point or invalid rect
if φl >= φh && λl >= λh {
return distRad(φl, λl, φq, λq)
}
if λl <= λq && λq <= λh {
// q is between the bounding meridians of r
// hence, q is north, south or within r
if φl <= φq && φq <= φh { // Inside
return 0
}
if φq < φl { // South
return φl - φq
}
return φq - φh // North
}
// determine if q is closer to the east or west edge of r to select edge for
// tests below
Δλe := λl - λq
Δλw := λq - λh
if Δλe < 0 {
Δλe += twoΠ
}
if Δλw < 0 {
Δλw += twoΠ
}
var Δλ float64 // distance to closest edge
var λedge float64 // longitude of closest edge
if Δλe <= Δλw {
Δλ = Δλe
λedge = λl
} else {
Δλ = Δλw
λedge = λh
}
sinΔλ, cosΔλ := math.Sincos(Δλ)
tanφq := math.Tan(φq)
if Δλ >= halfΠ {
// If Δλ > 90 degrees (1/2 pi in radians) we're in one of the corners
// (NW/SW or NE/SE depending on the edge selected). Compare against the
// center line to decide which case we fall into
φmid := (φh + φl) / 2
if tanφq >= math.Tan(φmid)*cosΔλ {
return distRad(φq, λq, φh, λedge) // North corner
}
return distRad(φq, λq, φl, λedge) // South corner
}
if tanφq >= math.Tan(φh)*cosΔλ {
return distRad(φq, λq, φh, λedge) // North corner
}
if tanφq <= math.Tan(φl)*cosΔλ {
return distRad(φq, λq, φl, λedge) // South corner
}
// We're to the East or West of the rect, compute distance using cross-track
// Note that this is a simplification of the cross track distance formula
// valid since the track in question is a meridian.
return math.Asin(math.Cos(φq) * sinΔλ)
}

View File

@ -499,15 +499,22 @@ func TestSpatialSearch(t *testing.T) {
var items []geojson.Object var items []geojson.Object
exitems := []geojson.Object{ exitems := []geojson.Object{
r2, p1, p4, r1, p3, r3, p2, r2, p4, p1, r1, r3, p3, p2,
} }
lastDist := float64(-1)
distsMonotonic := true
c.Nearby(q4, nil, nil, c.Nearby(q4, nil, nil,
func(id string, obj geojson.Object, fields []float64) bool { func(id string, obj geojson.Object, fields []float64, dist float64) bool {
if dist < lastDist {
distsMonotonic = false
}
items = append(items, obj) items = append(items, obj)
return true return true
}, },
) )
expect(t, len(items) == 7) expect(t, len(items) == 7)
expect(t, distsMonotonic)
expect(t, reflect.DeepEqual(items, exitems)) expect(t, reflect.DeepEqual(items, exitems))
} }

View File

@ -62,15 +62,12 @@ type scanWriter struct {
// ScanWriterParams ... // ScanWriterParams ...
type ScanWriterParams struct { type ScanWriterParams struct {
id string id string
o geojson.Object o geojson.Object
fields []float64 fields []float64
distance float64 distance float64
distOutput bool // query or fence requested distance output noLock bool
noLock bool clip geojson.Object
ignoreGlobMatch bool
clip geojson.Object
skipTesting bool
} }
func (s *Server) newScanWriter( func (s *Server) newScanWriter(
@ -347,13 +344,11 @@ func (sw *scanWriter) Step(n uint64) {
// ok is whether the object passes the test and should be written // ok is whether the object passes the test and should be written
// keepGoing is whether there could be more objects to test // keepGoing is whether there could be more objects to test
func (sw *scanWriter) testObject(id string, o geojson.Object, fields []float64, ignoreGlobMatch bool) ( func (sw *scanWriter) testObject(id string, o geojson.Object, fields []float64) (
ok, keepGoing bool, fieldVals []float64) { ok, keepGoing bool, fieldVals []float64) {
if !ignoreGlobMatch { match, kg := sw.globMatch(id, o)
match, kg := sw.globMatch(id, o) if !match {
if !match { return false, kg, fieldVals
return false, kg, fieldVals
}
} }
nf, ok := sw.fieldMatch(fields, o) nf, ok := sw.fieldMatch(fields, o)
return ok, true, nf return ok, true, nf
@ -365,13 +360,9 @@ func (sw *scanWriter) writeObject(opts ScanWriterParams) bool {
sw.mu.Lock() sw.mu.Lock()
defer sw.mu.Unlock() defer sw.mu.Unlock()
} }
var ok bool ok, keepGoing, _ := sw.testObject(opts.id, opts.o, opts.fields)
keepGoing := true if !ok {
if !opts.skipTesting { return keepGoing
ok, keepGoing, _ = sw.testObject(opts.id, opts.o, opts.fields, opts.ignoreGlobMatch)
if !ok {
return keepGoing
}
} }
sw.count++ sw.count++
if sw.output == outputCount { if sw.output == outputCount {

View File

@ -3,19 +3,15 @@ package server
import ( import (
"bytes" "bytes"
"errors" "errors"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/mmcloughlin/geohash" "github.com/mmcloughlin/geohash"
"github.com/tidwall/geojson" "github.com/tidwall/geojson"
"github.com/tidwall/geojson/geo"
"github.com/tidwall/geojson/geometry" "github.com/tidwall/geojson/geometry"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"github.com/tidwall/tile38/internal/bing" "github.com/tidwall/tile38/internal/bing"
"github.com/tidwall/tile38/internal/clip"
"github.com/tidwall/tile38/internal/deadline"
"github.com/tidwall/tile38/internal/glob" "github.com/tidwall/tile38/internal/glob"
) )
@ -418,23 +414,29 @@ func (server *Server) cmdNearby(msg *Message) (res resp.Value, err error) {
} }
sw.writeHead() sw.writeHead()
if sw.col != nil { if sw.col != nil {
maxDist := s.obj.(*geojson.Circle).Meters()
iter := func(id string, o geojson.Object, fields []float64, dist float64) bool { iter := func(id string, o geojson.Object, fields []float64, dist float64) bool {
if server.hasExpired(s.key, id) {
return true
}
if maxDist > 0 && dist > maxDist {
return false
}
meters := 0.0 meters := 0.0
if s.distance { if s.distance {
meters = geo.DistanceFromHaversine(dist) meters = dist
} }
return sw.writeObject(ScanWriterParams{ return sw.writeObject(ScanWriterParams{
id: id, id: id,
o: o, o: o,
fields: fields, fields: fields,
distance: meters, distance: meters,
distOutput: s.distance, noLock: true,
noLock: true,
ignoreGlobMatch: true,
skipTesting: true,
}) })
} }
server.nearestNeighbors(&s, sw, msg.Deadline, s.obj.(*geojson.Circle), iter) sw.col.Nearby(s.obj, sw, msg.Deadline, iter)
} }
sw.writeFoot() sw.writeFoot()
if msg.OutputType == JSON { if msg.OutputType == JSON {
@ -444,48 +446,6 @@ func (server *Server) cmdNearby(msg *Message) (res resp.Value, err error) {
return sw.respOut, nil return sw.respOut, nil
} }
type iterItem struct {
id string
o geojson.Object
fields []float64
dist float64
}
func (server *Server) nearestNeighbors(
s *liveFenceSwitches, sw *scanWriter, dl *deadline.Deadline,
target *geojson.Circle,
iter func(id string, o geojson.Object, fields []float64, dist float64,
) bool) {
maxDist := target.Haversine()
var items []iterItem
sw.col.Nearby(target, sw, dl, func(id string, o geojson.Object, fields []float64) bool {
if server.hasExpired(s.key, id) {
return true
}
ok, keepGoing, _ := sw.testObject(id, o, fields, false)
if !ok {
return true
}
dist := target.HaversineTo(o.Center())
if maxDist > 0 && dist > maxDist {
return false
}
items = append(items, iterItem{id: id, o: o, fields: fields, dist: dist})
if !keepGoing {
return false
}
return uint64(len(items)) < sw.limit
})
sort.Slice(items, func(i, j int) bool {
return items[i].dist < items[j].dist
})
for _, item := range items {
if !iter(item.id, item.o, item.fields, item.dist) {
return
}
}
}
func (server *Server) cmdWithin(msg *Message) (res resp.Value, err error) { func (server *Server) cmdWithin(msg *Message) (res resp.Value, err error) {
return server.cmdWithinOrIntersects("within", msg) return server.cmdWithinOrIntersects("within", msg)
} }

View File

@ -2,6 +2,7 @@ package tests
import ( import (
"fmt" "fmt"
"math/rand"
"sort" "sort"
"testing" "testing"
) )
@ -522,3 +523,19 @@ func match(expectIn string) func(org, v interface{}) (resp, expect interface{})
return fmt.Sprintf("%v", org), expectIn return fmt.Sprintf("%v", org), expectIn
} }
} }
func subBenchSearch(b *testing.B, mc *mockServer) {
runBenchStep(b, mc, "KNN", keys_KNN_bench)
}
func keys_KNN_bench(mc *mockServer) error {
lat := rand.Float64()*180 - 90
lon := rand.Float64()*360 - 180
_, err := mc.conn.Do("NEARBY",
"mykey",
"LIMIT", 50,
"DISTANCE",
"POINTS",
"POINT", lat, lon)
return err
}

View File

@ -20,15 +20,19 @@ import (
var errTimeout = errors.New("timeout") var errTimeout = errors.New("timeout")
func mockCleanup() { func mockCleanup(silent bool) {
fmt.Printf("Cleanup: may take some time... ") if !silent {
fmt.Printf("Cleanup: may take some time... ")
}
files, _ := ioutil.ReadDir(".") files, _ := ioutil.ReadDir(".")
for _, file := range files { for _, file := range files {
if strings.HasPrefix(file.Name(), "data-mock-") { if strings.HasPrefix(file.Name(), "data-mock-") {
os.RemoveAll(file.Name()) os.RemoveAll(file.Name())
} }
} }
fmt.Printf("OK\n") if !silent {
fmt.Printf("OK\n")
}
} }
type mockServer struct { type mockServer struct {
@ -39,11 +43,13 @@ type mockServer struct {
conn redis.Conn conn redis.Conn
} }
func mockOpenServer() (*mockServer, error) { func mockOpenServer(silent bool) (*mockServer, error) {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
port := rand.Int()%20000 + 20000 port := rand.Int()%20000 + 20000
dir := fmt.Sprintf("data-mock-%d", port) dir := fmt.Sprintf("data-mock-%d", port)
fmt.Printf("Starting test server at port %d\n", port) if !silent {
fmt.Printf("Starting test server at port %d\n", port)
}
logOutput := ioutil.Discard logOutput := ioutil.Discard
if os.Getenv("PRINTLOG") == "1" { if os.Getenv("PRINTLOG") == "1" {
logOutput = os.Stderr logOutput = os.Stderr

View File

@ -2,10 +2,14 @@ package tests
import ( import (
"fmt" "fmt"
"math/rand"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"testing" "testing"
"time"
"github.com/gomodule/redigo/redis"
) )
const ( const (
@ -23,18 +27,18 @@ const (
) )
func TestAll(t *testing.T) { func TestAll(t *testing.T) {
mockCleanup() mockCleanup(false)
defer mockCleanup() defer mockCleanup(false)
ch := make(chan os.Signal) ch := make(chan os.Signal)
signal.Notify(ch, os.Interrupt, syscall.SIGTERM) signal.Notify(ch, os.Interrupt, syscall.SIGTERM)
go func() { go func() {
<-ch <-ch
mockCleanup() mockCleanup(false)
os.Exit(1) os.Exit(1)
}() }()
mc, err := mockOpenServer() mc, err := mockOpenServer(false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -84,3 +88,88 @@ func runStep(t *testing.T, mc *mockServer, name string, step func(mc *mockServer
fmt.Printf("["+green+"ok"+clear+"]: %s\n", name) fmt.Printf("["+green+"ok"+clear+"]: %s\n", name)
}) })
} }
func BenchmarkAll(b *testing.B) {
mockCleanup(true)
defer mockCleanup(true)
ch := make(chan os.Signal)
signal.Notify(ch, os.Interrupt, syscall.SIGTERM)
go func() {
<-ch
mockCleanup(true)
os.Exit(1)
}()
mc, err := mockOpenServer(true)
if err != nil {
b.Fatal(err)
}
defer mc.Close()
runSubBenchmark(b, "search", mc, subBenchSearch)
}
func loadBenchmarkPoints(b *testing.B, mc *mockServer) (err error) {
const nPoints = 200000
rand.Seed(time.Now().UnixNano())
// add a bunch of points
for i := 0; i < nPoints; i++ {
val := fmt.Sprintf("val:%d", i)
var resp string
var lat, lon, fval float64
fval = rand.Float64()
lat = rand.Float64()*180 - 90
lon = rand.Float64()*360 - 180
resp, err = redis.String(mc.conn.Do("SET",
"mykey", val,
"FIELD", "foo", fval,
"POINT", lat, lon))
if err != nil {
return
}
if resp != "OK" {
err = fmt.Errorf("expected 'OK', got '%s'", resp)
return
}
}
return
}
func runSubBenchmark(b *testing.B, name string, mc *mockServer, bench func(t *testing.B, mc *mockServer)) {
b.Run(name, func(b *testing.B) {
bench(b, mc)
})
}
func runBenchStep(b *testing.B, mc *mockServer, name string, step func(mc *mockServer) error) {
b.Helper()
b.Run(name, func(b *testing.B) {
b.Helper()
if err := func() error {
// reset the current server
mc.ResetConn()
defer mc.ResetConn()
// clear the database so the test is consistent
if err := mc.DoBatch([][]interface{}{
{"OUTPUT", "resp"}, {"OK"},
{"FLUSHDB"}, {"OK"},
}); err != nil {
return err
}
err := loadBenchmarkPoints(b, mc)
if err != nil {
return err
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := step(mc); err != nil {
return err
}
}
return nil
}(); err != nil {
b.Fatal(err)
}
})
}