Updated tests

This commit is contained in:
tidwall 2019-02-12 15:24:22 -07:00
parent d115b40d71
commit b5dcb18c54
5 changed files with 835 additions and 145 deletions

View File

@ -1,7 +1,6 @@
package collection package collection
import ( import (
"reflect"
"unsafe" "unsafe"
"github.com/tidwall/btree" "github.com/tidwall/btree"
@ -18,57 +17,6 @@ type Cursor interface {
Step(count uint64) Step(count uint64)
} }
type itemT struct {
obj geojson.Object
_ uint32
idLen uint32
idData unsafe.Pointer
fields []float64
}
func (item *itemT) id() string {
return *(*string)((unsafe.Pointer)(&reflect.StringHeader{
Data: uintptr(unsafe.Pointer(item.idData)),
Len: int(item.idLen),
}))
}
func newItem(id string, obj geojson.Object) *itemT {
item := new(itemT)
item.obj = obj
item.idLen = uint32(len(id))
if len(id) > 0 {
idData := make([]byte, len(id))
copy(idData, id)
item.idData = unsafe.Pointer(&idData[0])
}
return item
}
func (item *itemT) weightAndPoints() (weight, points int) {
if objIsSpatial(item.obj) {
points = item.obj.NumPoints()
weight = points * 16
} else {
weight = len(item.obj.String())
}
weight += len(item.fields)*8 + len(item.id())
return weight, points
}
func (item *itemT) Less(other btree.Item, ctx interface{}) bool {
value1 := item.obj.String()
value2 := other.(*itemT).obj.String()
if value1 < value2 {
return true
}
if value1 > value2 {
return false
}
// the values match so we'll compare IDs, which are always unique.
return item.id() < other.(*itemT).id()
}
// Collection represents a collection of geojson objects. // Collection represents a collection of geojson objects.
type Collection struct { type Collection struct {
items ptrbtree.BTree // items sorted by keys items ptrbtree.BTree // items sorted by keys
@ -180,45 +128,40 @@ func (c *Collection) delItem(item *itemT) {
func (c *Collection) Set( func (c *Collection) Set(
id string, obj geojson.Object, fields []string, values []float64, id string, obj geojson.Object, fields []string, values []float64,
) ( ) (
oldObject geojson.Object, oldFields []float64, newFields []float64, oldObj geojson.Object, oldFields []float64, newFields []float64,
) { ) {
newItem := newItem(id, obj) // create the new item
item := newItem(id, obj)
// add the new item to main btree and remove the old one if needed // add the new item to main btree and remove the old one if needed
oldItemV, ok := c.items.Set(unsafe.Pointer(newItem)) oldItemV, ok := c.items.Set(unsafe.Pointer(item))
if ok { if ok {
oldItem := (*itemT)(oldItemV) oldItem := (*itemT)(oldItemV)
oldObj = oldItem.obj
// remove old item from indexes // remove old item from indexes
c.delItem(oldItem) c.delItem(oldItem)
oldObject = oldItem.obj if len(oldItem.fields()) > 0 {
if len(oldItem.fields) > 0 {
// merge old and new fields // merge old and new fields
oldFields = oldItem.fields oldFields = oldItem.fields()
newItem.fields = make([]float64, len(oldFields)) item.directSetFields(oldFields)
copy(newItem.fields, oldFields)
} }
} }
if fields == nil && len(values) > 0 { if fields == nil && len(values) > 0 {
// directly set the field values, from copy // directly set the field values, from copy
newItem.fields = make([]float64, len(values)) item.directSetFields(values)
copy(newItem.fields, values)
} else if len(fields) > 0 { } else if len(fields) > 0 {
// add new field to new item // add new field to new item
if len(newItem.fields) == 0 { c.setFields(item, fields, values, false)
// make exact room
newItem.fields = make([]float64, 0, len(fields))
}
c.setFields(newItem, fields, values, false)
} }
// add new item to indexes // add new item to indexes
c.addItem(newItem) c.addItem(item)
// fmt.Printf("!!! %#v\n", oldObj)
return oldObject, oldFields, newItem.fields return oldObj, oldFields, item.fields()
} }
// Delete removes an object and returns it. // Delete removes an object and returns it.
@ -234,7 +177,7 @@ func (c *Collection) Delete(id string) (
c.delItem(oldItem) c.delItem(oldItem)
return oldItem.obj, oldItem.fields, true return oldItem.obj, oldItem.fields(), true
} }
// Get returns an object. // Get returns an object.
@ -248,7 +191,7 @@ func (c *Collection) Get(id string) (
} }
item := (*itemT)(itemV) item := (*itemT)(itemV)
return item.obj, item.fields, true return item.obj, item.fields(), true
} }
// SetField set a field value for an object and returns that object. // SetField set a field value for an object and returns that object.
@ -262,35 +205,7 @@ func (c *Collection) SetField(id, fieldName string, fieldValue float64) (
} }
item := (*itemT)(itemV) item := (*itemT)(itemV)
updated = c.setField(item, fieldName, fieldValue, true) updated = c.setField(item, fieldName, fieldValue, true)
return item.obj, item.fields, updated, true return item.obj, item.fields(), updated, true
}
func (c *Collection) setField(
item *itemT, fieldName string, fieldValue float64, updateWeight bool,
) (updated bool) {
idx, ok := c.fieldMap[fieldName]
if !ok {
idx = len(c.fieldMap)
c.fieldMap[fieldName] = idx
}
if idx >= len(item.fields) {
// grow the fields slice
oldLen := len(item.fields)
for idx >= len(item.fields) {
item.fields = append(item.fields, 0)
}
if updateWeight {
c.weight += (len(item.fields) - oldLen) * 8
}
item.fields[idx] = fieldValue
updated = true
} else if item.fields[idx] != fieldValue {
// existing field needs updating
item.fields[idx] = fieldValue
updated = true
}
return updated
} }
// SetFields is similar to SetField, just setting multiple fields at once // SetFields is similar to SetField, just setting multiple fields at once
@ -305,23 +220,7 @@ func (c *Collection) SetFields(
updatedCount = c.setFields(item, fieldNames, fieldValues, true) updatedCount = c.setFields(item, fieldNames, fieldValues, true)
return item.obj, item.fields, updatedCount, true return item.obj, item.fields(), updatedCount, true
}
func (c *Collection) setFields(
item *itemT, fieldNames []string, fieldValues []float64, updateWeight bool,
) (updatedCount int) {
for i, fieldName := range fieldNames {
var fieldValue float64
if i < len(fieldValues) {
fieldValue = fieldValues[i]
}
if c.setField(item, fieldName, fieldValue, updateWeight) {
updatedCount++
}
}
return updatedCount
} }
// FieldMap return a maps of the field names. // FieldMap return a maps of the field names.
@ -358,7 +257,7 @@ func (c *Collection) Scan(desc bool, cursor Cursor,
cursor.Step(1) cursor.Step(1)
} }
iitm := (*itemT)(ptr) iitm := (*itemT)(ptr)
keepon = iterator(iitm.id(), iitm.obj, iitm.fields) keepon = iterator(iitm.id(), iitm.obj, iitm.fields())
return keepon return keepon
} }
if desc { if desc {
@ -398,7 +297,7 @@ func (c *Collection) ScanRange(start, end string, desc bool, cursor Cursor,
return false return false
} }
} }
keepon = iterator(iitm.id(), iitm.obj, iitm.fields) keepon = iterator(iitm.id(), iitm.obj, iitm.fields())
return keepon return keepon
} }
@ -430,7 +329,7 @@ func (c *Collection) SearchValues(desc bool, cursor Cursor,
cursor.Step(1) cursor.Step(1)
} }
iitm := item.(*itemT) iitm := item.(*itemT)
keepon = iterator(iitm.id(), iitm.obj, iitm.fields) keepon = iterator(iitm.id(), iitm.obj, iitm.fields())
return keepon return keepon
} }
if desc { if desc {
@ -462,7 +361,7 @@ func (c *Collection) SearchValuesRange(start, end string, desc bool,
cursor.Step(1) cursor.Step(1)
} }
iitm := item.(*itemT) iitm := item.(*itemT)
keepon = iterator(iitm.id(), iitm.obj, iitm.fields) keepon = iterator(iitm.id(), iitm.obj, iitm.fields())
return keepon return keepon
} }
if desc { if desc {
@ -498,7 +397,7 @@ func (c *Collection) ScanGreaterOrEqual(id string, desc bool,
cursor.Step(1) cursor.Step(1)
} }
iitm := (*itemT)(ptr) iitm := (*itemT)(ptr)
keepon = iterator(iitm.id(), iitm.obj, iitm.fields) keepon = iterator(iitm.id(), iitm.obj, iitm.fields())
return keepon return keepon
} }
if desc { if desc {
@ -519,7 +418,7 @@ func (c *Collection) geoSearch(
[]float64{rect.Max.X, rect.Max.Y}, []float64{rect.Max.X, rect.Max.Y},
func(_, _ []float64, itemv unsafe.Pointer) bool { func(_, _ []float64, itemv unsafe.Pointer) bool {
item := (*itemT)(itemv) item := (*itemT)(itemv)
alive = iter(item.id(), item.obj, item.fields) alive = iter(item.id(), item.obj, item.fields())
return alive return alive
}, },
) )
@ -744,7 +643,7 @@ func (c *Collection) Nearby(
cursor.Step(1) cursor.Step(1)
} }
item := (*itemT)(itemv) item := (*itemT)(itemv)
alive = iter(item.id(), item.obj, item.fields) alive = iter(item.id(), item.obj, item.fields())
return alive return alive
}, },
) )

