diff --git a/internal/collection/collection.go b/internal/collection/collection.go index bfb46ee9..295fbe8a 100644 --- a/internal/collection/collection.go +++ b/internal/collection/collection.go @@ -1,7 +1,6 @@ package collection import ( - "reflect" "unsafe" "github.com/tidwall/btree" @@ -18,57 +17,6 @@ type Cursor interface { Step(count uint64) } -type itemT struct { - obj geojson.Object - _ uint32 - idLen uint32 - idData unsafe.Pointer - fields []float64 -} - -func (item *itemT) id() string { - return *(*string)((unsafe.Pointer)(&reflect.StringHeader{ - Data: uintptr(unsafe.Pointer(item.idData)), - Len: int(item.idLen), - })) -} - -func newItem(id string, obj geojson.Object) *itemT { - item := new(itemT) - item.obj = obj - item.idLen = uint32(len(id)) - if len(id) > 0 { - idData := make([]byte, len(id)) - copy(idData, id) - item.idData = unsafe.Pointer(&idData[0]) - } - return item -} - -func (item *itemT) weightAndPoints() (weight, points int) { - if objIsSpatial(item.obj) { - points = item.obj.NumPoints() - weight = points * 16 - } else { - weight = len(item.obj.String()) - } - weight += len(item.fields)*8 + len(item.id()) - return weight, points -} - -func (item *itemT) Less(other btree.Item, ctx interface{}) bool { - value1 := item.obj.String() - value2 := other.(*itemT).obj.String() - if value1 < value2 { - return true - } - if value1 > value2 { - return false - } - // the values match so we'll compare IDs, which are always unique. - return item.id() < other.(*itemT).id() -} - // Collection represents a collection of geojson objects. type Collection struct { items ptrbtree.BTree // items sorted by keys @@ -180,45 +128,40 @@ func (c *Collection) delItem(item *itemT) { func (c *Collection) Set( id string, obj geojson.Object, fields []string, values []float64, ) ( - oldObject geojson.Object, oldFields []float64, newFields []float64, + oldObj geojson.Object, oldFields []float64, newFields []float64, ) { - newItem := newItem(id, obj) + // create the new item + item := newItem(id, obj) // add the new item to main btree and remove the old one if needed - oldItemV, ok := c.items.Set(unsafe.Pointer(newItem)) + oldItemV, ok := c.items.Set(unsafe.Pointer(item)) if ok { oldItem := (*itemT)(oldItemV) + oldObj = oldItem.obj // remove old item from indexes c.delItem(oldItem) - oldObject = oldItem.obj - if len(oldItem.fields) > 0 { + if len(oldItem.fields()) > 0 { // merge old and new fields - oldFields = oldItem.fields - newItem.fields = make([]float64, len(oldFields)) - copy(newItem.fields, oldFields) + oldFields = oldItem.fields() + item.directSetFields(oldFields) } } if fields == nil && len(values) > 0 { // directly set the field values, from copy - newItem.fields = make([]float64, len(values)) - copy(newItem.fields, values) - + item.directSetFields(values) } else if len(fields) > 0 { // add new field to new item - if len(newItem.fields) == 0 { - // make exact room - newItem.fields = make([]float64, 0, len(fields)) - } - c.setFields(newItem, fields, values, false) + c.setFields(item, fields, values, false) } // add new item to indexes - c.addItem(newItem) + c.addItem(item) + // fmt.Printf("!!! %#v\n", oldObj) - return oldObject, oldFields, newItem.fields + return oldObj, oldFields, item.fields() } // Delete removes an object and returns it. @@ -234,7 +177,7 @@ func (c *Collection) Delete(id string) ( c.delItem(oldItem) - return oldItem.obj, oldItem.fields, true + return oldItem.obj, oldItem.fields(), true } // Get returns an object. @@ -248,7 +191,7 @@ func (c *Collection) Get(id string) ( } item := (*itemT)(itemV) - return item.obj, item.fields, true + return item.obj, item.fields(), true } // SetField set a field value for an object and returns that object. @@ -262,35 +205,7 @@ func (c *Collection) SetField(id, fieldName string, fieldValue float64) ( } item := (*itemT)(itemV) updated = c.setField(item, fieldName, fieldValue, true) - return item.obj, item.fields, updated, true -} - -func (c *Collection) setField( - item *itemT, fieldName string, fieldValue float64, updateWeight bool, -) (updated bool) { - idx, ok := c.fieldMap[fieldName] - if !ok { - idx = len(c.fieldMap) - c.fieldMap[fieldName] = idx - } - - if idx >= len(item.fields) { - // grow the fields slice - oldLen := len(item.fields) - for idx >= len(item.fields) { - item.fields = append(item.fields, 0) - } - if updateWeight { - c.weight += (len(item.fields) - oldLen) * 8 - } - item.fields[idx] = fieldValue - updated = true - } else if item.fields[idx] != fieldValue { - // existing field needs updating - item.fields[idx] = fieldValue - updated = true - } - return updated + return item.obj, item.fields(), updated, true } // SetFields is similar to SetField, just setting multiple fields at once @@ -305,23 +220,7 @@ func (c *Collection) SetFields( updatedCount = c.setFields(item, fieldNames, fieldValues, true) - return item.obj, item.fields, updatedCount, true -} - -func (c *Collection) setFields( - item *itemT, fieldNames []string, fieldValues []float64, updateWeight bool, -) (updatedCount int) { - - for i, fieldName := range fieldNames { - var fieldValue float64 - if i < len(fieldValues) { - fieldValue = fieldValues[i] - } - if c.setField(item, fieldName, fieldValue, updateWeight) { - updatedCount++ - } - } - return updatedCount + return item.obj, item.fields(), updatedCount, true } // FieldMap return a maps of the field names. @@ -358,7 +257,7 @@ func (c *Collection) Scan(desc bool, cursor Cursor, cursor.Step(1) } iitm := (*itemT)(ptr) - keepon = iterator(iitm.id(), iitm.obj, iitm.fields) + keepon = iterator(iitm.id(), iitm.obj, iitm.fields()) return keepon } if desc { @@ -398,7 +297,7 @@ func (c *Collection) ScanRange(start, end string, desc bool, cursor Cursor, return false } } - keepon = iterator(iitm.id(), iitm.obj, iitm.fields) + keepon = iterator(iitm.id(), iitm.obj, iitm.fields()) return keepon } @@ -430,7 +329,7 @@ func (c *Collection) SearchValues(desc bool, cursor Cursor, cursor.Step(1) } iitm := item.(*itemT) - keepon = iterator(iitm.id(), iitm.obj, iitm.fields) + keepon = iterator(iitm.id(), iitm.obj, iitm.fields()) return keepon } if desc { @@ -462,7 +361,7 @@ func (c *Collection) SearchValuesRange(start, end string, desc bool, cursor.Step(1) } iitm := item.(*itemT) - keepon = iterator(iitm.id(), iitm.obj, iitm.fields) + keepon = iterator(iitm.id(), iitm.obj, iitm.fields()) return keepon } if desc { @@ -498,7 +397,7 @@ func (c *Collection) ScanGreaterOrEqual(id string, desc bool, cursor.Step(1) } iitm := (*itemT)(ptr) - keepon = iterator(iitm.id(), iitm.obj, iitm.fields) + keepon = iterator(iitm.id(), iitm.obj, iitm.fields()) return keepon } if desc { @@ -519,7 +418,7 @@ func (c *Collection) geoSearch( []float64{rect.Max.X, rect.Max.Y}, func(_, _ []float64, itemv unsafe.Pointer) bool { item := (*itemT)(itemv) - alive = iter(item.id(), item.obj, item.fields) + alive = iter(item.id(), item.obj, item.fields()) return alive }, ) @@ -744,7 +643,7 @@ func (c *Collection) Nearby( cursor.Step(1) } item := (*itemT)(itemv) - alive = iter(item.id(), item.obj, item.fields) + alive = iter(item.id(), item.obj, item.fields()) return alive }, ) diff --git a/internal/collection/collection_test.go b/internal/collection/collection_test.go index 21325676..019273cf 100644 --- a/internal/collection/collection_test.go +++ b/internal/collection/collection_test.go @@ -38,6 +38,56 @@ func bounds(c *Collection) geometry.Rect { } } +func TestStuff(t *testing.T) { + c := New() + key := "str" + str1 := String("hello") + str2 := String("jello") + { + println("A") + oldObj, _, _ := c.Set(key, str1, []string{"a", "b", "c"}, nil) + println("B") + expect(t, oldObj == nil) + } + { + println("C") + oldObj, _, _ := c.Set(key, str2, nil, nil) //[]float64{4, 5, 6}) + println("D") + expect(t, oldObj == str1) + // expect(t, reflect.DeepEqual(oldFlds, nil)) //[]float64{1, 2, 3})) + // expect(t, reflect.DeepEqual(newFlds, []float64{1, 2, 3, 4, 5, 6})) + } + { + // fValues := []float64{7, 8, 9, 10, 11, 12} + println("E") + oldObj, _, _ := c.Set(key, str1, nil, nil) + println("F") + expect(t, oldObj == str2) + // expect(t, reflect.DeepEqual(oldFlds, []float64{4, 5, 6})) + // expect(t, reflect.DeepEqual(newFlds, []float64{7, 8, 9, 10, 11, 12})) + } + + // var old geojson.Object + // c := New() + // old, _, _ = c.Set("hello1", String("world1"), nil, nil) + // expect(t, old == nil) + // old, _, _ = c.Set("hello2", String("world2"), nil, nil) + // expect(t, old == nil) + // old, _, _ = c.Set("hello3", String("world3"), nil, nil) + // expect(t, old == nil) + // old, _, _ = c.Set("hello4", String("world4"), nil, nil) + // expect(t, old == nil) + + // old, _, _ = c.Set("hello1", String("planet1"), nil, nil) + // expect(t, old == String("world1")) + // old, _, _ = c.Set("hello2", String("planet2"), nil, nil) + // expect(t, old == String("world2")) + // old, _, _ = c.Set("hello3", String("planet3"), nil, nil) + // expect(t, old == String("world3")) + // old, _, _ = c.Set("hello4", String("planet4"), nil, nil) + // expect(t, old == String("world4")) +} + func TestCollectionNewCollection(t *testing.T) { const numItems = 10000 objs := make(map[string]geojson.Object) @@ -114,24 +164,36 @@ func TestCollectionSet(t *testing.T) { t.Run("Fields", func(t *testing.T) { c := New() str1 := String("hello") - fNames := []string{"a", "b", "c"} - fValues := []float64{1, 2, 3} - oldObj, oldFlds, newFlds := c.Set("str", str1, fNames, fValues) - expect(t, oldObj == nil) - expect(t, len(oldFlds) == 0) - expect(t, reflect.DeepEqual(newFlds, fValues)) - str2 := String("hello") - fNames = []string{"d", "e", "f"} - fValues = []float64{4, 5, 6} - oldObj, oldFlds, newFlds = c.Set("str", str2, fNames, fValues) - expect(t, oldObj == str1) - expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3})) - expect(t, reflect.DeepEqual(newFlds, []float64{1, 2, 3, 4, 5, 6})) - fValues = []float64{7, 8, 9, 10, 11, 12} - oldObj, oldFlds, newFlds = c.Set("str", str1, nil, fValues) - expect(t, oldObj == str2) - expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3, 4, 5, 6})) - expect(t, reflect.DeepEqual(newFlds, []float64{7, 8, 9, 10, 11, 12})) + str2 := String("jello") + { + fNames := []string{"a", "b", "c"} + fValues := []float64{1, 2, 3} + println("A") + oldObj, oldFlds, newFlds := c.Set("str", str1, fNames, fValues) + println("B") + expect(t, oldObj == nil) + expect(t, len(oldFlds) == 0) + expect(t, reflect.DeepEqual(newFlds, fValues)) + } + { + fNames := []string{"d", "e", "f"} + fValues := []float64{4, 5, 6} + println("C") + oldObj, oldFlds, newFlds := c.Set("str", str2, fNames, fValues) + println("D") + expect(t, oldObj == str1) + expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3})) + expect(t, reflect.DeepEqual(newFlds, []float64{1, 2, 3, 4, 5, 6})) + } + { + fValues := []float64{7, 8, 9, 10, 11, 12} + println("E") + oldObj, oldFlds, newFlds := c.Set("str", str1, nil, fValues) + println("F") + expect(t, oldObj == str2) + expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3, 4, 5, 6})) + expect(t, reflect.DeepEqual(newFlds, []float64{7, 8, 9, 10, 11, 12})) + } }) t.Run("Delete", func(t *testing.T) { c := New() diff --git a/internal/collection/item.go b/internal/collection/item.go new file mode 100644 index 00000000..90d7662e --- /dev/null +++ b/internal/collection/item.go @@ -0,0 +1,140 @@ +package collection + +import ( + "reflect" + "unsafe" + + "github.com/tidwall/btree" + "github.com/tidwall/geojson" +) + +type itemT struct { + obj geojson.Object + idLen uint32 // id block size in bytes + fieldsLen uint32 // fields block size in bytes, not num of fields + data unsafe.Pointer +} + +func (item *itemT) id() string { + return *(*string)((unsafe.Pointer)(&reflect.StringHeader{ + Data: uintptr(unsafe.Pointer(item.data)), + Len: int(item.idLen), + })) +} + +func (item *itemT) fields() []float64 { + return *(*[]float64)((unsafe.Pointer)(&reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(item.data)) + uintptr(item.idLen), + Len: int(item.fieldsLen) / 8, + Cap: int(item.fieldsLen) / 8, + })) +} + +func (item *itemT) dataBytes() []byte { + return *(*[]byte)((unsafe.Pointer)(&reflect.SliceHeader{ + Data: uintptr(unsafe.Pointer(item.data)), + Len: int(item.fieldsLen) + int(item.idLen), + Cap: int(item.fieldsLen) + int(item.idLen), + })) +} + +func newItem(id string, obj geojson.Object) *itemT { + item := new(itemT) + item.obj = obj + item.idLen = uint32(len(id)) + if len(id) > 0 { + data := make([]byte, len(id)) + copy(data, id) + item.data = unsafe.Pointer(&data[0]) + } + return item +} + +func (item *itemT) weightAndPoints() (weight, points int) { + if objIsSpatial(item.obj) { + points = item.obj.NumPoints() + weight = points * 16 + } else { + weight = len(item.obj.String()) + } + weight += int(item.fieldsLen + item.idLen) + return weight, points +} + +func (item *itemT) Less(other btree.Item, ctx interface{}) bool { + value1 := item.obj.String() + value2 := other.(*itemT).obj.String() + if value1 < value2 { + return true + } + if value1 > value2 { + return false + } + // the values match so we'll compare IDs, which are always unique. + return item.id() < other.(*itemT).id() +} + +// directSetFields copies fields, overwriting previous fields +func (item *itemT) directSetFields(fields []float64) { + n := int(item.idLen) + len(fields)*8 + item.fieldsLen = uint32(len(fields) * 8) + if n > 0 { + newData := make([]byte, int(item.idLen)+len(fields)*8) + item.data = unsafe.Pointer(&newData[0]) + copy(newData, item.id()) + copy(item.fields(), fields) + } else { + item.data = nil + } +} + +func (c *Collection) setField( + item *itemT, fieldName string, fieldValue float64, updateWeight bool, +) (updated bool) { + idx, ok := c.fieldMap[fieldName] + if !ok { + idx = len(c.fieldMap) + c.fieldMap[fieldName] = idx + } + itemFields := item.fields() + if idx >= len(itemFields) { + // make room for new field + + oldLen := len(itemFields) + // print(c.weight) + data := make([]byte, int(item.idLen)+(idx+1)*8) + copy(data, item.dataBytes()) + item.fieldsLen = uint32((idx + 1) * 8) + item.data = unsafe.Pointer(&data[0]) + itemFields := item.fields() + if updateWeight { + c.weight += (len(itemFields) - oldLen) * 8 + } + // print(":") + // print(c.weight) + // println() + itemFields[idx] = fieldValue + updated = true + } else if itemFields[idx] != fieldValue { + // existing field needs updating + itemFields[idx] = fieldValue + updated = true + } + return updated +} +func (c *Collection) setFields( + item *itemT, fieldNames []string, fieldValues []float64, updateWeight bool, +) (updatedCount int) { + // TODO: optimize to predict the item data growth. + // TODO: do all sets here, instead of calling setFields in a loop + for i, fieldName := range fieldNames { + var fieldValue float64 + if i < len(fieldValues) { + fieldValue = fieldValues[i] + } + if c.setField(item, fieldName, fieldValue, updateWeight) { + updatedCount++ + } + } + return updatedCount +} diff --git a/internal/collection/ptrbtree/btree.go b/internal/collection/ptrbtree/btree.go index 9c09f9b3..65b04641 100644 --- a/internal/collection/ptrbtree/btree.go +++ b/internal/collection/ptrbtree/btree.go @@ -15,9 +15,9 @@ type btreeItem struct { // keyedItem must match layout of ../collection/itemT, otherwise // there's a risk for memory corruption. type keyedItem struct { - _ interface{} - _ uint32 + obj interface{} keyLen uint32 + _ uint32 data unsafe.Pointer } diff --git a/internal/collection/ptrbtree/btree_test.go b/internal/collection/ptrbtree/btree_test.go new file mode 100644 index 00000000..a12aa463 --- /dev/null +++ b/internal/collection/ptrbtree/btree_test.go @@ -0,0 +1,589 @@ +package ptrbtree + +import ( + "fmt" + "math/rand" + "sort" + "strings" + "testing" + "time" + "unsafe" +) + +func makeItem(key string, obj interface{}) unsafe.Pointer { + item := new(keyedItem) + item.obj = obj + if len(key) > 0 { + data := make([]byte, len(key)) + copy(data, key) + item.keyLen = uint32(len(key)) + item.data = unsafe.Pointer(&data[0]) + } + return unsafe.Pointer(item) +} + +func itemKey(ptr unsafe.Pointer) string { + return (btreeItem{ptr}).key() +} + +func itemValue(ptr unsafe.Pointer) interface{} { + return (*keyedItem)(ptr).obj +} + +func init() { + seed := time.Now().UnixNano() + fmt.Printf("seed: %d\n", seed) + rand.Seed(seed) +} + +func randKeys(N int) (keys []string) { + format := fmt.Sprintf("%%0%dd", len(fmt.Sprintf("%d", N-1))) + for _, i := range rand.Perm(N) { + keys = append(keys, fmt.Sprintf(format, i)) + } + return +} + +const flatLeaf = true + +func (tr *BTree) print() { + tr.root.print(0, tr.height) +} + +func (n *node) print(level, height int) { + if n == nil { + println("NIL") + return + } + if height == 0 && flatLeaf { + fmt.Printf("%s", strings.Repeat(" ", level)) + } + for i := 0; i < n.numItems; i++ { + if height > 0 { + n.children[i].print(level+1, height-1) + } + if height > 0 || (height == 0 && !flatLeaf) { + fmt.Printf("%s%v\n", strings.Repeat(" ", level), n.items[i].key()) + } else { + if i > 0 { + fmt.Printf(",") + } + fmt.Printf("%s", n.items[i].key()) + } + } + if height == 0 && flatLeaf { + fmt.Printf("\n") + } + if height > 0 { + n.children[n.numItems].print(level+1, height-1) + } +} + +func (tr *BTree) deepPrint() { + fmt.Printf("%#v\n", tr) + tr.root.deepPrint(0, tr.height) +} + +func (n *node) deepPrint(level, height int) { + if n == nil { + fmt.Printf("%s %#v\n", strings.Repeat(" ", level), n) + return + } + fmt.Printf("%s count: %v\n", strings.Repeat(" ", level), n.numItems) + fmt.Printf("%s items: %v\n", strings.Repeat(" ", level), n.items) + if height > 0 { + fmt.Printf("%s child: %v\n", strings.Repeat(" ", level), n.children) + } + if height > 0 { + for i := 0; i < n.numItems; i++ { + n.children[i].deepPrint(level+1, height-1) + } + n.children[n.numItems].deepPrint(level+1, height-1) + } +} + +func stringsEquals(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + return true +} + +func TestDescend(t *testing.T) { + var tr BTree + var count int + tr.Descend("1", func(ptr unsafe.Pointer) bool { + count++ + return true + }) + if count > 0 { + t.Fatalf("expected 0, got %v", count) + } + var keys []string + for i := 0; i < 1000; i += 10 { + keys = append(keys, fmt.Sprintf("%03d", i)) + tr.Set(makeItem(keys[len(keys)-1], nil)) + } + var exp []string + tr.Reverse(func(ptr unsafe.Pointer) bool { + exp = append(exp, itemKey(ptr)) + return true + }) + for i := 999; i >= 0; i-- { + var key string + key = fmt.Sprintf("%03d", i) + var all []string + tr.Descend(key, func(ptr unsafe.Pointer) bool { + all = append(all, itemKey(ptr)) + return true + }) + for len(exp) > 0 && key < exp[0] { + exp = exp[1:] + } + var count int + tr.Descend(key, func(ptr unsafe.Pointer) bool { + if count == (i+1)%maxItems { + return false + } + count++ + return true + }) + if count > len(exp) { + t.Fatalf("expected 1, got %v", count) + } + + if !stringsEquals(exp, all) { + fmt.Printf("exp: %v\n", exp) + fmt.Printf("all: %v\n", all) + t.Fatal("mismatch") + } + } +} + +func TestAscend(t *testing.T) { + var tr BTree + var count int + tr.Ascend("1", func(ptr unsafe.Pointer) bool { + count++ + return true + }) + if count > 0 { + t.Fatalf("expected 0, got %v", count) + } + var keys []string + for i := 0; i < 1000; i += 10 { + keys = append(keys, fmt.Sprintf("%03d", i)) + tr.Set(makeItem(keys[len(keys)-1], nil)) + } + exp := keys + for i := -1; i < 1000; i++ { + var key string + if i == -1 { + key = "" + } else { + key = fmt.Sprintf("%03d", i) + } + var all []string + tr.Ascend(key, func(ptr unsafe.Pointer) bool { + all = append(all, itemKey(ptr)) + return true + }) + + for len(exp) > 0 && key > exp[0] { + exp = exp[1:] + } + var count int + tr.Ascend(key, func(ptr unsafe.Pointer) bool { + if count == (i+1)%maxItems { + return false + } + count++ + return true + }) + if count > len(exp) { + t.Fatalf("expected 1, got %v", count) + } + if !stringsEquals(exp, all) { + t.Fatal("mismatch") + } + } +} + +func TestBTree(t *testing.T) { + N := 10000 + var tr BTree + keys := randKeys(N) + + // insert all items + for _, key := range keys { + value, replaced := tr.Set(makeItem(key, key)) + if replaced { + t.Fatal("expected false") + } + if value != nil { + t.Fatal("expected nil") + } + } + + // check length + if tr.Len() != len(keys) { + t.Fatalf("expected %v, got %v", len(keys), tr.Len()) + } + + // get each value + for _, key := range keys { + value, gotten := tr.Get(key) + if !gotten { + t.Fatal("expected true") + } + if value == nil || itemValue(value) != key { + t.Fatalf("expected '%v', got '%v'", key, value) + } + } + + // scan all items + var last string + all := make(map[string]interface{}) + tr.Scan(func(ptr unsafe.Pointer) bool { + if itemKey(ptr) <= last { + t.Fatal("out of order") + } + if itemValue(ptr).(string) != itemKey(ptr) { + t.Fatalf("mismatch") + } + last = itemKey(ptr) + all[itemKey(ptr)] = itemValue(ptr) + return true + }) + if len(all) != len(keys) { + t.Fatalf("expected '%v', got '%v'", len(keys), len(all)) + } + + // reverse all items + var prev string + all = make(map[string]interface{}) + tr.Reverse(func(ptr unsafe.Pointer) bool { + if prev != "" && itemKey(ptr) >= prev { + t.Fatal("out of order") + } + if itemValue(ptr).(string) != itemKey(ptr) { + t.Fatalf("mismatch") + } + prev = itemKey(ptr) + all[itemKey(ptr)] = itemValue(ptr) + return true + }) + if len(all) != len(keys) { + t.Fatalf("expected '%v', got '%v'", len(keys), len(all)) + } + + // try to get an invalid item + value, gotten := tr.Get("invalid") + if gotten { + t.Fatal("expected false") + } + if value != nil { + t.Fatal("expected nil") + } + + // scan and quit at various steps + for i := 0; i < 100; i++ { + var j int + tr.Scan(func(ptr unsafe.Pointer) bool { + if j == i { + return false + } + j++ + return true + }) + } + + // reverse and quit at various steps + for i := 0; i < 100; i++ { + var j int + tr.Reverse(func(ptr unsafe.Pointer) bool { + if j == i { + return false + } + j++ + return true + }) + } + + // delete half the items + for _, key := range keys[:len(keys)/2] { + value, deleted := tr.Delete(key) + if !deleted { + t.Fatal("expected true") + } + if value == nil || itemValue(value).(string) != key { + t.Fatalf("expected '%v', got '%v'", key, value) + } + } + + // check length + if tr.Len() != len(keys)/2 { + t.Fatalf("expected %v, got %v", len(keys)/2, tr.Len()) + } + + // try delete half again + for _, key := range keys[:len(keys)/2] { + value, deleted := tr.Delete(key) + if deleted { + t.Fatal("expected false") + } + if value != nil { + t.Fatalf("expected nil") + } + } + + // try delete half again + for _, key := range keys[:len(keys)/2] { + value, deleted := tr.Delete(key) + if deleted { + t.Fatal("expected false") + } + if value != nil { + t.Fatalf("expected nil") + } + } + + // check length + if tr.Len() != len(keys)/2 { + t.Fatalf("expected %v, got %v", len(keys)/2, tr.Len()) + } + + // scan items + last = "" + all = make(map[string]interface{}) + tr.Scan(func(ptr unsafe.Pointer) bool { + if itemKey(ptr) <= last { + t.Fatal("out of order") + } + if itemValue(ptr).(string) != itemKey(ptr) { + t.Fatalf("mismatch") + } + last = itemKey(ptr) + all[itemKey(ptr)] = itemValue(ptr) + return true + }) + if len(all) != len(keys)/2 { + t.Fatalf("expected '%v', got '%v'", len(keys), len(all)) + } + + // replace second half + for _, key := range keys[len(keys)/2:] { + value, replaced := tr.Set(makeItem(key, key)) + if !replaced { + t.Fatal("expected true") + } + if value == nil || itemValue(value).(string) != key { + t.Fatalf("expected '%v', got '%v'", key, value) + } + } + + // delete next half the items + for _, key := range keys[len(keys)/2:] { + value, deleted := tr.Delete(key) + if !deleted { + t.Fatal("expected true") + } + if value == nil || itemValue(value).(string) != key { + t.Fatalf("expected '%v', got '%v'", key, value) + } + } + + // check length + if tr.Len() != 0 { + t.Fatalf("expected %v, got %v", 0, tr.Len()) + } + + // do some stuff on an empty tree + value, gotten = tr.Get(keys[0]) + if gotten { + t.Fatal("expected false") + } + if value != nil { + t.Fatal("expected nil") + } + tr.Scan(func(ptr unsafe.Pointer) bool { + t.Fatal("should not be reached") + return true + }) + tr.Reverse(func(ptr unsafe.Pointer) bool { + t.Fatal("should not be reached") + return true + }) + + var deleted bool + value, deleted = tr.Delete("invalid") + if deleted { + t.Fatal("expected false") + } + if value != nil { + t.Fatal("expected nil") + } +} + +func BenchmarkTidwallSequentialSet(b *testing.B) { + var tr BTree + keys := randKeys(b.N) + sort.Strings(keys) + b.ResetTimer() + for i := 0; i < b.N; i++ { + tr.Set(makeItem(keys[i], nil)) + } +} + +func BenchmarkTidwallSequentialGet(b *testing.B) { + var tr BTree + keys := randKeys(b.N) + sort.Strings(keys) + for i := 0; i < b.N; i++ { + tr.Set(makeItem(keys[i], nil)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + tr.Get(keys[i]) + } +} + +func BenchmarkTidwallRandomSet(b *testing.B) { + var tr BTree + keys := randKeys(b.N) + b.ResetTimer() + for i := 0; i < b.N; i++ { + tr.Set(makeItem(keys[i], nil)) + } +} + +func BenchmarkTidwallRandomGet(b *testing.B) { + var tr BTree + keys := randKeys(b.N) + for i := 0; i < b.N; i++ { + tr.Set(makeItem(keys[i], nil)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + tr.Get(keys[i]) + } +} + +// type googleKind struct { +// key string +// } + +// func (a *googleKind) Less(b btree.Item) bool { +// return a.key < b.(*googleKind).key +// } + +// func BenchmarkGoogleSequentialSet(b *testing.B) { +// tr := btree.New(32) +// keys := randKeys(b.N) +// sort.Strings(keys) +// gkeys := make([]*googleKind, len(keys)) +// for i := 0; i < b.N; i++ { +// gkeys[i] = &googleKind{keys[i]} +// } +// b.ResetTimer() +// for i := 0; i < b.N; i++ { +// tr.ReplaceOrInsert(gkeys[i]) +// } +// } + +// func BenchmarkGoogleSequentialGet(b *testing.B) { +// tr := btree.New(32) +// keys := randKeys(b.N) +// gkeys := make([]*googleKind, len(keys)) +// for i := 0; i < b.N; i++ { +// gkeys[i] = &googleKind{keys[i]} +// } +// for i := 0; i < b.N; i++ { +// tr.ReplaceOrInsert(gkeys[i]) +// } +// sort.Strings(keys) +// b.ResetTimer() +// for i := 0; i < b.N; i++ { +// tr.Get(gkeys[i]) +// } +// } + +// func BenchmarkGoogleRandomSet(b *testing.B) { +// tr := btree.New(32) +// keys := randKeys(b.N) +// gkeys := make([]*googleKind, len(keys)) +// for i := 0; i < b.N; i++ { +// gkeys[i] = &googleKind{keys[i]} +// } +// b.ResetTimer() +// for i := 0; i < b.N; i++ { +// tr.ReplaceOrInsert(gkeys[i]) +// } +// } + +// func BenchmarkGoogleRandomGet(b *testing.B) { +// tr := btree.New(32) +// keys := randKeys(b.N) +// gkeys := make([]*googleKind, len(keys)) +// for i := 0; i < b.N; i++ { +// gkeys[i] = &googleKind{keys[i]} +// } +// for i := 0; i < b.N; i++ { +// tr.ReplaceOrInsert(gkeys[i]) +// } +// b.ResetTimer() +// for i := 0; i < b.N; i++ { +// tr.Get(gkeys[i]) +// } +// } + +func TestBTreeOne(t *testing.T) { + var tr BTree + tr.Set(makeItem("1", "1")) + tr.Delete("1") + tr.Set(makeItem("1", "1")) + tr.Delete("1") + tr.Set(makeItem("1", "1")) + tr.Delete("1") +} + +func TestBTree256(t *testing.T) { + var tr BTree + var n int + for j := 0; j < 2; j++ { + for _, i := range rand.Perm(256) { + tr.Set(makeItem(fmt.Sprintf("%d", i), i)) + n++ + if tr.Len() != n { + t.Fatalf("expected 256, got %d", n) + } + } + for _, i := range rand.Perm(256) { + v, ok := tr.Get(fmt.Sprintf("%d", i)) + if !ok { + t.Fatal("expected true") + } + if itemValue(v).(int) != i { + t.Fatalf("expected %d, got %d", i, itemValue(v).(int)) + } + } + for _, i := range rand.Perm(256) { + tr.Delete(fmt.Sprintf("%d", i)) + n-- + if tr.Len() != n { + t.Fatalf("expected 256, got %d", n) + } + } + for _, i := range rand.Perm(256) { + _, ok := tr.Get(fmt.Sprintf("%d", i)) + if ok { + t.Fatal("expected false") + } + } + } +}