diff --git a/internal/collection/collection.go b/internal/collection/collection.go index 3d3cb5e5..df1f75b4 100644 --- a/internal/collection/collection.go +++ b/internal/collection/collection.go @@ -1,6 +1,7 @@ package collection import ( + "math" "runtime" "github.com/tidwall/btree" @@ -699,7 +700,7 @@ func (c *Collection) Nearby( target geojson.Object, cursor Cursor, deadline *deadline.Deadline, - iter func(id string, obj geojson.Object, fields []float64) bool, + iter func(id string, obj geojson.Object, fields []float64, dist float64) bool, ) bool { // First look to see if there's at least one candidate in the circle's // outer rectangle. This is a fast-fail operation. @@ -734,18 +735,15 @@ func (c *Collection) Nearby( cursor.Step(offset) } c.index.Nearby( - algo.Box( - [2]float64{center.X, center.Y}, - [2]float64{center.X, center.Y}, - false, nil), - func(_, _ [2]float64, itemv interface{}, _ float64) bool { + geodeticDistAlgo([2]float64{center.X, center.Y}), + func(_, _ [2]float64, itemv interface{}, dist float64) bool { count++ if count <= offset { return true } nextStep(count, cursor, deadline) item := itemv.(*itemT) - alive = iter(item.id, item.obj, c.fieldValues.get(item.fieldValuesSlot)) + alive = iter(item.id, item.obj, c.getFieldValues(item.id), dist) return alive }, ) @@ -761,3 +759,123 @@ func nextStep(step uint64, cursor Cursor, deadline *deadline.Deadline) { cursor.Step(1) } } + +func geodeticDistAlgo(center [2]float64) func( + min, max [2]float64, data interface{}, item bool, + add func(min, max [2]float64, data interface{}, item bool, dist float64), +) { + const earthRadius = 6371e3 + return func( + min, max [2]float64, data interface{}, item bool, + add func(min, max [2]float64, data interface{}, item bool, dist float64), + ) { + add(min, max, data, item, earthRadius*pointRectDistGeodeticDeg( + center[1], center[0], + min[1], min[0], + max[1], max[0], + )) + } +} + +func pointRectDistGeodeticDeg(pLat, pLng, minLat, minLng, maxLat, maxLng float64) float64 { + result := pointRectDistGeodeticRad( + pLat*math.Pi/180, pLng*math.Pi/180, + minLat*math.Pi/180, minLng*math.Pi/180, + maxLat*math.Pi/180, maxLng*math.Pi/180, + ) + return result +} + +func pointRectDistGeodeticRad(φq, λq, φl, λl, φh, λh float64) float64 { + // Algorithm from: + // Schubert, E., Zimek, A., & Kriegel, H.-P. (2013). + // Geodetic Distance Queries on R-Trees for Indexing Geographic Data. + // Lecture Notes in Computer Science, 146–164. + // doi:10.1007/978-3-642-40235-7_9 + const ( + twoΠ = 2 * math.Pi + halfΠ = math.Pi / 2 + ) + + // distance on the unit sphere computed using Haversine formula + distRad := func(φa, λa, φb, λb float64) float64 { + if φa == φb && λa == λb { + return 0 + } + + Δφ := φa - φb + Δλ := λa - λb + sinΔφ := math.Sin(Δφ / 2) + sinΔλ := math.Sin(Δλ / 2) + cosφa := math.Cos(φa) + cosφb := math.Cos(φb) + + return 2 * math.Asin(math.Sqrt(sinΔφ*sinΔφ+sinΔλ*sinΔλ*cosφa*cosφb)) + } + + // Simple case, point or invalid rect + if φl >= φh && λl >= λh { + return distRad(φl, λl, φq, λq) + } + + if λl <= λq && λq <= λh { + // q is between the bounding meridians of r + // hence, q is north, south or within r + if φl <= φq && φq <= φh { // Inside + return 0 + } + + if φq < φl { // South + return φl - φq + } + + return φq - φh // North + } + + // determine if q is closer to the east or west edge of r to select edge for + // tests below + Δλe := λl - λq + Δλw := λq - λh + if Δλe < 0 { + Δλe += twoΠ + } + if Δλw < 0 { + Δλw += twoΠ + } + var Δλ float64 // distance to closest edge + var λedge float64 // longitude of closest edge + if Δλe <= Δλw { + Δλ = Δλe + λedge = λl + } else { + Δλ = Δλw + λedge = λh + } + + sinΔλ, cosΔλ := math.Sincos(Δλ) + tanφq := math.Tan(φq) + + if Δλ >= halfΠ { + // If Δλ > 90 degrees (1/2 pi in radians) we're in one of the corners + // (NW/SW or NE/SE depending on the edge selected). Compare against the + // center line to decide which case we fall into + φmid := (φh + φl) / 2 + if tanφq >= math.Tan(φmid)*cosΔλ { + return distRad(φq, λq, φh, λedge) // North corner + } + return distRad(φq, λq, φl, λedge) // South corner + } + + if tanφq >= math.Tan(φh)*cosΔλ { + return distRad(φq, λq, φh, λedge) // North corner + } + + if tanφq <= math.Tan(φl)*cosΔλ { + return distRad(φq, λq, φl, λedge) // South corner + } + + // We're to the East or West of the rect, compute distance using cross-track + // Note that this is a simplification of the cross track distance formula + // valid since the track in question is a meridian. + return math.Asin(math.Cos(φq) * sinΔλ) +} diff --git a/internal/collection/collection_test.go b/internal/collection/collection_test.go index 2f5b0ceb..70edd6a5 100644 --- a/internal/collection/collection_test.go +++ b/internal/collection/collection_test.go @@ -499,15 +499,22 @@ func TestSpatialSearch(t *testing.T) { var items []geojson.Object exitems := []geojson.Object{ - r2, p1, p4, r1, p3, r3, p2, + r2, p4, p1, r1, r3, p3, p2, } + + lastDist := float64(-1) + distsMonotonic := true c.Nearby(q4, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, fields []float64, dist float64) bool { + if dist < lastDist { + distsMonotonic = false + } items = append(items, obj) return true }, ) expect(t, len(items) == 7) + expect(t, distsMonotonic) expect(t, reflect.DeepEqual(items, exitems)) } diff --git a/internal/server/scanner.go b/internal/server/scanner.go index 7e157490..6ec6dc2c 100644 --- a/internal/server/scanner.go +++ b/internal/server/scanner.go @@ -62,15 +62,12 @@ type scanWriter struct { // ScanWriterParams ... type ScanWriterParams struct { - id string - o geojson.Object - fields []float64 - distance float64 - distOutput bool // query or fence requested distance output - noLock bool - ignoreGlobMatch bool - clip geojson.Object - skipTesting bool + id string + o geojson.Object + fields []float64 + distance float64 + noLock bool + clip geojson.Object } func (s *Server) newScanWriter( @@ -347,13 +344,11 @@ func (sw *scanWriter) Step(n uint64) { // ok is whether the object passes the test and should be written // keepGoing is whether there could be more objects to test -func (sw *scanWriter) testObject(id string, o geojson.Object, fields []float64, ignoreGlobMatch bool) ( +func (sw *scanWriter) testObject(id string, o geojson.Object, fields []float64) ( ok, keepGoing bool, fieldVals []float64) { - if !ignoreGlobMatch { - match, kg := sw.globMatch(id, o) - if !match { - return false, kg, fieldVals - } + match, kg := sw.globMatch(id, o) + if !match { + return false, kg, fieldVals } nf, ok := sw.fieldMatch(fields, o) return ok, true, nf @@ -365,13 +360,9 @@ func (sw *scanWriter) writeObject(opts ScanWriterParams) bool { sw.mu.Lock() defer sw.mu.Unlock() } - var ok bool - keepGoing := true - if !opts.skipTesting { - ok, keepGoing, _ = sw.testObject(opts.id, opts.o, opts.fields, opts.ignoreGlobMatch) - if !ok { - return keepGoing - } + ok, keepGoing, _ := sw.testObject(opts.id, opts.o, opts.fields) + if !ok { + return keepGoing } sw.count++ if sw.output == outputCount { diff --git a/internal/server/search.go b/internal/server/search.go index 75e59d5d..7d4b3e90 100644 --- a/internal/server/search.go +++ b/internal/server/search.go @@ -3,19 +3,15 @@ package server import ( "bytes" "errors" - "sort" "strconv" "strings" "time" "github.com/mmcloughlin/geohash" "github.com/tidwall/geojson" - "github.com/tidwall/geojson/geo" "github.com/tidwall/geojson/geometry" "github.com/tidwall/resp" "github.com/tidwall/tile38/internal/bing" - "github.com/tidwall/tile38/internal/clip" - "github.com/tidwall/tile38/internal/deadline" "github.com/tidwall/tile38/internal/glob" ) @@ -418,23 +414,29 @@ func (server *Server) cmdNearby(msg *Message) (res resp.Value, err error) { } sw.writeHead() if sw.col != nil { + maxDist := s.obj.(*geojson.Circle).Meters() iter := func(id string, o geojson.Object, fields []float64, dist float64) bool { + if server.hasExpired(s.key, id) { + return true + } + + if maxDist > 0 && dist > maxDist { + return false + } + meters := 0.0 if s.distance { - meters = geo.DistanceFromHaversine(dist) + meters = dist } return sw.writeObject(ScanWriterParams{ - id: id, - o: o, - fields: fields, - distance: meters, - distOutput: s.distance, - noLock: true, - ignoreGlobMatch: true, - skipTesting: true, + id: id, + o: o, + fields: fields, + distance: meters, + noLock: true, }) } - server.nearestNeighbors(&s, sw, msg.Deadline, s.obj.(*geojson.Circle), iter) + sw.col.Nearby(s.obj, sw, msg.Deadline, iter) } sw.writeFoot() if msg.OutputType == JSON { @@ -444,48 +446,6 @@ func (server *Server) cmdNearby(msg *Message) (res resp.Value, err error) { return sw.respOut, nil } -type iterItem struct { - id string - o geojson.Object - fields []float64 - dist float64 -} - -func (server *Server) nearestNeighbors( - s *liveFenceSwitches, sw *scanWriter, dl *deadline.Deadline, - target *geojson.Circle, - iter func(id string, o geojson.Object, fields []float64, dist float64, - ) bool) { - maxDist := target.Haversine() - var items []iterItem - sw.col.Nearby(target, sw, dl, func(id string, o geojson.Object, fields []float64) bool { - if server.hasExpired(s.key, id) { - return true - } - ok, keepGoing, _ := sw.testObject(id, o, fields, false) - if !ok { - return true - } - dist := target.HaversineTo(o.Center()) - if maxDist > 0 && dist > maxDist { - return false - } - items = append(items, iterItem{id: id, o: o, fields: fields, dist: dist}) - if !keepGoing { - return false - } - return uint64(len(items)) < sw.limit - }) - sort.Slice(items, func(i, j int) bool { - return items[i].dist < items[j].dist - }) - for _, item := range items { - if !iter(item.id, item.o, item.fields, item.dist) { - return - } - } -} - func (server *Server) cmdWithin(msg *Message) (res resp.Value, err error) { return server.cmdWithinOrIntersects("within", msg) } diff --git a/tests/keys_search_test.go b/tests/keys_search_test.go index 8b5d3fa2..82b8c573 100644 --- a/tests/keys_search_test.go +++ b/tests/keys_search_test.go @@ -2,6 +2,7 @@ package tests import ( "fmt" + "math/rand" "sort" "testing" ) @@ -522,3 +523,19 @@ func match(expectIn string) func(org, v interface{}) (resp, expect interface{}) return fmt.Sprintf("%v", org), expectIn } } + +func subBenchSearch(b *testing.B, mc *mockServer) { + runBenchStep(b, mc, "KNN", keys_KNN_bench) +} + +func keys_KNN_bench(mc *mockServer) error { + lat := rand.Float64()*180 - 90 + lon := rand.Float64()*360 - 180 + _, err := mc.conn.Do("NEARBY", + "mykey", + "LIMIT", 50, + "DISTANCE", + "POINTS", + "POINT", lat, lon) + return err +} diff --git a/tests/mock_test.go b/tests/mock_test.go index 9855c1e7..327629eb 100644 --- a/tests/mock_test.go +++ b/tests/mock_test.go @@ -20,15 +20,19 @@ import ( var errTimeout = errors.New("timeout") -func mockCleanup() { - fmt.Printf("Cleanup: may take some time... ") +func mockCleanup(silent bool) { + if !silent { + fmt.Printf("Cleanup: may take some time... ") + } files, _ := ioutil.ReadDir(".") for _, file := range files { if strings.HasPrefix(file.Name(), "data-mock-") { os.RemoveAll(file.Name()) } } - fmt.Printf("OK\n") + if !silent { + fmt.Printf("OK\n") + } } type mockServer struct { @@ -39,11 +43,13 @@ type mockServer struct { conn redis.Conn } -func mockOpenServer() (*mockServer, error) { +func mockOpenServer(silent bool) (*mockServer, error) { rand.Seed(time.Now().UnixNano()) port := rand.Int()%20000 + 20000 dir := fmt.Sprintf("data-mock-%d", port) - fmt.Printf("Starting test server at port %d\n", port) + if !silent { + fmt.Printf("Starting test server at port %d\n", port) + } logOutput := ioutil.Discard if os.Getenv("PRINTLOG") == "1" { logOutput = os.Stderr diff --git a/tests/tests_test.go b/tests/tests_test.go index 5ce39a1d..b4f1b3ee 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -2,10 +2,14 @@ package tests import ( "fmt" + "math/rand" "os" "os/signal" "syscall" "testing" + "time" + + "github.com/gomodule/redigo/redis" ) const ( @@ -23,18 +27,18 @@ const ( ) func TestAll(t *testing.T) { - mockCleanup() - defer mockCleanup() + mockCleanup(false) + defer mockCleanup(false) ch := make(chan os.Signal) signal.Notify(ch, os.Interrupt, syscall.SIGTERM) go func() { <-ch - mockCleanup() + mockCleanup(false) os.Exit(1) }() - mc, err := mockOpenServer() + mc, err := mockOpenServer(false) if err != nil { t.Fatal(err) } @@ -84,3 +88,88 @@ func runStep(t *testing.T, mc *mockServer, name string, step func(mc *mockServer fmt.Printf("["+green+"ok"+clear+"]: %s\n", name) }) } + +func BenchmarkAll(b *testing.B) { + mockCleanup(true) + defer mockCleanup(true) + + ch := make(chan os.Signal) + signal.Notify(ch, os.Interrupt, syscall.SIGTERM) + go func() { + <-ch + mockCleanup(true) + os.Exit(1) + }() + + mc, err := mockOpenServer(true) + if err != nil { + b.Fatal(err) + } + defer mc.Close() + runSubBenchmark(b, "search", mc, subBenchSearch) +} + +func loadBenchmarkPoints(b *testing.B, mc *mockServer) (err error) { + const nPoints = 200000 + rand.Seed(time.Now().UnixNano()) + + // add a bunch of points + for i := 0; i < nPoints; i++ { + val := fmt.Sprintf("val:%d", i) + var resp string + var lat, lon, fval float64 + fval = rand.Float64() + lat = rand.Float64()*180 - 90 + lon = rand.Float64()*360 - 180 + resp, err = redis.String(mc.conn.Do("SET", + "mykey", val, + "FIELD", "foo", fval, + "POINT", lat, lon)) + if err != nil { + return + } + if resp != "OK" { + err = fmt.Errorf("expected 'OK', got '%s'", resp) + return + } + } + return +} + +func runSubBenchmark(b *testing.B, name string, mc *mockServer, bench func(t *testing.B, mc *mockServer)) { + b.Run(name, func(b *testing.B) { + bench(b, mc) + }) +} + +func runBenchStep(b *testing.B, mc *mockServer, name string, step func(mc *mockServer) error) { + b.Helper() + b.Run(name, func(b *testing.B) { + b.Helper() + if err := func() error { + // reset the current server + mc.ResetConn() + defer mc.ResetConn() + // clear the database so the test is consistent + if err := mc.DoBatch([][]interface{}{ + {"OUTPUT", "resp"}, {"OK"}, + {"FLUSHDB"}, {"OK"}, + }); err != nil { + return err + } + err := loadBenchmarkPoints(b, mc) + if err != nil { + return err + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := step(mc); err != nil { + return err + } + } + return nil + }(); err != nil { + b.Fatal(err) + } + }) +}