View File

@ -38,6 +38,56 @@ func bounds(c *Collection) geometry.Rect {
} }
} }
func TestStuff(t *testing.T) {
c := New()
key := "str"
str1 := String("hello")
str2 := String("jello")
{
println("A")
oldObj, _, _ := c.Set(key, str1, []string{"a", "b", "c"}, nil)
println("B")
expect(t, oldObj == nil)
}
{
println("C")
oldObj, _, _ := c.Set(key, str2, nil, nil) //[]float64{4, 5, 6})
println("D")
expect(t, oldObj == str1)
// expect(t, reflect.DeepEqual(oldFlds, nil)) //[]float64{1, 2, 3}))
// expect(t, reflect.DeepEqual(newFlds, []float64{1, 2, 3, 4, 5, 6}))
}
{
// fValues := []float64{7, 8, 9, 10, 11, 12}
println("E")
oldObj, _, _ := c.Set(key, str1, nil, nil)
println("F")
expect(t, oldObj == str2)
// expect(t, reflect.DeepEqual(oldFlds, []float64{4, 5, 6}))
// expect(t, reflect.DeepEqual(newFlds, []float64{7, 8, 9, 10, 11, 12}))
}
// var old geojson.Object
// c := New()
// old, _, _ = c.Set("hello1", String("world1"), nil, nil)
// expect(t, old == nil)
// old, _, _ = c.Set("hello2", String("world2"), nil, nil)
// expect(t, old == nil)
// old, _, _ = c.Set("hello3", String("world3"), nil, nil)
// expect(t, old == nil)
// old, _, _ = c.Set("hello4", String("world4"), nil, nil)
// expect(t, old == nil)
// old, _, _ = c.Set("hello1", String("planet1"), nil, nil)
// expect(t, old == String("world1"))
// old, _, _ = c.Set("hello2", String("planet2"), nil, nil)
// expect(t, old == String("world2"))
// old, _, _ = c.Set("hello3", String("planet3"), nil, nil)
// expect(t, old == String("world3"))
// old, _, _ = c.Set("hello4", String("planet4"), nil, nil)
// expect(t, old == String("world4"))
}
func TestCollectionNewCollection(t *testing.T) { func TestCollectionNewCollection(t *testing.T) {
const numItems = 10000 const numItems = 10000
objs := make(map[string]geojson.Object) objs := make(map[string]geojson.Object)
@ -114,24 +164,36 @@ func TestCollectionSet(t *testing.T) {
t.Run("Fields", func(t *testing.T) { t.Run("Fields", func(t *testing.T) {
c := New() c := New()
str1 := String("hello") str1 := String("hello")
fNames := []string{"a", "b", "c"} str2 := String("jello")
fValues := []float64{1, 2, 3} {
oldObj, oldFlds, newFlds := c.Set("str", str1, fNames, fValues) fNames := []string{"a", "b", "c"}
expect(t, oldObj == nil) fValues := []float64{1, 2, 3}
expect(t, len(oldFlds) == 0) println("A")
expect(t, reflect.DeepEqual(newFlds, fValues)) oldObj, oldFlds, newFlds := c.Set("str", str1, fNames, fValues)
str2 := String("hello") println("B")
fNames = []string{"d", "e", "f"} expect(t, oldObj == nil)
fValues = []float64{4, 5, 6} expect(t, len(oldFlds) == 0)
oldObj, oldFlds, newFlds = c.Set("str", str2, fNames, fValues) expect(t, reflect.DeepEqual(newFlds, fValues))
expect(t, oldObj == str1) }
expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3})) {
expect(t, reflect.DeepEqual(newFlds, []float64{1, 2, 3, 4, 5, 6})) fNames := []string{"d", "e", "f"}
fValues = []float64{7, 8, 9, 10, 11, 12} fValues := []float64{4, 5, 6}
oldObj, oldFlds, newFlds = c.Set("str", str1, nil, fValues) println("C")
expect(t, oldObj == str2) oldObj, oldFlds, newFlds := c.Set("str", str2, fNames, fValues)
expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3, 4, 5, 6})) println("D")
expect(t, reflect.DeepEqual(newFlds, []float64{7, 8, 9, 10, 11, 12})) expect(t, oldObj == str1)
expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3}))
expect(t, reflect.DeepEqual(newFlds, []float64{1, 2, 3, 4, 5, 6}))
}
{
fValues := []float64{7, 8, 9, 10, 11, 12}
println("E")
oldObj, oldFlds, newFlds := c.Set("str", str1, nil, fValues)
println("F")
expect(t, oldObj == str2)
expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3, 4, 5, 6}))
expect(t, reflect.DeepEqual(newFlds, []float64{7, 8, 9, 10, 11, 12}))
}
}) })
t.Run("Delete", func(t *testing.T) { t.Run("Delete", func(t *testing.T) {
c := New() c := New()

140
internal/collection/item.go Normal file
View File

@ -0,0 +1,140 @@
package collection
import (
"reflect"
"unsafe"
"github.com/tidwall/btree"
"github.com/tidwall/geojson"
)
type itemT struct {
obj geojson.Object
idLen uint32 // id block size in bytes
fieldsLen uint32 // fields block size in bytes, not num of fields
data unsafe.Pointer
}
func (item *itemT) id() string {
return *(*string)((unsafe.Pointer)(&reflect.StringHeader{
Data: uintptr(unsafe.Pointer(item.data)),
Len: int(item.idLen),
}))
}
func (item *itemT) fields() []float64 {
return *(*[]float64)((unsafe.Pointer)(&reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(item.data)) + uintptr(item.idLen),
Len: int(item.fieldsLen) / 8,
Cap: int(item.fieldsLen) / 8,
}))
}
func (item *itemT) dataBytes() []byte {
return *(*[]byte)((unsafe.Pointer)(&reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(item.data)),
Len: int(item.fieldsLen) + int(item.idLen),
Cap: int(item.fieldsLen) + int(item.idLen),
}))
}
func newItem(id string, obj geojson.Object) *itemT {
item := new(itemT)
item.obj = obj
item.idLen = uint32(len(id))
if len(id) > 0 {
data := make([]byte, len(id))
copy(data, id)
item.data = unsafe.Pointer(&data[0])
}
return item
}
func (item *itemT) weightAndPoints() (weight, points int) {
if objIsSpatial(item.obj) {
points = item.obj.NumPoints()
weight = points * 16
} else {
weight = len(item.obj.String())
}
weight += int(item.fieldsLen + item.idLen)
return weight, points
}
func (item *itemT) Less(other btree.Item, ctx interface{}) bool {
value1 := item.obj.String()
value2 := other.(*itemT).obj.String()
if value1 < value2 {
return true
}
if value1 > value2 {
return false
}
// the values match so we'll compare IDs, which are always unique.
return item.id() < other.(*itemT).id()
}
// directSetFields copies fields, overwriting previous fields
func (item *itemT) directSetFields(fields []float64) {
n := int(item.idLen) + len(fields)*8
item.fieldsLen = uint32(len(fields) * 8)
if n > 0 {
newData := make([]byte, int(item.idLen)+len(fields)*8)
item.data = unsafe.Pointer(&newData[0])
copy(newData, item.id())
copy(item.fields(), fields)
} else {
item.data = nil
}
}
func (c *Collection) setField(
item *itemT, fieldName string, fieldValue float64, updateWeight bool,
) (updated bool) {
idx, ok := c.fieldMap[fieldName]
if !ok {
idx = len(c.fieldMap)
c.fieldMap[fieldName] = idx
}
itemFields := item.fields()
if idx >= len(itemFields) {
// make room for new field
oldLen := len(itemFields)
// print(c.weight)
data := make([]byte, int(item.idLen)+(idx+1)*8)
copy(data, item.dataBytes())
item.fieldsLen = uint32((idx + 1) * 8)
item.data = unsafe.Pointer(&data[0])
itemFields := item.fields()
if updateWeight {
c.weight += (len(itemFields) - oldLen) * 8
}
// print(":")
// print(c.weight)
// println()
itemFields[idx] = fieldValue
updated = true
} else if itemFields[idx] != fieldValue {
// existing field needs updating
itemFields[idx] = fieldValue
updated = true
}
return updated
}
func (c *Collection) setFields(
item *itemT, fieldNames []string, fieldValues []float64, updateWeight bool,
) (updatedCount int) {
// TODO: optimize to predict the item data growth.
// TODO: do all sets here, instead of calling setFields in a loop
for i, fieldName := range fieldNames {
var fieldValue float64
if i < len(fieldValues) {
fieldValue = fieldValues[i]
}
if c.setField(item, fieldName, fieldValue, updateWeight) {
updatedCount++
}
}
return updatedCount
}

