diff --git a/pkg/collection/collection.go b/pkg/collection/collection.go index ba3c6e32..444bb80e 100644 --- a/pkg/collection/collection.go +++ b/pkg/collection/collection.go @@ -4,6 +4,7 @@ import ( "math" "github.com/tidwall/btree" + "github.com/tidwall/tile38/pkg/ds" "github.com/tidwall/tile38/pkg/geojson" "github.com/tidwall/tile38/pkg/index" ) @@ -49,7 +50,7 @@ func (i *itemT) Point() (x, y float64) { // Collection represents a collection of geojson objects. type Collection struct { - items *btree.BTree // items sorted by keys + items ds.BTree // items sorted by keys values *btree.BTree // items sorted by value+key index *index.Index // items geospatially indexed fieldMap map[string]int @@ -66,8 +67,7 @@ var counter uint64 func New() *Collection { col := &Collection{ index: index.New(), - items: btree.New(128, idOrdered), - values: btree.New(128, valueOrdered), + values: btree.New(16, valueOrdered), fieldMap: make(map[string]int), } return col @@ -117,15 +117,15 @@ func (c *Collection) Bounds() (minX, minY, maxX, maxY float64) { return c.index.Bounds() } -// ReplaceOrInsert adds or replaces an object in the collection and returns the fields array. +// Set adds or replaces an object in the collection and returns the fields array. // If an item with the same id is already in the collection then the new item will adopt the old item's fields. // The fields argument is optional. // The return values are the old object, the old fields, and the new fields -func (c *Collection) ReplaceOrInsert(id string, obj geojson.Object, fields []string, values []float64) (oldObject geojson.Object, oldFields []float64, newFields []float64) { +func (c *Collection) Set(id string, obj geojson.Object, fields []string, values []float64) (oldObject geojson.Object, oldFields []float64, newFields []float64) { var oldItem *itemT - var newItem *itemT = &itemT{id: id, object: obj} + newItem := &itemT{id: id, object: obj} // add the new item to main btree and remove the old one if needed - oldItemPtr := c.items.ReplaceOrInsert(newItem) + oldItemPtr, _ := c.items.Set(id, newItem) if oldItemPtr != nil { // the old item was removed, now let's remove from the rtree // or strings tree. @@ -186,14 +186,14 @@ func (c *Collection) ReplaceOrInsert(id string, obj geojson.Object, fields []str return oldObject, oldFields, newFields } -// Remove removes an object and returns it. +// Delete removes an object and returns it. // If the object does not exist then the 'ok' return value will be false. -func (c *Collection) Remove(id string) (obj geojson.Object, fields []float64, ok bool) { - i := c.items.Delete(&itemT{id: id}) - if i == nil { +func (c *Collection) Delete(id string) (obj geojson.Object, fields []float64, ok bool) { + old, _ := c.items.Delete(id) + if old == nil { return nil, nil, false } - item := i.(*itemT) + item := old.(*itemT) if item.object.IsGeometry() { c.index.Remove(item) c.objects-- @@ -212,43 +212,43 @@ func (c *Collection) Remove(id string) (obj geojson.Object, fields []float64, ok // Get returns an object. // If the object does not exist then the 'ok' return value will be false. func (c *Collection) Get(id string) (obj geojson.Object, fields []float64, ok bool) { - i := c.items.Get(&itemT{id: id}) - if i == nil { + val, _ := c.items.Get(id) + if val == nil { return nil, nil, false } - item := i.(*itemT) + item := val.(*itemT) return item.object, c.getFieldValues(id), true } // SetField set a field value for an object and returns that object. // If the object does not exist then the 'ok' return value will be false. func (c *Collection) SetField(id, field string, value float64) (obj geojson.Object, fields []float64, updated bool, ok bool) { - i := c.items.Get(&itemT{id: id}) - if i == nil { + val, _ := c.items.Get(id) + if val == nil { ok = false return } - item := i.(*itemT) + item := val.(*itemT) updated = c.setField(item, field, value) return item.object, c.getFieldValues(id), updated, true } // SetFields is similar to SetField, just setting multiple fields at once -func (c *Collection) SetFields(id string, in_fields []string, in_values []float64) ( - obj geojson.Object, fields []float64, updated_count int, ok bool, +func (c *Collection) SetFields(id string, inFields []string, inValues []float64) ( + obj geojson.Object, fields []float64, updatedCount int, ok bool, ) { - i := c.items.Get(&itemT{id: id}) - if i == nil { + val, _ := c.items.Get(id) + if val == nil { ok = false return } - item := i.(*itemT) - for idx, field := range in_fields { - if c.setField(item, field, in_values[idx]) { - updated_count++ + item := val.(*itemT) + for idx, field := range inFields { + if c.setField(item, field, inValues[idx]) { + updatedCount++ } } - return item.object, c.getFieldValues(id), updated_count, true + return item.object, c.getFieldValues(id), updatedCount, true } func (c *Collection) setField(item *itemT, field string, value float64) (updated bool) { @@ -288,34 +288,43 @@ func (c *Collection) Scan(desc bool, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true - iter := func(item btree.Item) bool { - iitm := item.(*itemT) + iter := func(key string, value interface{}) bool { + iitm := value.(*itemT) keepon = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id)) return keepon } if desc { - c.items.Descend(iter) + c.items.Reverse(iter) } else { - c.items.Ascend(iter) + c.items.Scan(iter) } return keepon } -// ScanGreaterOrEqual iterates though the collection starting with specified id. +// ScanRange iterates though the collection starting with specified id. func (c *Collection) ScanRange(start, end string, desc bool, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true - iter := func(item btree.Item) bool { - iitm := item.(*itemT) + iter := func(key string, value interface{}) bool { + if !desc { + if key >= end { + return false + } + } else { + if key <= end { + return false + } + } + iitm := value.(*itemT) keepon = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id)) return keepon } if desc { - c.items.DescendRange(&itemT{id: start}, &itemT{id: end}, iter) + c.items.Descend(start, iter) } else { - c.items.AscendRange(&itemT{id: start}, &itemT{id: end}, iter) + c.items.Ascend(start, iter) } return keepon } @@ -361,15 +370,15 @@ func (c *Collection) ScanGreaterOrEqual(id string, desc bool, iterator func(id string, obj geojson.Object, fields []float64) bool, ) bool { var keepon = true - iter := func(item btree.Item) bool { - iitm := item.(*itemT) + iter := func(key string, value interface{}) bool { + iitm := value.(*itemT) keepon = iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id)) return keepon } if desc { - c.items.DescendLessOrEqual(&itemT{id: id}, iter) + c.items.Descend(id, iter) } else { - c.items.AscendGreaterOrEqual(&itemT{id: id}, iter) + c.items.Ascend(id, iter) } return keepon } @@ -593,7 +602,11 @@ func (c *Collection) Intersects( }) } -func (c *Collection) NearestNeighbors(lat, lon float64, iterator func(id string, obj geojson.Object, fields []float64) bool) bool { +// NearestNeighbors returns the nearest neighbors +func (c *Collection) NearestNeighbors( + lat, lon float64, + iterator func(id string, obj geojson.Object, fields []float64) bool, +) bool { return c.index.KNN(lon, lat, func(item interface{}) bool { var iitm *itemT iitm, ok := item.(*itemT) diff --git a/pkg/collection/collection_test.go b/pkg/collection/collection_test.go index dcc34bf1..1ee48077 100644 --- a/pkg/collection/collection_test.go +++ b/pkg/collection/collection_test.go @@ -28,7 +28,7 @@ func TestCollection(t *testing.T) { }} } objs[id] = obj - c.ReplaceOrInsert(id, obj, nil, nil) + c.Set(id, obj, nil, nil) } count := 0 bbox := geojson.BBox{Min: geojson.Position{X: -180, Y: -90, Z: 0}, Max: geojson.Position{X: 180, Y: 90, Z: 0}} @@ -76,7 +76,7 @@ func TestManyCollections(t *testing.T) { col = New() colsM[key] = col } - col.ReplaceOrInsert(id, obj, nil, nil) + col.Set(id, obj, nil, nil) k++ } } @@ -110,7 +110,7 @@ func BenchmarkInsert(t *testing.B) { col := New() t.ResetTimer() for i := 0; i < t.N; i++ { - col.ReplaceOrInsert(items[i].id, items[i].object, nil, nil) + col.Set(items[i].id, items[i].object, nil, nil) } } @@ -128,11 +128,11 @@ func BenchmarkReplace(t *testing.B) { } col := New() for i := 0; i < t.N; i++ { - col.ReplaceOrInsert(items[i].id, items[i].object, nil, nil) + col.Set(items[i].id, items[i].object, nil, nil) } t.ResetTimer() for _, i := range rand.Perm(t.N) { - o, _, _ := col.ReplaceOrInsert(items[i].id, items[i].object, nil, nil) + o, _, _ := col.Set(items[i].id, items[i].object, nil, nil) if o != items[i].object { t.Fatal("shoot!") } @@ -153,7 +153,7 @@ func BenchmarkGet(t *testing.B) { } col := New() for i := 0; i < t.N; i++ { - col.ReplaceOrInsert(items[i].id, items[i].object, nil, nil) + col.Set(items[i].id, items[i].object, nil, nil) } t.ResetTimer() for _, i := range rand.Perm(t.N) { @@ -178,11 +178,11 @@ func BenchmarkRemove(t *testing.B) { } col := New() for i := 0; i < t.N; i++ { - col.ReplaceOrInsert(items[i].id, items[i].object, nil, nil) + col.Set(items[i].id, items[i].object, nil, nil) } t.ResetTimer() for _, i := range rand.Perm(t.N) { - o, _, _ := col.Remove(items[i].id) + o, _, _ := col.Delete(items[i].id) if o != items[i].object { t.Fatal("shoot!") } diff --git a/pkg/controller/crud.go b/pkg/controller/crud.go index 116500cc..a77a1478 100644 --- a/pkg/controller/crud.go +++ b/pkg/controller/crud.go @@ -307,7 +307,7 @@ func (c *Controller) cmdDel(msg *server.Message) (res resp.Value, d commandDetai found := false col := c.getCol(d.key) if col != nil { - d.obj, d.fields, ok = col.Remove(d.id) + d.obj, d.fields, ok = col.Delete(d.id) if ok { if col.Count() == 0 { c.deleteCol(d.key) @@ -373,7 +373,7 @@ func (c *Controller) cmdPdel(msg *server.Message) (res resp.Value, d commandDeta } var atLeastOneNotDeleted bool for i, dc := range d.children { - dc.obj, dc.fields, ok = col.Remove(dc.id) + dc.obj, dc.fields, ok = col.Delete(dc.id) if !ok { d.children[i].command = "?" atLeastOneNotDeleted = true @@ -740,7 +740,7 @@ func (c *Controller) cmdSet(msg *server.Message) (res resp.Value, d commandDetai } } c.clearIDExpires(d.key, d.id) - d.oldObj, d.oldFields, d.fields = col.ReplaceOrInsert(d.id, d.obj, fields, values) + d.oldObj, d.oldFields, d.fields = col.Set(d.id, d.obj, fields, values) d.command = "set" d.updated = true // perhaps we should do a diff on the previous object? d.timestamp = time.Now() diff --git a/pkg/controller/json.go b/pkg/controller/json.go index d0b9b228..abf31537 100644 --- a/pkg/controller/json.go +++ b/pkg/controller/json.go @@ -213,7 +213,7 @@ func (c *Controller) cmdJset(msg *server.Message) (res resp.Value, d commandDeta d.updated = true c.clearIDExpires(key, id) - col.ReplaceOrInsert(d.id, d.obj, nil, nil) + col.Set(d.id, d.obj, nil, nil) switch msg.OutputType { case server.JSON: var buf bytes.Buffer @@ -287,7 +287,7 @@ func (c *Controller) cmdJdel(msg *server.Message) (res resp.Value, d commandDeta d.updated = true c.clearIDExpires(d.key, d.id) - col.ReplaceOrInsert(d.id, d.obj, nil, nil) + col.Set(d.id, d.obj, nil, nil) switch msg.OutputType { case server.JSON: var buf bytes.Buffer diff --git a/pkg/ds/btree.go b/pkg/ds/btree.go new file mode 100644 index 00000000..7b2d011e --- /dev/null +++ b/pkg/ds/btree.go @@ -0,0 +1,403 @@ +package ds + +const maxItems = 31 // use an odd number +const minItems = maxItems / 2 + +type item struct { + key string + value interface{} +} + +type node struct { + numItems int + items [maxItems]item + children [maxItems + 1]*node +} + +type leaf struct { + numItems int + items [maxItems]item +} + +// BTree is an ordered set of key/value pairs where the key is a string +// and the value is an interface{} +type BTree struct { + height int + root *node + length int +} + +func (n *node) find(key string) (index int, found bool) { + i, j := 0, n.numItems + for i < j { + h := i + (j-i)/2 + if key >= n.items[h].key { + i = h + 1 + } else { + j = h + } + } + if i > 0 && n.items[i-1].key >= key { + return i - 1, true + } + return i, false +} + +// Set or replace a value for a key +func (tr *BTree) Set(key string, value interface{}) ( + prev interface{}, replaced bool, +) { + if tr.root == nil { + tr.root = new(node) + tr.root.items[0] = item{key, value} + tr.root.numItems = 1 + tr.length = 1 + return + } + prev, replaced = tr.root.set(key, value, tr.height) + if replaced { + return + } + if tr.root.numItems == maxItems { + n := tr.root + right, median := n.split(tr.height) + tr.root = new(node) + tr.root.children[0] = n + tr.root.items[0] = median + tr.root.children[1] = right + tr.root.numItems = 1 + tr.height++ + } + tr.length++ + return +} + +func (n *node) split(height int) (right *node, median item) { + right = new(node) + median = n.items[maxItems/2] + copy(right.items[:maxItems/2], n.items[maxItems/2+1:]) + if height > 0 { + copy(right.children[:maxItems/2+1], n.children[maxItems/2+1:]) + } + right.numItems = maxItems / 2 + if height > 0 { + for i := maxItems/2 + 1; i < maxItems+1; i++ { + n.children[i] = nil + } + } + for i := maxItems / 2; i < maxItems; i++ { + n.items[i] = item{} + } + n.numItems = maxItems / 2 + return +} + +func (n *node) set(key string, value interface{}, height int) ( + prev interface{}, replaced bool, +) { + i, found := n.find(key) + if found { + prev = n.items[i].value + n.items[i].value = value + return prev, true + } + if height == 0 { + for j := n.numItems; j > i; j-- { + n.items[j] = n.items[j-1] + } + n.items[i] = item{key, value} + n.numItems++ + return nil, false + } + prev, replaced = n.children[i].set(key, value, height-1) + if replaced { + return + } + if n.children[i].numItems == maxItems { + right, median := n.children[i].split(height - 1) + copy(n.children[i+1:], n.children[i:]) + copy(n.items[i+1:], n.items[i:]) + n.items[i] = median + n.children[i+1] = right + n.numItems++ + } + return +} + +// Scan all items in tree +func (tr *BTree) Scan(iter func(key string, value interface{}) bool) { + if tr.root != nil { + tr.root.scan(iter, tr.height) + } +} + +func (n *node) scan( + iter func(key string, value interface{}) bool, height int, +) bool { + if height == 0 { + for i := 0; i < n.numItems; i++ { + if !iter(n.items[i].key, n.items[i].value) { + return false + } + } + return true + } + for i := 0; i < n.numItems; i++ { + if !n.children[i].scan(iter, height-1) { + return false + } + if !iter(n.items[i].key, n.items[i].value) { + return false + } + } + return n.children[n.numItems].scan(iter, height-1) +} + +// Get a value for key +func (tr *BTree) Get(key string) (value interface{}, gotten bool) { + if tr.root == nil { + return + } + return tr.root.get(key, tr.height) +} + +func (n *node) get(key string, height int) (value interface{}, gotten bool) { + i, found := n.find(key) + if found { + return n.items[i].value, true + } + if height == 0 { + return nil, false + } + return n.children[i].get(key, height-1) +} + +// Len returns the number of items in the tree +func (tr *BTree) Len() int { + return tr.length +} + +// Delete a value for a key +func (tr *BTree) Delete(key string) (prev interface{}, deleted bool) { + if tr.root == nil { + return + } + var prevItem item + prevItem, deleted = tr.root.delete(false, key, tr.height) + if !deleted { + return + } + prev = prevItem.value + if tr.root.numItems == 0 { + tr.root = tr.root.children[0] + tr.height-- + } + tr.length-- + if tr.length == 0 { + tr.root = nil + } + return +} + +func (n *node) delete(max bool, key string, height int) ( + prev item, deleted bool, +) { + i, found := 0, false + if max { + i, found = n.numItems-1, true + } else { + i, found = n.find(key) + } + if height == 0 { + if found { + prev = n.items[i] + // found the items at the leaf, remove it and return. + copy(n.items[i:], n.items[i+1:n.numItems]) + n.items[n.numItems-1] = item{} + n.children[n.numItems] = nil + n.numItems-- + return prev, true + } + return item{}, false + } + + if found { + if max { + i++ + prev, deleted = n.children[i].delete(true, "", height-1) + } else { + prev = n.items[i] + maxItem, _ := n.children[i].delete(true, "", height-1) + n.items[i] = maxItem + deleted = true + } + } else { + prev, deleted = n.children[i].delete(max, key, height-1) + } + if !deleted { + return + } + if n.children[i].numItems < minItems { + if i == n.numItems { + i-- + } + if n.children[i].numItems+n.children[i+1].numItems+1 < maxItems { + // merge left + item + right + n.children[i].items[n.children[i].numItems] = n.items[i] + copy(n.children[i].items[n.children[i].numItems+1:], + n.children[i+1].items[:n.children[i+1].numItems]) + if height > 1 { + copy(n.children[i].children[n.children[i].numItems+1:], + n.children[i+1].children[:n.children[i+1].numItems+1]) + } + n.children[i].numItems += n.children[i+1].numItems + 1 + copy(n.items[i:], n.items[i+1:n.numItems]) + copy(n.children[i+1:], n.children[i+2:n.numItems+1]) + n.items[n.numItems] = item{} + n.children[n.numItems+1] = nil + n.numItems-- + } else if n.children[i].numItems > n.children[i+1].numItems { + // move left -> right + copy(n.children[i+1].items[1:], + n.children[i+1].items[:n.children[i+1].numItems]) + if height > 1 { + copy(n.children[i+1].children[1:], + n.children[i+1].children[:n.children[i+1].numItems+1]) + } + n.children[i+1].items[0] = n.items[i] + if height > 1 { + n.children[i+1].children[0] = + n.children[i].children[n.children[i].numItems] + } + n.children[i+1].numItems++ + n.items[i] = n.children[i].items[n.children[i].numItems-1] + n.children[i].items[n.children[i].numItems-1] = item{} + if height > 1 { + n.children[i].children[n.children[i].numItems] = nil + } + n.children[i].numItems-- + } else { + // move right -> left + n.children[i].items[n.children[i].numItems] = n.items[i] + if height > 1 { + n.children[i].children[n.children[i].numItems+1] = + n.children[i+1].children[0] + } + n.children[i].numItems++ + n.items[i] = n.children[i+1].items[0] + copy(n.children[i+1].items[:], + n.children[i+1].items[1:n.children[i+1].numItems]) + if height > 1 { + copy(n.children[i+1].children[:], + n.children[i+1].children[1:n.children[i+1].numItems+1]) + } + n.children[i+1].numItems-- + } + } + return +} + +// Ascend the tree within the range [pivot, last] +func (tr *BTree) Ascend( + pivot string, + iter func(key string, value interface{}) bool, +) { + if tr.root != nil { + tr.root.ascend(pivot, iter, tr.height) + } +} + +func (n *node) ascend( + pivot string, + iter func(key string, value interface{}) bool, + height int, +) bool { + i, found := n.find(pivot) + if !found { + if height > 0 { + if !n.children[i].ascend(pivot, iter, height-1) { + return false + } + } + } + for ; i < n.numItems; i++ { + if !iter(n.items[i].key, n.items[i].value) { + return false + } + if height > 0 { + if !n.children[i+1].scan(iter, height-1) { + return false + } + } + } + return true +} + +// Reverse all items in tree +func (tr *BTree) Reverse(iter func(key string, value interface{}) bool) { + if tr.root != nil { + tr.root.reverse(iter, tr.height) + } +} + +func (n *node) reverse( + iter func(key string, value interface{}) bool, height int, +) bool { + if height == 0 { + for i := n.numItems - 1; i >= 0; i-- { + if !iter(n.items[i].key, n.items[i].value) { + return false + } + } + return true + } + if !n.children[n.numItems].reverse(iter, height-1) { + return false + } + for i := n.numItems - 1; i >= 0; i-- { + if !iter(n.items[i].key, n.items[i].value) { + return false + } + if !n.children[i].reverse(iter, height-1) { + return false + } + } + return true +} + +// Descend the tree within the range [pivot, first] +func (tr *BTree) Descend( + pivot string, + iter func(key string, value interface{}) bool, +) { + if tr.root != nil { + tr.root.descend(pivot, iter, tr.height) + } +} + +func (n *node) descend( + pivot string, + iter func(key string, value interface{}) bool, + height int, +) bool { + i, found := n.find(pivot) + if !found { + if height > 0 { + if !n.children[i].descend(pivot, iter, height-1) { + return false + } + } + i-- + } + for ; i >= 0; i-- { + if !iter(n.items[i].key, n.items[i].value) { + return false + } + if height > 0 { + if !n.children[i].reverse(iter, height-1) { + return false + } + } + } + return true +} diff --git a/pkg/ds/btree_test.go b/pkg/ds/btree_test.go new file mode 100644 index 00000000..cb6a04ee --- /dev/null +++ b/pkg/ds/btree_test.go @@ -0,0 +1,430 @@ +package ds + +import ( + "fmt" + "math/rand" + "strings" + "testing" + "time" +) + +func init() { + seed := time.Now().UnixNano() + fmt.Printf("seed: %d\n", seed) + rand.Seed(seed) +} + +func randKeys(N int) (keys []string) { + format := fmt.Sprintf("%%0%dd", len(fmt.Sprintf("%d", N-1))) + for _, i := range rand.Perm(N) { + keys = append(keys, fmt.Sprintf(format, i)) + } + return +} + +const flatLeaf = true + +func (tr *BTree) print() { + tr.root.print(0, tr.height) +} + +func (n *node) print(level, height int) { + if n == nil { + println("NIL") + return + } + if height == 0 && flatLeaf { + fmt.Printf("%s", strings.Repeat(" ", level)) + } + for i := 0; i < n.numItems; i++ { + if height > 0 { + n.children[i].print(level+1, height-1) + } + if height > 0 || (height == 0 && !flatLeaf) { + fmt.Printf("%s%v\n", strings.Repeat(" ", level), n.items[i].key) + } else { + if i > 0 { + fmt.Printf(",") + } + fmt.Printf("%s", n.items[i].key) + } + } + if height == 0 && flatLeaf { + fmt.Printf("\n") + } + if height > 0 { + n.children[n.numItems].print(level+1, height-1) + } +} + +func (tr *BTree) deepPrint() { + fmt.Printf("%#v\n", tr) + tr.root.deepPrint(0, tr.height) +} + +func (n *node) deepPrint(level, height int) { + if n == nil { + fmt.Printf("%s %#v\n", strings.Repeat(" ", level), n) + return + } + fmt.Printf("%s count: %v\n", strings.Repeat(" ", level), n.numItems) + fmt.Printf("%s items: %v\n", strings.Repeat(" ", level), n.items) + if height > 0 { + fmt.Printf("%s child: %v\n", strings.Repeat(" ", level), n.children) + } + if height > 0 { + for i := 0; i < n.numItems; i++ { + n.children[i].deepPrint(level+1, height-1) + } + n.children[n.numItems].deepPrint(level+1, height-1) + } +} + +func stringsEquals(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + return true +} + +func TestDescend(t *testing.T) { + var tr BTree + var count int + tr.Descend("1", func(key string, value interface{}) bool { + count++ + return true + }) + if count > 0 { + t.Fatalf("expected 0, got %v", count) + } + var keys []string + for i := 0; i < 1000; i += 10 { + keys = append(keys, fmt.Sprintf("%03d", i)) + tr.Set(keys[len(keys)-1], nil) + } + var exp []string + tr.Reverse(func(key string, _ interface{}) bool { + exp = append(exp, key) + return true + }) + for i := 999; i >= 0; i-- { + var key string + key = fmt.Sprintf("%03d", i) + var all []string + tr.Descend(key, func(key string, value interface{}) bool { + all = append(all, key) + return true + }) + for len(exp) > 0 && key < exp[0] { + exp = exp[1:] + } + var count int + tr.Descend(key, func(key string, value interface{}) bool { + if count == (i+1)%maxItems { + return false + } + count++ + return true + }) + if count > len(exp) { + t.Fatalf("expected 1, got %v", count) + } + + if !stringsEquals(exp, all) { + fmt.Printf("exp: %v\n", exp) + fmt.Printf("all: %v\n", all) + t.Fatal("mismatch") + } + } +} + +func TestAscend(t *testing.T) { + var tr BTree + var count int + tr.Ascend("1", func(key string, value interface{}) bool { + count++ + return true + }) + if count > 0 { + t.Fatalf("expected 0, got %v", count) + } + var keys []string + for i := 0; i < 1000; i += 10 { + keys = append(keys, fmt.Sprintf("%03d", i)) + tr.Set(keys[len(keys)-1], nil) + } + exp := keys + for i := -1; i < 1000; i++ { + var key string + if i == -1 { + key = "" + } else { + key = fmt.Sprintf("%03d", i) + } + var all []string + tr.Ascend(key, func(key string, value interface{}) bool { + all = append(all, key) + return true + }) + + for len(exp) > 0 && key > exp[0] { + exp = exp[1:] + } + var count int + tr.Ascend(key, func(key string, value interface{}) bool { + if count == (i+1)%maxItems { + return false + } + count++ + return true + }) + if count > len(exp) { + t.Fatalf("expected 1, got %v", count) + } + if !stringsEquals(exp, all) { + t.Fatal("mismatch") + } + } +} + +func TestBTree(t *testing.T) { + N := 10000 + var tr BTree + keys := randKeys(N) + + // insert all items + for _, key := range keys { + value, replaced := tr.Set(key, key) + if replaced { + t.Fatal("expected false") + } + if value != nil { + t.Fatal("expected nil") + } + } + + // check length + if tr.Len() != len(keys) { + t.Fatalf("expected %v, got %v", len(keys), tr.Len()) + } + + // get each value + for _, key := range keys { + value, gotten := tr.Get(key) + if !gotten { + t.Fatal("expected true") + } + if value == nil || value.(string) != key { + t.Fatalf("expected '%v', got '%v'", key, value) + } + } + + // scan all items + var last string + all := make(map[string]interface{}) + tr.Scan(func(key string, value interface{}) bool { + if key <= last { + t.Fatal("out of order") + } + if value.(string) != key { + t.Fatalf("mismatch") + } + last = key + all[key] = value + return true + }) + if len(all) != len(keys) { + t.Fatalf("expected '%v', got '%v'", len(keys), len(all)) + } + + // reverse all items + var prev string + all = make(map[string]interface{}) + tr.Reverse(func(key string, value interface{}) bool { + if prev != "" && key >= prev { + t.Fatal("out of order") + } + if value.(string) != key { + t.Fatalf("mismatch") + } + prev = key + all[key] = value + return true + }) + if len(all) != len(keys) { + t.Fatalf("expected '%v', got '%v'", len(keys), len(all)) + } + + // try to get an invalid item + value, gotten := tr.Get("invalid") + if gotten { + t.Fatal("expected false") + } + if value != nil { + t.Fatal("expected nil") + } + + // scan and quit at various steps + for i := 0; i < 100; i++ { + var j int + tr.Scan(func(key string, value interface{}) bool { + if j == i { + return false + } + j++ + return true + }) + } + + // reverse and quit at various steps + for i := 0; i < 100; i++ { + var j int + tr.Reverse(func(key string, value interface{}) bool { + if j == i { + return false + } + j++ + return true + }) + } + + // delete half the items + for _, key := range keys[:len(keys)/2] { + value, deleted := tr.Delete(key) + if !deleted { + t.Fatal("expected true") + } + if value == nil || value.(string) != key { + t.Fatalf("expected '%v', got '%v'", key, value) + } + } + + // check length + if tr.Len() != len(keys)/2 { + t.Fatalf("expected %v, got %v", len(keys)/2, tr.Len()) + } + + // try delete half again + for _, key := range keys[:len(keys)/2] { + value, deleted := tr.Delete(key) + if deleted { + t.Fatal("expected false") + } + if value != nil { + t.Fatalf("expected nil") + } + } + + // try delete half again + for _, key := range keys[:len(keys)/2] { + value, deleted := tr.Delete(key) + if deleted { + t.Fatal("expected false") + } + if value != nil { + t.Fatalf("expected nil") + } + } + + // check length + if tr.Len() != len(keys)/2 { + t.Fatalf("expected %v, got %v", len(keys)/2, tr.Len()) + } + + // scan items + last = "" + all = make(map[string]interface{}) + tr.Scan(func(key string, value interface{}) bool { + if key <= last { + t.Fatal("out of order") + } + if value.(string) != key { + t.Fatalf("mismatch") + } + last = key + all[key] = value + return true + }) + if len(all) != len(keys)/2 { + t.Fatalf("expected '%v', got '%v'", len(keys), len(all)) + } + + // replace second half + for _, key := range keys[len(keys)/2:] { + value, replaced := tr.Set(key, key) + if !replaced { + t.Fatal("expected true") + } + if value == nil || value.(string) != key { + t.Fatalf("expected '%v', got '%v'", key, value) + } + } + + // delete next half the items + for _, key := range keys[len(keys)/2:] { + value, deleted := tr.Delete(key) + if !deleted { + t.Fatal("expected true") + } + if value == nil || value.(string) != key { + t.Fatalf("expected '%v', got '%v'", key, value) + } + } + + // check length + if tr.Len() != 0 { + t.Fatalf("expected %v, got %v", 0, tr.Len()) + } + + // do some stuff on an empty tree + value, gotten = tr.Get(keys[0]) + if gotten { + t.Fatal("expected false") + } + if value != nil { + t.Fatal("expected nil") + } + tr.Scan(func(key string, value interface{}) bool { + t.Fatal("should not be reached") + return true + }) + tr.Reverse(func(key string, value interface{}) bool { + t.Fatal("should not be reached") + return true + }) + + var deleted bool + value, deleted = tr.Delete("invalid") + if deleted { + t.Fatal("expected false") + } + if value != nil { + t.Fatal("expected nil") + } +} + +func BenchmarkTidwallSet(b *testing.B) { + var tr BTree + keys := randKeys(b.N) + b.ResetTimer() + for i := 0; i < b.N; i++ { + tr.Set(keys[i], nil) + } +} + +func BenchmarkTidwallGet(b *testing.B) { + var tr BTree + keys := randKeys(b.N) + for i := 0; i < b.N; i++ { + tr.Set(keys[i], nil) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + tr.Get(keys[i]) + } +}