Switch BTree implementation

This commit is contained in:
tidwall 2018-08-16 13:07:55 -07:00
parent 2e7130cda0
commit 3ae26e3479
6 changed files with 900 additions and 54 deletions

View File

@ -4,6 +4,7 @@ import (
"math"
"github.com/tidwall/btree"
"github.com/tidwall/tile38/pkg/ds"
"github.com/tidwall/tile38/pkg/geojson"
"github.com/tidwall/tile38/pkg/index"
)
@ -49,7 +50,7 @@ func (i *itemT) Point() (x, y float64) {
// Collection represents a collection of geojson objects.
type Collection struct {
items *btree.BTree // items sorted by keys
items ds.BTree // items sorted by keys
values *btree.BTree // items sorted by value+key
index *index.Index // items geospatially indexed
fieldMap map[string]int
@ -66,8 +67,7 @@ var counter uint64
func New() *Collection {
col := &Collection{
index: index.New(),
items: btree.New(128, idOrdered),
values: btree.New(128, valueOrdered),
values: btree.New(16, valueOrdered),
fieldMap: make(map[string]int),
}
return col
@ -117,15 +117,15 @@ func (c *Collection) Bounds() (minX, minY, maxX, maxY float64) {
return c.index.Bounds()
}
// ReplaceOrInsert adds or replaces an object in the collection and returns the fields array.
// Set adds or replaces an object in the collection and returns the fields array.
// If an item with the same id is already in the collection then the new item will adopt the old item's fields.
// The fields argument is optional.
// The return values are the old object, the old fields, and the new fields
func (c *Collection) ReplaceOrInsert(id string, obj geojson.Object, fields []string, values []float64) (oldObject geojson.Object, oldFields []float64, newFields []float64) {
func (c *Collection) Set(id string, obj geojson.Object, fields []string, values []float64) (oldObject geojson.Object, oldFields []float64, newFields []float64) {
var oldItem *itemT
var newItem *itemT = &itemT{id: id, object: obj}
newItem := &itemT{id: id, object: obj}
// add the new item to main btree and remove the old one if needed
oldItemPtr := c.items.ReplaceOrInsert(newItem)
oldItemPtr, _ := c.items.Set(id, newItem)
if oldItemPtr != nil {
// the old item was removed, now let's remove from the rtree
// or strings tree.
@ -186,14 +186,14 @@ func (c *Collection) ReplaceOrInsert(id string, obj geojson.Object, fields []str
return oldObject, oldFields, newFields
}
// Remove removes an object and returns it.
// Delete removes an object and returns it.
// If the object does not exist then the 'ok' return value will be false.
func (c *Collection) Remove(id string) (obj geojson.Object, fields []float64, ok bool) {
i := c.items.Delete(&itemT{id: id})
if i == nil {
func (c *Collection) Delete(id string) (obj geojson.Object, fields []float64, ok bool) {
old, _ := c.items.Delete(id)
if old == nil {
return nil, nil, false
}
item := i.(*itemT)
item := old.(*itemT)
if item.object.IsGeometry() {
c.index.Remove(item)
c.objects--
@ -212,43 +212,43 @@ func (c *Collection) Remove(id string) (obj geojson.Object, fields []float64, ok
// Get returns an object.
// If the object does not exist then the 'ok' return value will be false.
func (c *Collection) Get(id string) (obj geojson.Object, fields []float64, ok bool) {
i := c.items.Get(&itemT{id: id})
if i == nil {
val, _ := c.items.Get(id)
if val == nil {
return nil, nil, false
}
item := i.(*itemT)
item := val.(*itemT)
return item.object, c.getFieldValues(id), true
}
// SetField set a field value for an object and returns that object.
// If the object does not exist then the 'ok' return value will be false.
func (c *Collection) SetField(id, field string, value float64) (obj geojson.Object, fields []float64, updated bool, ok bool) {
i := c.items.Get(&itemT{id: id})
if i == nil {
val, _ := c.items.Get(id)
if val == nil {
ok = false
return
}
item := i.(*itemT)
item := val.(*itemT)
updated = c.setField(item, field, value)
return item.object, c.getFieldValues(id), updated, true
}
// SetFields is similar to SetField, just setting multiple fields at once
func (c *Collection) SetFields(id string, in_fields []string, in_values []float64) (
obj geojson.Object, fields []float64, updated_count int, ok bool,
func (c *Collection) SetFields(id string, inFields []string, inValues []float64) (
obj geojson.Object, fields []float64, updatedCount int, ok bool,
) {
i := c.items.Get(&itemT{id: id})
if i == nil {
val, _ := c.items.Get(id)
if val == nil {
ok = false
return
}
item := i.(*itemT)
for idx, field := range in_fields {
if c.setField(item, field, in_values[idx]) {
updated_count++
item := val.(*itemT)
for idx, field := range inFields {
if c.setField(item, field, inValues[idx]) {
updatedCount++
}
}
return item.object, c.getFieldValues(id), updated_count, true
return item.object, c.getFieldValues(id), updatedCount, true
}
func (c *Collection) setField(item *itemT, field string, value float64) (updated bool) {
@ -288,34 +288,43 @@ func (c *Collection) Scan(desc bool,
iterator func(id string, obj geojson.Object, fields []float64) bool,
) bool {
var keepon = true
iter := func(item btree.Item) bool {
iitm := item.(*itemT)
iter := func(key string, value interface{}) bool {
iitm := value.(*itemT)
keepon = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id))
return keepon
}
if desc {
c.items.Descend(iter)
c.items.Reverse(iter)
} else {
c.items.Ascend(iter)
c.items.Scan(iter)
}
return keepon
}
// ScanGreaterOrEqual iterates though the collection starting with specified id.
// ScanRange iterates though the collection starting with specified id.
func (c *Collection) ScanRange(start, end string, desc bool,
iterator func(id string, obj geojson.Object, fields []float64) bool,
) bool {
var keepon = true
iter := func(item btree.Item) bool {
iitm := item.(*itemT)
iter := func(key string, value interface{}) bool {
if !desc {
if key >= end {
return false
}
} else {
if key <= end {
return false
}
}
iitm := value.(*itemT)
keepon = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id))
return keepon
}
if desc {
c.items.DescendRange(&itemT{id: start}, &itemT{id: end}, iter)
c.items.Descend(start, iter)
} else {
c.items.AscendRange(&itemT{id: start}, &itemT{id: end}, iter)
c.items.Ascend(start, iter)
}
return keepon
}
@ -361,15 +370,15 @@ func (c *Collection) ScanGreaterOrEqual(id string, desc bool,
iterator func(id string, obj geojson.Object, fields []float64) bool,
) bool {
var keepon = true
iter := func(item btree.Item) bool {
iitm := item.(*itemT)
iter := func(key string, value interface{}) bool {
iitm := value.(*itemT)
keepon = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id))
return keepon
}
if desc {
c.items.DescendLessOrEqual(&itemT{id: id}, iter)
c.items.Descend(id, iter)
} else {
c.items.AscendGreaterOrEqual(&itemT{id: id}, iter)
c.items.Ascend(id, iter)
}
return keepon
}
@ -593,7 +602,11 @@ func (c *Collection) Intersects(
})
}
func (c *Collection) NearestNeighbors(lat, lon float64, iterator func(id string, obj geojson.Object, fields []float64) bool) bool {
// NearestNeighbors returns the nearest neighbors
func (c *Collection) NearestNeighbors(
lat, lon float64,
iterator func(id string, obj geojson.Object, fields []float64) bool,
) bool {
return c.index.KNN(lon, lat, func(item interface{}) bool {
var iitm *itemT
iitm, ok := item.(*itemT)

View File

@ -28,7 +28,7 @@ func TestCollection(t *testing.T) {
}}
}
objs[id] = obj
c.ReplaceOrInsert(id, obj, nil, nil)
c.Set(id, obj, nil, nil)
}
count := 0
bbox := geojson.BBox{Min: geojson.Position{X: -180, Y: -90, Z: 0}, Max: geojson.Position{X: 180, Y: 90, Z: 0}}
@ -76,7 +76,7 @@ func TestManyCollections(t *testing.T) {
col = New()
colsM[key] = col
}
col.ReplaceOrInsert(id, obj, nil, nil)
col.Set(id, obj, nil, nil)
k++
}
}
@ -110,7 +110,7 @@ func BenchmarkInsert(t *testing.B) {
col := New()
t.ResetTimer()
for i := 0; i < t.N; i++ {
col.ReplaceOrInsert(items[i].id, items[i].object, nil, nil)
col.Set(items[i].id, items[i].object, nil, nil)
}
}
@ -128,11 +128,11 @@ func BenchmarkReplace(t *testing.B) {
}
col := New()
for i := 0; i < t.N; i++ {
col.ReplaceOrInsert(items[i].id, items[i].object, nil, nil)
col.Set(items[i].id, items[i].object, nil, nil)
}
t.ResetTimer()
for _, i := range rand.Perm(t.N) {
o, _, _ := col.ReplaceOrInsert(items[i].id, items[i].object, nil, nil)
o, _, _ := col.Set(items[i].id, items[i].object, nil, nil)
if o != items[i].object {
t.Fatal("shoot!")
}
@ -153,7 +153,7 @@ func BenchmarkGet(t *testing.B) {
}
col := New()
for i := 0; i < t.N; i++ {
col.ReplaceOrInsert(items[i].id, items[i].object, nil, nil)
col.Set(items[i].id, items[i].object, nil, nil)
}
t.ResetTimer()
for _, i := range rand.Perm(t.N) {
@ -178,11 +178,11 @@ func BenchmarkRemove(t *testing.B) {
}
col := New()
for i := 0; i < t.N; i++ {
col.ReplaceOrInsert(items[i].id, items[i].object, nil, nil)
col.Set(items[i].id, items[i].object, nil, nil)
}
t.ResetTimer()
for _, i := range rand.Perm(t.N) {
o, _, _ := col.Remove(items[i].id)
o, _, _ := col.Delete(items[i].id)
if o != items[i].object {
t.Fatal("shoot!")
}

View File

@ -307,7 +307,7 @@ func (c *Controller) cmdDel(msg *server.Message) (res resp.Value, d commandDetai
found := false
col := c.getCol(d.key)
if col != nil {
d.obj, d.fields, ok = col.Remove(d.id)
d.obj, d.fields, ok = col.Delete(d.id)
if ok {
if col.Count() == 0 {
c.deleteCol(d.key)
@ -373,7 +373,7 @@ func (c *Controller) cmdPdel(msg *server.Message) (res resp.Value, d commandDeta
}
var atLeastOneNotDeleted bool
for i, dc := range d.children {
dc.obj, dc.fields, ok = col.Remove(dc.id)
dc.obj, dc.fields, ok = col.Delete(dc.id)
if !ok {
d.children[i].command = "?"
atLeastOneNotDeleted = true
@ -740,7 +740,7 @@ func (c *Controller) cmdSet(msg *server.Message) (res resp.Value, d commandDetai
}
}
c.clearIDExpires(d.key, d.id)
d.oldObj, d.oldFields, d.fields = col.ReplaceOrInsert(d.id, d.obj, fields, values)
d.oldObj, d.oldFields, d.fields = col.Set(d.id, d.obj, fields, values)
d.command = "set"
d.updated = true // perhaps we should do a diff on the previous object?
d.timestamp = time.Now()

View File

@ -213,7 +213,7 @@ func (c *Controller) cmdJset(msg *server.Message) (res resp.Value, d commandDeta
d.updated = true
c.clearIDExpires(key, id)
col.ReplaceOrInsert(d.id, d.obj, nil, nil)
col.Set(d.id, d.obj, nil, nil)
switch msg.OutputType {
case server.JSON:
var buf bytes.Buffer
@ -287,7 +287,7 @@ func (c *Controller) cmdJdel(msg *server.Message) (res resp.Value, d commandDeta
d.updated = true
c.clearIDExpires(d.key, d.id)
col.ReplaceOrInsert(d.id, d.obj, nil, nil)
col.Set(d.id, d.obj, nil, nil)
switch msg.OutputType {
case server.JSON:
var buf bytes.Buffer

403
pkg/ds/btree.go Normal file
View File

@ -0,0 +1,403 @@
package ds
const maxItems = 31 // use an odd number
const minItems = maxItems / 2
type item struct {
key string
value interface{}
}
type node struct {
numItems int
items [maxItems]item
children [maxItems + 1]*node
}
type leaf struct {
numItems int
items [maxItems]item
}
// BTree is an ordered set of key/value pairs where the key is a string
// and the value is an interface{}
type BTree struct {
height int
root *node
length int
}
func (n *node) find(key string) (index int, found bool) {
i, j := 0, n.numItems
for i < j {
h := i + (j-i)/2
if key >= n.items[h].key {
i = h + 1
} else {
j = h
}
}
if i > 0 && n.items[i-1].key >= key {
return i - 1, true
}
return i, false
}
// Set or replace a value for a key
func (tr *BTree) Set(key string, value interface{}) (
prev interface{}, replaced bool,
) {
if tr.root == nil {
tr.root = new(node)
tr.root.items[0] = item{key, value}
tr.root.numItems = 1
tr.length = 1
return
}
prev, replaced = tr.root.set(key, value, tr.height)
if replaced {
return
}
if tr.root.numItems == maxItems {
n := tr.root
right, median := n.split(tr.height)
tr.root = new(node)
tr.root.children[0] = n
tr.root.items[0] = median
tr.root.children[1] = right
tr.root.numItems = 1
tr.height++
}
tr.length++
return
}
func (n *node) split(height int) (right *node, median item) {
right = new(node)
median = n.items[maxItems/2]
copy(right.items[:maxItems/2], n.items[maxItems/2+1:])
if height > 0 {
copy(right.children[:maxItems/2+1], n.children[maxItems/2+1:])
}
right.numItems = maxItems / 2
if height > 0 {
for i := maxItems/2 + 1; i < maxItems+1; i++ {
n.children[i] = nil
}
}
for i := maxItems / 2; i < maxItems; i++ {
n.items[i] = item{}
}
n.numItems = maxItems / 2
return
}
func (n *node) set(key string, value interface{}, height int) (
prev interface{}, replaced bool,
) {
i, found := n.find(key)
if found {
prev = n.items[i].value
n.items[i].value = value
return prev, true
}
if height == 0 {
for j := n.numItems; j > i; j-- {
n.items[j] = n.items[j-1]
}
n.items[i] = item{key, value}
n.numItems++
return nil, false
}
prev, replaced = n.children[i].set(key, value, height-1)
if replaced {
return
}
if n.children[i].numItems == maxItems {
right, median := n.children[i].split(height - 1)
copy(n.children[i+1:], n.children[i:])
copy(n.items[i+1:], n.items[i:])
n.items[i] = median
n.children[i+1] = right
n.numItems++
}
return
}
// Scan all items in tree
func (tr *BTree) Scan(iter func(key string, value interface{}) bool) {
if tr.root != nil {
tr.root.scan(iter, tr.height)
}
}
func (n *node) scan(
iter func(key string, value interface{}) bool, height int,
) bool {
if height == 0 {
for i := 0; i < n.numItems; i++ {
if !iter(n.items[i].key, n.items[i].value) {
return false
}
}
return true
}
for i := 0; i < n.numItems; i++ {
if !n.children[i].scan(iter, height-1) {
return false
}
if !iter(n.items[i].key, n.items[i].value) {
return false
}
}
return n.children[n.numItems].scan(iter, height-1)
}
// Get a value for key
func (tr *BTree) Get(key string) (value interface{}, gotten bool) {
if tr.root == nil {
return
}
return tr.root.get(key, tr.height)
}
func (n *node) get(key string, height int) (value interface{}, gotten bool) {
i, found := n.find(key)
if found {
return n.items[i].value, true
}
if height == 0 {
return nil, false
}
return n.children[i].get(key, height-1)
}
// Len returns the number of items in the tree
func (tr *BTree) Len() int {
return tr.length
}
// Delete a value for a key
func (tr *BTree) Delete(key string) (prev interface{}, deleted bool) {
if tr.root == nil {
return
}
var prevItem item
prevItem, deleted = tr.root.delete(false, key, tr.height)
if !deleted {
return
}
prev = prevItem.value
if tr.root.numItems == 0 {
tr.root = tr.root.children[0]
tr.height--
}
tr.length--
if tr.length == 0 {
tr.root = nil
}
return
}
func (n *node) delete(max bool, key string, height int) (
prev item, deleted bool,
) {
i, found := 0, false
if max {
i, found = n.numItems-1, true
} else {
i, found = n.find(key)
}
if height == 0 {
if found {
prev = n.items[i]
// found the items at the leaf, remove it and return.
copy(n.items[i:], n.items[i+1:n.numItems])
n.items[n.numItems-1] = item{}
n.children[n.numItems] = nil
n.numItems--
return prev, true
}
return item{}, false
}
if found {
if max {
i++
prev, deleted = n.children[i].delete(true, "", height-1)
} else {
prev = n.items[i]
maxItem, _ := n.children[i].delete(true, "", height-1)
n.items[i] = maxItem
deleted = true
}
} else {
prev, deleted = n.children[i].delete(max, key, height-1)
}
if !deleted {
return
}
if n.children[i].numItems < minItems {
if i == n.numItems {
i--
}
if n.children[i].numItems+n.children[i+1].numItems+1 < maxItems {
// merge left + item + right
n.children[i].items[n.children[i].numItems] = n.items[i]
copy(n.children[i].items[n.children[i].numItems+1:],
n.children[i+1].items[:n.children[i+1].numItems])
if height > 1 {
copy(n.children[i].children[n.children[i].numItems+1:],
n.children[i+1].children[:n.children[i+1].numItems+1])
}
n.children[i].numItems += n.children[i+1].numItems + 1
copy(n.items[i:], n.items[i+1:n.numItems])
copy(n.children[i+1:], n.children[i+2:n.numItems+1])
n.items[n.numItems] = item{}
n.children[n.numItems+1] = nil
n.numItems--
} else if n.children[i].numItems > n.children[i+1].numItems {
// move left -> right
copy(n.children[i+1].items[1:],
n.children[i+1].items[:n.children[i+1].numItems])
if height > 1 {
copy(n.children[i+1].children[1:],
n.children[i+1].children[:n.children[i+1].numItems+1])
}
n.children[i+1].items[0] = n.items[i]
if height > 1 {
n.children[i+1].children[0] =
n.children[i].children[n.children[i].numItems]
}
n.children[i+1].numItems++
n.items[i] = n.children[i].items[n.children[i].numItems-1]
n.children[i].items[n.children[i].numItems-1] = item{}
if height > 1 {
n.children[i].children[n.children[i].numItems] = nil
}
n.children[i].numItems--
} else {
// move right -> left
n.children[i].items[n.children[i].numItems] = n.items[i]
if height > 1 {
n.children[i].children[n.children[i].numItems+1] =
n.children[i+1].children[0]
}
n.children[i].numItems++
n.items[i] = n.children[i+1].items[0]
copy(n.children[i+1].items[:],
n.children[i+1].items[1:n.children[i+1].numItems])
if height > 1 {
copy(n.children[i+1].children[:],
n.children[i+1].children[1:n.children[i+1].numItems+1])
}
n.children[i+1].numItems--
}
}
return
}
// Ascend the tree within the range [pivot, last]
func (tr *BTree) Ascend(
pivot string,
iter func(key string, value interface{}) bool,
) {
if tr.root != nil {
tr.root.ascend(pivot, iter, tr.height)
}
}
func (n *node) ascend(
pivot string,
iter func(key string, value interface{}) bool,
height int,
) bool {
i, found := n.find(pivot)
if !found {
if height > 0 {
if !n.children[i].ascend(pivot, iter, height-1) {
return false
}
}
}
for ; i < n.numItems; i++ {
if !iter(n.items[i].key, n.items[i].value) {
return false
}
if height > 0 {
if !n.children[i+1].scan(iter, height-1) {
return false
}
}
}
return true
}
// Reverse all items in tree
func (tr *BTree) Reverse(iter func(key string, value interface{}) bool) {
if tr.root != nil {
tr.root.reverse(iter, tr.height)
}
}
func (n *node) reverse(
iter func(key string, value interface{}) bool, height int,
) bool {
if height == 0 {
for i := n.numItems - 1; i >= 0; i-- {
if !iter(n.items[i].key, n.items[i].value) {
return false
}
}
return true
}
if !n.children[n.numItems].reverse(iter, height-1) {
return false
}
for i := n.numItems - 1; i >= 0; i-- {
if !iter(n.items[i].key, n.items[i].value) {
return false
}
if !n.children[i].reverse(iter, height-1) {
return false
}
}
return true
}
// Descend the tree within the range [pivot, first]
func (tr *BTree) Descend(
pivot string,
iter func(key string, value interface{}) bool,
) {
if tr.root != nil {
tr.root.descend(pivot, iter, tr.height)
}
}
func (n *node) descend(
pivot string,
iter func(key string, value interface{}) bool,
height int,
) bool {
i, found := n.find(pivot)
if !found {
if height > 0 {
if !n.children[i].descend(pivot, iter, height-1) {
return false
}
}
i--
}
for ; i >= 0; i-- {
if !iter(n.items[i].key, n.items[i].value) {
return false
}
if height > 0 {
if !n.children[i].reverse(iter, height-1) {
return false
}
}
}
return true
}

430
pkg/ds/btree_test.go Normal file
View File

@ -0,0 +1,430 @@
package ds
import (
"fmt"
"math/rand"
"strings"
"testing"
"time"
)
func init() {
seed := time.Now().UnixNano()
fmt.Printf("seed: %d\n", seed)
rand.Seed(seed)
}
func randKeys(N int) (keys []string) {
format := fmt.Sprintf("%%0%dd", len(fmt.Sprintf("%d", N-1)))
for _, i := range rand.Perm(N) {
keys = append(keys, fmt.Sprintf(format, i))
}
return
}
const flatLeaf = true
func (tr *BTree) print() {
tr.root.print(0, tr.height)
}
func (n *node) print(level, height int) {
if n == nil {
println("NIL")
return
}
if height == 0 && flatLeaf {
fmt.Printf("%s", strings.Repeat(" ", level))
}
for i := 0; i < n.numItems; i++ {
if height > 0 {
n.children[i].print(level+1, height-1)
}
if height > 0 || (height == 0 && !flatLeaf) {
fmt.Printf("%s%v\n", strings.Repeat(" ", level), n.items[i].key)
} else {
if i > 0 {
fmt.Printf(",")
}
fmt.Printf("%s", n.items[i].key)
}
}
if height == 0 && flatLeaf {
fmt.Printf("\n")
}
if height > 0 {
n.children[n.numItems].print(level+1, height-1)
}
}
func (tr *BTree) deepPrint() {
fmt.Printf("%#v\n", tr)
tr.root.deepPrint(0, tr.height)
}
func (n *node) deepPrint(level, height int) {
if n == nil {
fmt.Printf("%s %#v\n", strings.Repeat(" ", level), n)
return
}
fmt.Printf("%s count: %v\n", strings.Repeat(" ", level), n.numItems)
fmt.Printf("%s items: %v\n", strings.Repeat(" ", level), n.items)
if height > 0 {
fmt.Printf("%s child: %v\n", strings.Repeat(" ", level), n.children)
}
if height > 0 {
for i := 0; i < n.numItems; i++ {
n.children[i].deepPrint(level+1, height-1)
}
n.children[n.numItems].deepPrint(level+1, height-1)
}
}
func stringsEquals(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if a[i] != b[i] {
return false
}
}
return true
}
func TestDescend(t *testing.T) {
var tr BTree
var count int
tr.Descend("1", func(key string, value interface{}) bool {
count++
return true
})
if count > 0 {
t.Fatalf("expected 0, got %v", count)
}
var keys []string
for i := 0; i < 1000; i += 10 {
keys = append(keys, fmt.Sprintf("%03d", i))
tr.Set(keys[len(keys)-1], nil)
}
var exp []string
tr.Reverse(func(key string, _ interface{}) bool {
exp = append(exp, key)
return true
})
for i := 999; i >= 0; i-- {
var key string
key = fmt.Sprintf("%03d", i)
var all []string
tr.Descend(key, func(key string, value interface{}) bool {
all = append(all, key)
return true
})
for len(exp) > 0 && key < exp[0] {
exp = exp[1:]
}
var count int
tr.Descend(key, func(key string, value interface{}) bool {
if count == (i+1)%maxItems {
return false
}
count++
return true
})
if count > len(exp) {
t.Fatalf("expected 1, got %v", count)
}
if !stringsEquals(exp, all) {
fmt.Printf("exp: %v\n", exp)
fmt.Printf("all: %v\n", all)
t.Fatal("mismatch")
}
}
}
func TestAscend(t *testing.T) {
var tr BTree
var count int
tr.Ascend("1", func(key string, value interface{}) bool {
count++
return true
})
if count > 0 {
t.Fatalf("expected 0, got %v", count)
}
var keys []string
for i := 0; i < 1000; i += 10 {
keys = append(keys, fmt.Sprintf("%03d", i))
tr.Set(keys[len(keys)-1], nil)
}
exp := keys
for i := -1; i < 1000; i++ {
var key string
if i == -1 {
key = ""
} else {
key = fmt.Sprintf("%03d", i)
}
var all []string
tr.Ascend(key, func(key string, value interface{}) bool {
all = append(all, key)
return true
})
for len(exp) > 0 && key > exp[0] {
exp = exp[1:]
}
var count int
tr.Ascend(key, func(key string, value interface{}) bool {
if count == (i+1)%maxItems {
return false
}
count++
return true
})
if count > len(exp) {
t.Fatalf("expected 1, got %v", count)
}
if !stringsEquals(exp, all) {
t.Fatal("mismatch")
}
}
}
func TestBTree(t *testing.T) {
N := 10000
var tr BTree
keys := randKeys(N)
// insert all items
for _, key := range keys {
value, replaced := tr.Set(key, key)
if replaced {
t.Fatal("expected false")
}
if value != nil {
t.Fatal("expected nil")
}
}
// check length
if tr.Len() != len(keys) {
t.Fatalf("expected %v, got %v", len(keys), tr.Len())
}
// get each value
for _, key := range keys {
value, gotten := tr.Get(key)
if !gotten {
t.Fatal("expected true")
}
if value == nil || value.(string) != key {
t.Fatalf("expected '%v', got '%v'", key, value)
}
}
// scan all items
var last string
all := make(map[string]interface{})
tr.Scan(func(key string, value interface{}) bool {
if key <= last {
t.Fatal("out of order")
}
if value.(string) != key {
t.Fatalf("mismatch")
}
last = key
all[key] = value
return true
})
if len(all) != len(keys) {
t.Fatalf("expected '%v', got '%v'", len(keys), len(all))
}
// reverse all items
var prev string
all = make(map[string]interface{})
tr.Reverse(func(key string, value interface{}) bool {
if prev != "" && key >= prev {
t.Fatal("out of order")
}
if value.(string) != key {
t.Fatalf("mismatch")
}
prev = key
all[key] = value
return true
})
if len(all) != len(keys) {
t.Fatalf("expected '%v', got '%v'", len(keys), len(all))
}
// try to get an invalid item
value, gotten := tr.Get("invalid")
if gotten {
t.Fatal("expected false")
}
if value != nil {
t.Fatal("expected nil")
}
// scan and quit at various steps
for i := 0; i < 100; i++ {
var j int
tr.Scan(func(key string, value interface{}) bool {
if j == i {
return false
}
j++
return true
})
}
// reverse and quit at various steps
for i := 0; i < 100; i++ {
var j int
tr.Reverse(func(key string, value interface{}) bool {
if j == i {
return false
}
j++
return true
})
}
// delete half the items
for _, key := range keys[:len(keys)/2] {
value, deleted := tr.Delete(key)
if !deleted {
t.Fatal("expected true")
}
if value == nil || value.(string) != key {
t.Fatalf("expected '%v', got '%v'", key, value)
}
}
// check length
if tr.Len() != len(keys)/2 {
t.Fatalf("expected %v, got %v", len(keys)/2, tr.Len())
}
// try delete half again
for _, key := range keys[:len(keys)/2] {
value, deleted := tr.Delete(key)
if deleted {
t.Fatal("expected false")
}
if value != nil {
t.Fatalf("expected nil")
}
}
// try delete half again
for _, key := range keys[:len(keys)/2] {
value, deleted := tr.Delete(key)
if deleted {
t.Fatal("expected false")
}
if value != nil {
t.Fatalf("expected nil")
}
}
// check length
if tr.Len() != len(keys)/2 {
t.Fatalf("expected %v, got %v", len(keys)/2, tr.Len())
}
// scan items
last = ""
all = make(map[string]interface{})
tr.Scan(func(key string, value interface{}) bool {
if key <= last {
t.Fatal("out of order")
}
if value.(string) != key {
t.Fatalf("mismatch")
}
last = key
all[key] = value
return true
})
if len(all) != len(keys)/2 {
t.Fatalf("expected '%v', got '%v'", len(keys), len(all))
}
// replace second half
for _, key := range keys[len(keys)/2:] {
value, replaced := tr.Set(key, key)
if !replaced {
t.Fatal("expected true")
}
if value == nil || value.(string) != key {
t.Fatalf("expected '%v', got '%v'", key, value)
}
}
// delete next half the items
for _, key := range keys[len(keys)/2:] {
value, deleted := tr.Delete(key)
if !deleted {
t.Fatal("expected true")
}
if value == nil || value.(string) != key {
t.Fatalf("expected '%v', got '%v'", key, value)
}
}
// check length
if tr.Len() != 0 {
t.Fatalf("expected %v, got %v", 0, tr.Len())
}
// do some stuff on an empty tree
value, gotten = tr.Get(keys[0])
if gotten {
t.Fatal("expected false")
}
if value != nil {
t.Fatal("expected nil")
}
tr.Scan(func(key string, value interface{}) bool {
t.Fatal("should not be reached")
return true
})
tr.Reverse(func(key string, value interface{}) bool {
t.Fatal("should not be reached")
return true
})
var deleted bool
value, deleted = tr.Delete("invalid")
if deleted {
t.Fatal("expected false")
}
if value != nil {
t.Fatal("expected nil")
}
}
func BenchmarkTidwallSet(b *testing.B) {
var tr BTree
keys := randKeys(b.N)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tr.Set(keys[i], nil)
}
}
func BenchmarkTidwallGet(b *testing.B) {
var tr BTree
keys := randKeys(b.N)
for i := 0; i < b.N; i++ {
tr.Set(keys[i], nil)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tr.Get(keys[i])
}
}