View File

@ -15,9 +15,9 @@ type btreeItem struct {
// keyedItem must match layout of ../collection/itemT, otherwise // keyedItem must match layout of ../collection/itemT, otherwise
// there's a risk for memory corruption. // there's a risk for memory corruption.
type keyedItem struct { type keyedItem struct {
_ interface{} obj interface{}
_ uint32
keyLen uint32 keyLen uint32
_ uint32
data unsafe.Pointer data unsafe.Pointer
} }

View File

@ -0,0 +1,589 @@
package ptrbtree
import (
"fmt"
"math/rand"
"sort"
"strings"
"testing"
"time"
"unsafe"
)
func makeItem(key string, obj interface{}) unsafe.Pointer {
item := new(keyedItem)
item.obj = obj
if len(key) > 0 {
data := make([]byte, len(key))
copy(data, key)
item.keyLen = uint32(len(key))
item.data = unsafe.Pointer(&data[0])
}
return unsafe.Pointer(item)
}
func itemKey(ptr unsafe.Pointer) string {
return (btreeItem{ptr}).key()
}
func itemValue(ptr unsafe.Pointer) interface{} {
return (*keyedItem)(ptr).obj
}
func init() {
seed := time.Now().UnixNano()
fmt.Printf("seed: %d\n", seed)
rand.Seed(seed)
}
func randKeys(N int) (keys []string) {
format := fmt.Sprintf("%%0%dd", len(fmt.Sprintf("%d", N-1)))
for _, i := range rand.Perm(N) {
keys = append(keys, fmt.Sprintf(format, i))
}
return
}
const flatLeaf = true
func (tr *BTree) print() {
tr.root.print(0, tr.height)
}
func (n *node) print(level, height int) {
if n == nil {
println("NIL")
return
}
if height == 0 && flatLeaf {
fmt.Printf("%s", strings.Repeat(" ", level))
}
for i := 0; i < n.numItems; i++ {
if height > 0 {
n.children[i].print(level+1, height-1)
}
if height > 0 || (height == 0 && !flatLeaf) {
fmt.Printf("%s%v\n", strings.Repeat(" ", level), n.items[i].key())
} else {
if i > 0 {
fmt.Printf(",")
}
fmt.Printf("%s", n.items[i].key())
}
}
if height == 0 && flatLeaf {
fmt.Printf("\n")
}
if height > 0 {
n.children[n.numItems].print(level+1, height-1)
}
}
func (tr *BTree) deepPrint() {
fmt.Printf("%#v\n", tr)
tr.root.deepPrint(0, tr.height)
}
func (n *node) deepPrint(level, height int) {
if n == nil {
fmt.Printf("%s %#v\n", strings.Repeat(" ", level), n)
return
}
fmt.Printf("%s count: %v\n", strings.Repeat(" ", level), n.numItems)
fmt.Printf("%s items: %v\n", strings.Repeat(" ", level), n.items)
if height > 0 {
fmt.Printf("%s child: %v\n", strings.Repeat(" ", level), n.children)
}
if height > 0 {
for i := 0; i < n.numItems; i++ {
n.children[i].deepPrint(level+1, height-1)
}
n.children[n.numItems].deepPrint(level+1, height-1)
}
}
func stringsEquals(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if a[i] != b[i] {
return false
}
}
return true
}
func TestDescend(t *testing.T) {
var tr BTree
var count int
tr.Descend("1", func(ptr unsafe.Pointer) bool {
count++
return true
})
if count > 0 {
t.Fatalf("expected 0, got %v", count)
}
var keys []string
for i := 0; i < 1000; i += 10 {
keys = append(keys, fmt.Sprintf("%03d", i))
tr.Set(makeItem(keys[len(keys)-1], nil))
}
var exp []string
tr.Reverse(func(ptr unsafe.Pointer) bool {
exp = append(exp, itemKey(ptr))
return true
})
for i := 999; i >= 0; i-- {
var key string
key = fmt.Sprintf("%03d", i)
var all []string
tr.Descend(key, func(ptr unsafe.Pointer) bool {
all = append(all, itemKey(ptr))
return true
})
for len(exp) > 0 && key < exp[0] {
exp = exp[1:]
}
var count int
tr.Descend(key, func(ptr unsafe.Pointer) bool {
if count == (i+1)%maxItems {
return false
}
count++
return true
})
if count > len(exp) {
t.Fatalf("expected 1, got %v", count)
}
if !stringsEquals(exp, all) {
fmt.Printf("exp: %v\n", exp)
fmt.Printf("all: %v\n", all)
t.Fatal("mismatch")
}
}
}
func TestAscend(t *testing.T) {
var tr BTree
var count int
tr.Ascend("1", func(ptr unsafe.Pointer) bool {
count++
return true
})
if count > 0 {
t.Fatalf("expected 0, got %v", count)
}
var keys []string
for i := 0; i < 1000; i += 10 {
keys = append(keys, fmt.Sprintf("%03d", i))
tr.Set(makeItem(keys[len(keys)-1], nil))
}
exp := keys
for i := -1; i < 1000; i++ {
var key string
if i == -1 {
key = ""
} else {
key = fmt.Sprintf("%03d", i)
}
var all []string
tr.Ascend(key, func(ptr unsafe.Pointer) bool {
all = append(all, itemKey(ptr))
return true
})
for len(exp) > 0 && key > exp[0] {
exp = exp[1:]
}
var count int
tr.Ascend(key, func(ptr unsafe.Pointer) bool {
if count == (i+1)%maxItems {
return false
}
count++
return true
})
if count > len(exp) {
t.Fatalf("expected 1, got %v", count)
}
if !stringsEquals(exp, all) {
t.Fatal("mismatch")
}
}
}
func TestBTree(t *testing.T) {
N := 10000
var tr BTree
keys := randKeys(N)
// insert all items
for _, key := range keys {
value, replaced := tr.Set(makeItem(key, key))
if replaced {
t.Fatal("expected false")
}
if value != nil {
t.Fatal("expected nil")
}
}
// check length
if tr.Len() != len(keys) {
t.Fatalf("expected %v, got %v", len(keys), tr.Len())
}
// get each value
for _, key := range keys {
value, gotten := tr.Get(key)
if !gotten {
t.Fatal("expected true")
}
if value == nil || itemValue(value) != key {
t.Fatalf("expected '%v', got '%v'", key, value)
}
}
// scan all items
var last string
all := make(map[string]interface{})
tr.Scan(func(ptr unsafe.Pointer) bool {
if itemKey(ptr) <= last {
t.Fatal("out of order")
}
if itemValue(ptr).(string) != itemKey(ptr) {
t.Fatalf("mismatch")
}
last = itemKey(ptr)
all[itemKey(ptr)] = itemValue(ptr)
return true
})
if len(all) != len(keys) {
t.Fatalf("expected '%v', got '%v'", len(keys), len(all))
}
// reverse all items
var prev string
all = make(map[string]interface{})
tr.Reverse(func(ptr unsafe.Pointer) bool {
if prev != "" && itemKey(ptr) >= prev {
t.Fatal("out of order")
}
if itemValue(ptr).(string) != itemKey(ptr) {
t.Fatalf("mismatch")
}
prev = itemKey(ptr)
all[itemKey(ptr)] = itemValue(ptr)
return true
})
if len(all) != len(keys) {
t.Fatalf("expected '%v', got '%v'", len(keys), len(all))
}
// try to get an invalid item
value, gotten := tr.Get("invalid")
if gotten {
t.Fatal("expected false")
}
if value != nil {
t.Fatal("expected nil")
}
// scan and quit at various steps
for i := 0; i < 100; i++ {
var j int
tr.Scan(func(ptr unsafe.Pointer) bool {
if j == i {
return false
}
j++
return true
})
}
// reverse and quit at various steps
for i := 0; i < 100; i++ {
var j int
tr.Reverse(func(ptr unsafe.Pointer) bool {
if j == i {
return false
}
j++
return true
})
}
// delete half the items
for _, key := range keys[:len(keys)/2] {
value, deleted := tr.Delete(key)
if !deleted {
t.Fatal("expected true")
}
if value == nil || itemValue(value).(string) != key {
t.Fatalf("expected '%v', got '%v'", key, value)
}
}
// check length
if tr.Len() != len(keys)/2 {
t.Fatalf("expected %v, got %v", len(keys)/2, tr.Len())
}
// try delete half again
for _, key := range keys[:len(keys)/2] {
value, deleted := tr.Delete(key)
if deleted {
t.Fatal("expected false")
}
if value != nil {
t.Fatalf("expected nil")
}
}
// try delete half again
for _, key := range keys[:len(keys)/2] {
value, deleted := tr.Delete(key)
if deleted {
t.Fatal("expected false")
}
if value != nil {
t.Fatalf("expected nil")
}
}
// check length
if tr.Len() != len(keys)/2 {
t.Fatalf("expected %v, got %v", len(keys)/2, tr.Len())
}
// scan items
last = ""
all = make(map[string]interface{})
tr.Scan(func(ptr unsafe.Pointer) bool {
if itemKey(ptr) <= last {
t.Fatal("out of order")
}
if itemValue(ptr).(string) != itemKey(ptr) {
t.Fatalf("mismatch")
}
last = itemKey(ptr)
all[itemKey(ptr)] = itemValue(ptr)
return true
})
if len(all) != len(keys)/2 {
t.Fatalf("expected '%v', got '%v'", len(keys), len(all))
}
// replace second half
for _, key := range keys[len(keys)/2:] {
value, replaced := tr.Set(makeItem(key, key))
if !replaced {
t.Fatal("expected true")
}
if value == nil || itemValue(value).(string) != key {
t.Fatalf("expected '%v', got '%v'", key, value)
}
}
// delete next half the items
for _, key := range keys[len(keys)/2:] {
value, deleted := tr.Delete(key)
if !deleted {
t.Fatal("expected true")
}
if value == nil || itemValue(value).(string) != key {
t.Fatalf("expected '%v', got '%v'", key, value)
}
}
// check length
if tr.Len() != 0 {
t.Fatalf("expected %v, got %v", 0, tr.Len())
}
// do some stuff on an empty tree
value, gotten = tr.Get(keys[0])
if gotten {
t.Fatal("expected false")
}
if value != nil {
t.Fatal("expected nil")
}
tr.Scan(func(ptr unsafe.Pointer) bool {
t.Fatal("should not be reached")
return true
})
tr.Reverse(func(ptr unsafe.Pointer) bool {
t.Fatal("should not be reached")
return true
})
var deleted bool
value, deleted = tr.Delete("invalid")
if deleted {
t.Fatal("expected false")
}
if value != nil {
t.Fatal("expected nil")
}
}
func BenchmarkTidwallSequentialSet(b *testing.B) {
var tr BTree
keys := randKeys(b.N)
sort.Strings(keys)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tr.Set(makeItem(keys[i], nil))
}
}
func BenchmarkTidwallSequentialGet(b *testing.B) {
var tr BTree
keys := randKeys(b.N)
sort.Strings(keys)
for i := 0; i < b.N; i++ {
tr.Set(makeItem(keys[i], nil))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tr.Get(keys[i])
}
}
func BenchmarkTidwallRandomSet(b *testing.B) {
var tr BTree
keys := randKeys(b.N)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tr.Set(makeItem(keys[i], nil))
}
}
func BenchmarkTidwallRandomGet(b *testing.B) {
var tr BTree
keys := randKeys(b.N)
for i := 0; i < b.N; i++ {
tr.Set(makeItem(keys[i], nil))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tr.Get(keys[i])
}
}
// type googleKind struct {
// key string
// }
// func (a *googleKind) Less(b btree.Item) bool {
// return a.key < b.(*googleKind).key
// }
// func BenchmarkGoogleSequentialSet(b *testing.B) {
// tr := btree.New(32)
// keys := randKeys(b.N)
// sort.Strings(keys)
// gkeys := make([]*googleKind, len(keys))
// for i := 0; i < b.N; i++ {
// gkeys[i] = &googleKind{keys[i]}
// }
// b.ResetTimer()
// for i := 0; i < b.N; i++ {
// tr.ReplaceOrInsert(gkeys[i])
// }
// }
// func BenchmarkGoogleSequentialGet(b *testing.B) {
// tr := btree.New(32)
// keys := randKeys(b.N)
// gkeys := make([]*googleKind, len(keys))
// for i := 0; i < b.N; i++ {
// gkeys[i] = &googleKind{keys[i]}
// }
// for i := 0; i < b.N; i++ {
// tr.ReplaceOrInsert(gkeys[i])
// }
// sort.Strings(keys)
// b.ResetTimer()
// for i := 0; i < b.N; i++ {
// tr.Get(gkeys[i])
// }
// }
// func BenchmarkGoogleRandomSet(b *testing.B) {
// tr := btree.New(32)
// keys := randKeys(b.N)
// gkeys := make([]*googleKind, len(keys))
// for i := 0; i < b.N; i++ {
// gkeys[i] = &googleKind{keys[i]}
// }
// b.ResetTimer()
// for i := 0; i < b.N; i++ {
// tr.ReplaceOrInsert(gkeys[i])
// }
// }
// func BenchmarkGoogleRandomGet(b *testing.B) {
// tr := btree.New(32)
// keys := randKeys(b.N)
// gkeys := make([]*googleKind, len(keys))
// for i := 0; i < b.N; i++ {
// gkeys[i] = &googleKind{keys[i]}
// }
// for i := 0; i < b.N; i++ {
// tr.ReplaceOrInsert(gkeys[i])
// }
// b.ResetTimer()
// for i := 0; i < b.N; i++ {
// tr.Get(gkeys[i])
// }
// }
func TestBTreeOne(t *testing.T) {
var tr BTree
tr.Set(makeItem("1", "1"))
tr.Delete("1")
tr.Set(makeItem("1", "1"))
tr.Delete("1")
tr.Set(makeItem("1", "1"))
tr.Delete("1")
}
func TestBTree256(t *testing.T) {
var tr BTree
var n int
for j := 0; j < 2; j++ {
for _, i := range rand.Perm(256) {
tr.Set(makeItem(fmt.Sprintf("%d", i), i))
n++
if tr.Len() != n {
t.Fatalf("expected 256, got %d", n)
}
}
for _, i := range rand.Perm(256) {
v, ok := tr.Get(fmt.Sprintf("%d", i))
if !ok {
t.Fatal("expected true")
}
if itemValue(v).(int) != i {
t.Fatalf("expected %d, got %d", i, itemValue(v).(int))
}
}
for _, i := range rand.Perm(256) {
tr.Delete(fmt.Sprintf("%d", i))
n--
if tr.Len() != n {
t.Fatalf("expected 256, got %d", n)
}
}
for _, i := range rand.Perm(256) {
_, ok := tr.Get(fmt.Sprintf("%d", i))
if ok {
t.Fatal("expected false")
}
}
}
}