mirror of https://github.com/tidwall/rtred.git
226 lines
4.5 KiB
Go
226 lines
4.5 KiB
Go
package rtred
|
|
|
|
import (
|
|
"fmt"
|
|
"math/rand"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type tRect []float64
|
|
|
|
func (r *tRect) Arr() []float64 {
|
|
return []float64(*r)
|
|
}
|
|
func (r *tRect) Rect(ctx interface{}) (min, max []float64) {
|
|
return r.Arr()[:len(r.Arr())/2], r.Arr()[len(r.Arr())/2:]
|
|
}
|
|
|
|
func (r *tRect) String() string {
|
|
min, max := r.Rect(nil)
|
|
return fmt.Sprintf("%v,%v", min, max)
|
|
}
|
|
func tRandRect(dims int) *tRect {
|
|
if dims == -1 {
|
|
dims = rand.Int()%4 + 1
|
|
}
|
|
r := tRect(make([]float64, dims*2))
|
|
for j := 0; j < dims; j++ {
|
|
minf := rand.Float64()*200 - 100
|
|
maxf := rand.Float64()*200 - 100
|
|
if minf > maxf {
|
|
minf, maxf = maxf, minf
|
|
}
|
|
r[j] = minf
|
|
r[dims+j] = maxf
|
|
}
|
|
return &r
|
|
}
|
|
|
|
type tPoint struct {
|
|
x, y float64
|
|
}
|
|
|
|
func (r *tPoint) Rect(ctx interface{}) (min, max []float64) {
|
|
return []float64{r.x, r.y}, []float64{r.x, r.y}
|
|
}
|
|
func tRandPoint() *tPoint {
|
|
return &tPoint{
|
|
rand.Float64()*200 - 100,
|
|
rand.Float64()*200 - 100,
|
|
}
|
|
}
|
|
func TestRTree(t *testing.T) {
|
|
tr := New("hello")
|
|
zeroPoint := &tRect{0, 0, 0, 0}
|
|
tr.Insert(&tRect{10, 10, 10, 10, 20, 20, 20, 20})
|
|
tr.Insert(&tRect{10, 10, 10, 20, 20, 20})
|
|
tr.Insert(&tRect{10, 10, 20, 20})
|
|
tr.Insert(&tRect{10, 20})
|
|
tr.Insert(zeroPoint)
|
|
if tr.Count() != 5 {
|
|
t.Fatalf("expecting %v, got %v", 5, tr.Count())
|
|
}
|
|
|
|
var count int
|
|
tr.Search(&tRect{0, 0, 0, 100, 100, 5}, func(item Item) bool {
|
|
count++
|
|
return true
|
|
})
|
|
|
|
if count != 3 {
|
|
t.Fatalf("expecting %v, got %v", 3, count)
|
|
}
|
|
tr.Remove(zeroPoint)
|
|
count = 0
|
|
tr.Search(&tRect{0, 0, 0, 100, 100, 5}, func(item Item) bool {
|
|
count++
|
|
return true
|
|
})
|
|
if count != 2 {
|
|
t.Fatalf("expecting %v, got %v", 2, count)
|
|
}
|
|
}
|
|
|
|
func TestInsertDelete(t *testing.T) {
|
|
rand.Seed(time.Now().UnixNano())
|
|
n := 50000
|
|
tr := New(nil)
|
|
var r2arr []*tRect
|
|
for i := 0; i < n; i++ {
|
|
r := tRandRect(-1)
|
|
if len(r.Arr()) == 4 {
|
|
r2arr = append(r2arr, r)
|
|
}
|
|
tr.Insert(r)
|
|
}
|
|
if tr.Count() != n {
|
|
t.Fatalf("expecting %v, got %v", n, tr.Count())
|
|
}
|
|
var count int
|
|
tr.Search(&tRect{-100, -100, -100, -100, 100, 100, 100, 100}, func(item Item) bool {
|
|
if len(item.(*tRect).Arr()) == 4 {
|
|
count++
|
|
}
|
|
return true
|
|
})
|
|
p := float64(count) / float64(n)
|
|
|
|
if p < .23 || p > .27 {
|
|
t.Fatalf("bad random range, expected between 0.24-0.26, got %v", p)
|
|
}
|
|
for _, i := range rand.Perm(len(r2arr)) {
|
|
tr.Remove(r2arr[i])
|
|
}
|
|
total := tr.Count() + count
|
|
if total != n {
|
|
t.Fatalf("expected %v, got %v", n, total)
|
|
}
|
|
}
|
|
func TestPoints(t *testing.T) {
|
|
rand.Seed(time.Now().UnixNano())
|
|
n := 25000
|
|
tr := New(nil)
|
|
var points []*tPoint
|
|
for i := 0; i < n; i++ {
|
|
r := tRandPoint()
|
|
points = append(points, r)
|
|
tr.Insert(r)
|
|
}
|
|
if tr.Count() != n {
|
|
t.Fatalf("expecting %v, got %v", n, tr.Count())
|
|
}
|
|
var count int
|
|
tr.Search(&tRect{-100, -100, -100, -100, 100, 100, 100, 100}, func(item Item) bool {
|
|
count++
|
|
return true
|
|
})
|
|
|
|
if count != n {
|
|
t.Fatalf("expecting %v, got %v", n, count)
|
|
}
|
|
for _, i := range rand.Perm(len(points)) {
|
|
tr.Remove(points[i])
|
|
}
|
|
total := tr.Count() + count
|
|
if total != n {
|
|
t.Fatalf("expected %v, got %v", n, total)
|
|
}
|
|
}
|
|
func BenchmarkInsert(t *testing.B) {
|
|
t.StopTimer()
|
|
rand.Seed(time.Now().UnixNano())
|
|
tr := New(nil)
|
|
var points []*tPoint
|
|
for i := 0; i < t.N; i++ {
|
|
points = append(points, tRandPoint())
|
|
}
|
|
t.StartTimer()
|
|
for i := 0; i < t.N; i++ {
|
|
tr.Insert(points[i])
|
|
}
|
|
t.StopTimer()
|
|
count := tr.Count()
|
|
if count != t.N {
|
|
t.Fatalf("expected %v, got %v", t.N, count)
|
|
}
|
|
|
|
t.StartTimer()
|
|
}
|
|
|
|
func TestKNN(t *testing.T) {
|
|
n := 25000
|
|
tr := New(nil)
|
|
var points []*tPoint
|
|
rand.Seed(1)
|
|
for i := 0; i < n; i++ {
|
|
r := tRandPoint()
|
|
points = append(points, r)
|
|
tr.Insert(r)
|
|
}
|
|
if tr.Count() != n {
|
|
t.Fatalf("expecting %v, got %v", n, tr.Count())
|
|
}
|
|
var count int
|
|
tr.Search(&tRect{-100, -100, -100, -100, 100, 100, 100, 100}, func(item Item) bool {
|
|
count++
|
|
return true
|
|
})
|
|
var pdist float64
|
|
var i int
|
|
center := []float64{50, 50}
|
|
centerRect := &tRect{center[0], center[1], center[0], center[1]}
|
|
tr.KNN(centerRect, true, func(item Item, dist float64) bool {
|
|
dist2 := boxDistPoint(center, item)
|
|
if i > 0 && dist2 < pdist {
|
|
t.Fatal("out of order")
|
|
}
|
|
pdist = dist
|
|
i++
|
|
return true
|
|
})
|
|
if i != n {
|
|
t.Fatal("mismatch")
|
|
}
|
|
}
|
|
|
|
func boxDistPoint(point []float64, item Item) float64 {
|
|
var dist float64
|
|
min, max := item.Rect(nil)
|
|
for i := 0; i < len(point); i++ {
|
|
d := axisDist(point[i], min[i], max[i])
|
|
dist += d * d
|
|
}
|
|
return dist
|
|
}
|
|
|
|
func axisDist(k, min, max float64) float64 {
|
|
if k < min {
|
|
return min - k
|
|
}
|
|
if k <= max {
|
|
return 0
|
|
}
|
|
return k - max
|
|
}
|