mirror of https://github.com/tidwall/tile38.git
Updated tests
This commit is contained in:
parent
d115b40d71
commit
b5dcb18c54
|
@ -1,7 +1,6 @@
|
|||
package collection
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tidwall/btree"
|
||||
|
@ -18,57 +17,6 @@ type Cursor interface {
|
|||
Step(count uint64)
|
||||
}
|
||||
|
||||
type itemT struct {
|
||||
obj geojson.Object
|
||||
_ uint32
|
||||
idLen uint32
|
||||
idData unsafe.Pointer
|
||||
fields []float64
|
||||
}
|
||||
|
||||
func (item *itemT) id() string {
|
||||
return *(*string)((unsafe.Pointer)(&reflect.StringHeader{
|
||||
Data: uintptr(unsafe.Pointer(item.idData)),
|
||||
Len: int(item.idLen),
|
||||
}))
|
||||
}
|
||||
|
||||
func newItem(id string, obj geojson.Object) *itemT {
|
||||
item := new(itemT)
|
||||
item.obj = obj
|
||||
item.idLen = uint32(len(id))
|
||||
if len(id) > 0 {
|
||||
idData := make([]byte, len(id))
|
||||
copy(idData, id)
|
||||
item.idData = unsafe.Pointer(&idData[0])
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func (item *itemT) weightAndPoints() (weight, points int) {
|
||||
if objIsSpatial(item.obj) {
|
||||
points = item.obj.NumPoints()
|
||||
weight = points * 16
|
||||
} else {
|
||||
weight = len(item.obj.String())
|
||||
}
|
||||
weight += len(item.fields)*8 + len(item.id())
|
||||
return weight, points
|
||||
}
|
||||
|
||||
func (item *itemT) Less(other btree.Item, ctx interface{}) bool {
|
||||
value1 := item.obj.String()
|
||||
value2 := other.(*itemT).obj.String()
|
||||
if value1 < value2 {
|
||||
return true
|
||||
}
|
||||
if value1 > value2 {
|
||||
return false
|
||||
}
|
||||
// the values match so we'll compare IDs, which are always unique.
|
||||
return item.id() < other.(*itemT).id()
|
||||
}
|
||||
|
||||
// Collection represents a collection of geojson objects.
|
||||
type Collection struct {
|
||||
items ptrbtree.BTree // items sorted by keys
|
||||
|
@ -180,45 +128,40 @@ func (c *Collection) delItem(item *itemT) {
|
|||
func (c *Collection) Set(
|
||||
id string, obj geojson.Object, fields []string, values []float64,
|
||||
) (
|
||||
oldObject geojson.Object, oldFields []float64, newFields []float64,
|
||||
oldObj geojson.Object, oldFields []float64, newFields []float64,
|
||||
) {
|
||||
newItem := newItem(id, obj)
|
||||
// create the new item
|
||||
item := newItem(id, obj)
|
||||
|
||||
// add the new item to main btree and remove the old one if needed
|
||||
oldItemV, ok := c.items.Set(unsafe.Pointer(newItem))
|
||||
oldItemV, ok := c.items.Set(unsafe.Pointer(item))
|
||||
if ok {
|
||||
oldItem := (*itemT)(oldItemV)
|
||||
oldObj = oldItem.obj
|
||||
|
||||
// remove old item from indexes
|
||||
c.delItem(oldItem)
|
||||
|
||||
oldObject = oldItem.obj
|
||||
if len(oldItem.fields) > 0 {
|
||||
if len(oldItem.fields()) > 0 {
|
||||
// merge old and new fields
|
||||
oldFields = oldItem.fields
|
||||
newItem.fields = make([]float64, len(oldFields))
|
||||
copy(newItem.fields, oldFields)
|
||||
oldFields = oldItem.fields()
|
||||
item.directSetFields(oldFields)
|
||||
}
|
||||
}
|
||||
|
||||
if fields == nil && len(values) > 0 {
|
||||
// directly set the field values, from copy
|
||||
newItem.fields = make([]float64, len(values))
|
||||
copy(newItem.fields, values)
|
||||
|
||||
item.directSetFields(values)
|
||||
} else if len(fields) > 0 {
|
||||
// add new field to new item
|
||||
if len(newItem.fields) == 0 {
|
||||
// make exact room
|
||||
newItem.fields = make([]float64, 0, len(fields))
|
||||
}
|
||||
c.setFields(newItem, fields, values, false)
|
||||
c.setFields(item, fields, values, false)
|
||||
}
|
||||
|
||||
// add new item to indexes
|
||||
c.addItem(newItem)
|
||||
c.addItem(item)
|
||||
// fmt.Printf("!!! %#v\n", oldObj)
|
||||
|
||||
return oldObject, oldFields, newItem.fields
|
||||
return oldObj, oldFields, item.fields()
|
||||
}
|
||||
|
||||
// Delete removes an object and returns it.
|
||||
|
@ -234,7 +177,7 @@ func (c *Collection) Delete(id string) (
|
|||
|
||||
c.delItem(oldItem)
|
||||
|
||||
return oldItem.obj, oldItem.fields, true
|
||||
return oldItem.obj, oldItem.fields(), true
|
||||
}
|
||||
|
||||
// Get returns an object.
|
||||
|
@ -248,7 +191,7 @@ func (c *Collection) Get(id string) (
|
|||
}
|
||||
item := (*itemT)(itemV)
|
||||
|
||||
return item.obj, item.fields, true
|
||||
return item.obj, item.fields(), true
|
||||
}
|
||||
|
||||
// SetField set a field value for an object and returns that object.
|
||||
|
@ -262,35 +205,7 @@ func (c *Collection) SetField(id, fieldName string, fieldValue float64) (
|
|||
}
|
||||
item := (*itemT)(itemV)
|
||||
updated = c.setField(item, fieldName, fieldValue, true)
|
||||
return item.obj, item.fields, updated, true
|
||||
}
|
||||
|
||||
func (c *Collection) setField(
|
||||
item *itemT, fieldName string, fieldValue float64, updateWeight bool,
|
||||
) (updated bool) {
|
||||
idx, ok := c.fieldMap[fieldName]
|
||||
if !ok {
|
||||
idx = len(c.fieldMap)
|
||||
c.fieldMap[fieldName] = idx
|
||||
}
|
||||
|
||||
if idx >= len(item.fields) {
|
||||
// grow the fields slice
|
||||
oldLen := len(item.fields)
|
||||
for idx >= len(item.fields) {
|
||||
item.fields = append(item.fields, 0)
|
||||
}
|
||||
if updateWeight {
|
||||
c.weight += (len(item.fields) - oldLen) * 8
|
||||
}
|
||||
item.fields[idx] = fieldValue
|
||||
updated = true
|
||||
} else if item.fields[idx] != fieldValue {
|
||||
// existing field needs updating
|
||||
item.fields[idx] = fieldValue
|
||||
updated = true
|
||||
}
|
||||
return updated
|
||||
return item.obj, item.fields(), updated, true
|
||||
}
|
||||
|
||||
// SetFields is similar to SetField, just setting multiple fields at once
|
||||
|
@ -305,23 +220,7 @@ func (c *Collection) SetFields(
|
|||
|
||||
updatedCount = c.setFields(item, fieldNames, fieldValues, true)
|
||||
|
||||
return item.obj, item.fields, updatedCount, true
|
||||
}
|
||||
|
||||
func (c *Collection) setFields(
|
||||
item *itemT, fieldNames []string, fieldValues []float64, updateWeight bool,
|
||||
) (updatedCount int) {
|
||||
|
||||
for i, fieldName := range fieldNames {
|
||||
var fieldValue float64
|
||||
if i < len(fieldValues) {
|
||||
fieldValue = fieldValues[i]
|
||||
}
|
||||
if c.setField(item, fieldName, fieldValue, updateWeight) {
|
||||
updatedCount++
|
||||
}
|
||||
}
|
||||
return updatedCount
|
||||
return item.obj, item.fields(), updatedCount, true
|
||||
}
|
||||
|
||||
// FieldMap return a maps of the field names.
|
||||
|
@ -358,7 +257,7 @@ func (c *Collection) Scan(desc bool, cursor Cursor,
|
|||
cursor.Step(1)
|
||||
}
|
||||
iitm := (*itemT)(ptr)
|
||||
keepon = iterator(iitm.id(), iitm.obj, iitm.fields)
|
||||
keepon = iterator(iitm.id(), iitm.obj, iitm.fields())
|
||||
return keepon
|
||||
}
|
||||
if desc {
|
||||
|
@ -398,7 +297,7 @@ func (c *Collection) ScanRange(start, end string, desc bool, cursor Cursor,
|
|||
return false
|
||||
}
|
||||
}
|
||||
keepon = iterator(iitm.id(), iitm.obj, iitm.fields)
|
||||
keepon = iterator(iitm.id(), iitm.obj, iitm.fields())
|
||||
return keepon
|
||||
}
|
||||
|
||||
|
@ -430,7 +329,7 @@ func (c *Collection) SearchValues(desc bool, cursor Cursor,
|
|||
cursor.Step(1)
|
||||
}
|
||||
iitm := item.(*itemT)
|
||||
keepon = iterator(iitm.id(), iitm.obj, iitm.fields)
|
||||
keepon = iterator(iitm.id(), iitm.obj, iitm.fields())
|
||||
return keepon
|
||||
}
|
||||
if desc {
|
||||
|
@ -462,7 +361,7 @@ func (c *Collection) SearchValuesRange(start, end string, desc bool,
|
|||
cursor.Step(1)
|
||||
}
|
||||
iitm := item.(*itemT)
|
||||
keepon = iterator(iitm.id(), iitm.obj, iitm.fields)
|
||||
keepon = iterator(iitm.id(), iitm.obj, iitm.fields())
|
||||
return keepon
|
||||
}
|
||||
if desc {
|
||||
|
@ -498,7 +397,7 @@ func (c *Collection) ScanGreaterOrEqual(id string, desc bool,
|
|||
cursor.Step(1)
|
||||
}
|
||||
iitm := (*itemT)(ptr)
|
||||
keepon = iterator(iitm.id(), iitm.obj, iitm.fields)
|
||||
keepon = iterator(iitm.id(), iitm.obj, iitm.fields())
|
||||
return keepon
|
||||
}
|
||||
if desc {
|
||||
|
@ -519,7 +418,7 @@ func (c *Collection) geoSearch(
|
|||
[]float64{rect.Max.X, rect.Max.Y},
|
||||
func(_, _ []float64, itemv unsafe.Pointer) bool {
|
||||
item := (*itemT)(itemv)
|
||||
alive = iter(item.id(), item.obj, item.fields)
|
||||
alive = iter(item.id(), item.obj, item.fields())
|
||||
return alive
|
||||
},
|
||||
)
|
||||
|
@ -744,7 +643,7 @@ func (c *Collection) Nearby(
|
|||
cursor.Step(1)
|
||||
}
|
||||
item := (*itemT)(itemv)
|
||||
alive = iter(item.id(), item.obj, item.fields)
|
||||
alive = iter(item.id(), item.obj, item.fields())
|
||||
return alive
|
||||
},
|
||||
)
|
||||
|
|
|
@ -38,6 +38,56 @@ func bounds(c *Collection) geometry.Rect {
|
|||
}
|
||||
}
|
||||
|
||||
func TestStuff(t *testing.T) {
|
||||
c := New()
|
||||
key := "str"
|
||||
str1 := String("hello")
|
||||
str2 := String("jello")
|
||||
{
|
||||
println("A")
|
||||
oldObj, _, _ := c.Set(key, str1, []string{"a", "b", "c"}, nil)
|
||||
println("B")
|
||||
expect(t, oldObj == nil)
|
||||
}
|
||||
{
|
||||
println("C")
|
||||
oldObj, _, _ := c.Set(key, str2, nil, nil) //[]float64{4, 5, 6})
|
||||
println("D")
|
||||
expect(t, oldObj == str1)
|
||||
// expect(t, reflect.DeepEqual(oldFlds, nil)) //[]float64{1, 2, 3}))
|
||||
// expect(t, reflect.DeepEqual(newFlds, []float64{1, 2, 3, 4, 5, 6}))
|
||||
}
|
||||
{
|
||||
// fValues := []float64{7, 8, 9, 10, 11, 12}
|
||||
println("E")
|
||||
oldObj, _, _ := c.Set(key, str1, nil, nil)
|
||||
println("F")
|
||||
expect(t, oldObj == str2)
|
||||
// expect(t, reflect.DeepEqual(oldFlds, []float64{4, 5, 6}))
|
||||
// expect(t, reflect.DeepEqual(newFlds, []float64{7, 8, 9, 10, 11, 12}))
|
||||
}
|
||||
|
||||
// var old geojson.Object
|
||||
// c := New()
|
||||
// old, _, _ = c.Set("hello1", String("world1"), nil, nil)
|
||||
// expect(t, old == nil)
|
||||
// old, _, _ = c.Set("hello2", String("world2"), nil, nil)
|
||||
// expect(t, old == nil)
|
||||
// old, _, _ = c.Set("hello3", String("world3"), nil, nil)
|
||||
// expect(t, old == nil)
|
||||
// old, _, _ = c.Set("hello4", String("world4"), nil, nil)
|
||||
// expect(t, old == nil)
|
||||
|
||||
// old, _, _ = c.Set("hello1", String("planet1"), nil, nil)
|
||||
// expect(t, old == String("world1"))
|
||||
// old, _, _ = c.Set("hello2", String("planet2"), nil, nil)
|
||||
// expect(t, old == String("world2"))
|
||||
// old, _, _ = c.Set("hello3", String("planet3"), nil, nil)
|
||||
// expect(t, old == String("world3"))
|
||||
// old, _, _ = c.Set("hello4", String("planet4"), nil, nil)
|
||||
// expect(t, old == String("world4"))
|
||||
}
|
||||
|
||||
func TestCollectionNewCollection(t *testing.T) {
|
||||
const numItems = 10000
|
||||
objs := make(map[string]geojson.Object)
|
||||
|
@ -114,24 +164,36 @@ func TestCollectionSet(t *testing.T) {
|
|||
t.Run("Fields", func(t *testing.T) {
|
||||
c := New()
|
||||
str1 := String("hello")
|
||||
fNames := []string{"a", "b", "c"}
|
||||
fValues := []float64{1, 2, 3}
|
||||
oldObj, oldFlds, newFlds := c.Set("str", str1, fNames, fValues)
|
||||
expect(t, oldObj == nil)
|
||||
expect(t, len(oldFlds) == 0)
|
||||
expect(t, reflect.DeepEqual(newFlds, fValues))
|
||||
str2 := String("hello")
|
||||
fNames = []string{"d", "e", "f"}
|
||||
fValues = []float64{4, 5, 6}
|
||||
oldObj, oldFlds, newFlds = c.Set("str", str2, fNames, fValues)
|
||||
expect(t, oldObj == str1)
|
||||
expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3}))
|
||||
expect(t, reflect.DeepEqual(newFlds, []float64{1, 2, 3, 4, 5, 6}))
|
||||
fValues = []float64{7, 8, 9, 10, 11, 12}
|
||||
oldObj, oldFlds, newFlds = c.Set("str", str1, nil, fValues)
|
||||
expect(t, oldObj == str2)
|
||||
expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3, 4, 5, 6}))
|
||||
expect(t, reflect.DeepEqual(newFlds, []float64{7, 8, 9, 10, 11, 12}))
|
||||
str2 := String("jello")
|
||||
{
|
||||
fNames := []string{"a", "b", "c"}
|
||||
fValues := []float64{1, 2, 3}
|
||||
println("A")
|
||||
oldObj, oldFlds, newFlds := c.Set("str", str1, fNames, fValues)
|
||||
println("B")
|
||||
expect(t, oldObj == nil)
|
||||
expect(t, len(oldFlds) == 0)
|
||||
expect(t, reflect.DeepEqual(newFlds, fValues))
|
||||
}
|
||||
{
|
||||
fNames := []string{"d", "e", "f"}
|
||||
fValues := []float64{4, 5, 6}
|
||||
println("C")
|
||||
oldObj, oldFlds, newFlds := c.Set("str", str2, fNames, fValues)
|
||||
println("D")
|
||||
expect(t, oldObj == str1)
|
||||
expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3}))
|
||||
expect(t, reflect.DeepEqual(newFlds, []float64{1, 2, 3, 4, 5, 6}))
|
||||
}
|
||||
{
|
||||
fValues := []float64{7, 8, 9, 10, 11, 12}
|
||||
println("E")
|
||||
oldObj, oldFlds, newFlds := c.Set("str", str1, nil, fValues)
|
||||
println("F")
|
||||
expect(t, oldObj == str2)
|
||||
expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3, 4, 5, 6}))
|
||||
expect(t, reflect.DeepEqual(newFlds, []float64{7, 8, 9, 10, 11, 12}))
|
||||
}
|
||||
})
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
c := New()
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
package collection
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tidwall/btree"
|
||||
"github.com/tidwall/geojson"
|
||||
)
|
||||
|
||||
type itemT struct {
|
||||
obj geojson.Object
|
||||
idLen uint32 // id block size in bytes
|
||||
fieldsLen uint32 // fields block size in bytes, not num of fields
|
||||
data unsafe.Pointer
|
||||
}
|
||||
|
||||
func (item *itemT) id() string {
|
||||
return *(*string)((unsafe.Pointer)(&reflect.StringHeader{
|
||||
Data: uintptr(unsafe.Pointer(item.data)),
|
||||
Len: int(item.idLen),
|
||||
}))
|
||||
}
|
||||
|
||||
func (item *itemT) fields() []float64 {
|
||||
return *(*[]float64)((unsafe.Pointer)(&reflect.SliceHeader{
|
||||
Data: uintptr(unsafe.Pointer(item.data)) + uintptr(item.idLen),
|
||||
Len: int(item.fieldsLen) / 8,
|
||||
Cap: int(item.fieldsLen) / 8,
|
||||
}))
|
||||
}
|
||||
|
||||
func (item *itemT) dataBytes() []byte {
|
||||
return *(*[]byte)((unsafe.Pointer)(&reflect.SliceHeader{
|
||||
Data: uintptr(unsafe.Pointer(item.data)),
|
||||
Len: int(item.fieldsLen) + int(item.idLen),
|
||||
Cap: int(item.fieldsLen) + int(item.idLen),
|
||||
}))
|
||||
}
|
||||
|
||||
func newItem(id string, obj geojson.Object) *itemT {
|
||||
item := new(itemT)
|
||||
item.obj = obj
|
||||
item.idLen = uint32(len(id))
|
||||
if len(id) > 0 {
|
||||
data := make([]byte, len(id))
|
||||
copy(data, id)
|
||||
item.data = unsafe.Pointer(&data[0])
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func (item *itemT) weightAndPoints() (weight, points int) {
|
||||
if objIsSpatial(item.obj) {
|
||||
points = item.obj.NumPoints()
|
||||
weight = points * 16
|
||||
} else {
|
||||
weight = len(item.obj.String())
|
||||
}
|
||||
weight += int(item.fieldsLen + item.idLen)
|
||||
return weight, points
|
||||
}
|
||||
|
||||
func (item *itemT) Less(other btree.Item, ctx interface{}) bool {
|
||||
value1 := item.obj.String()
|
||||
value2 := other.(*itemT).obj.String()
|
||||
if value1 < value2 {
|
||||
return true
|
||||
}
|
||||
if value1 > value2 {
|
||||
return false
|
||||
}
|
||||
// the values match so we'll compare IDs, which are always unique.
|
||||
return item.id() < other.(*itemT).id()
|
||||
}
|
||||
|
||||
// directSetFields copies fields, overwriting previous fields
|
||||
func (item *itemT) directSetFields(fields []float64) {
|
||||
n := int(item.idLen) + len(fields)*8
|
||||
item.fieldsLen = uint32(len(fields) * 8)
|
||||
if n > 0 {
|
||||
newData := make([]byte, int(item.idLen)+len(fields)*8)
|
||||
item.data = unsafe.Pointer(&newData[0])
|
||||
copy(newData, item.id())
|
||||
copy(item.fields(), fields)
|
||||
} else {
|
||||
item.data = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Collection) setField(
|
||||
item *itemT, fieldName string, fieldValue float64, updateWeight bool,
|
||||
) (updated bool) {
|
||||
idx, ok := c.fieldMap[fieldName]
|
||||
if !ok {
|
||||
idx = len(c.fieldMap)
|
||||
c.fieldMap[fieldName] = idx
|
||||
}
|
||||
itemFields := item.fields()
|
||||
if idx >= len(itemFields) {
|
||||
// make room for new field
|
||||
|
||||
oldLen := len(itemFields)
|
||||
// print(c.weight)
|
||||
data := make([]byte, int(item.idLen)+(idx+1)*8)
|
||||
copy(data, item.dataBytes())
|
||||
item.fieldsLen = uint32((idx + 1) * 8)
|
||||
item.data = unsafe.Pointer(&data[0])
|
||||
itemFields := item.fields()
|
||||
if updateWeight {
|
||||
c.weight += (len(itemFields) - oldLen) * 8
|
||||
}
|
||||
// print(":")
|
||||
// print(c.weight)
|
||||
// println()
|
||||
itemFields[idx] = fieldValue
|
||||
updated = true
|
||||
} else if itemFields[idx] != fieldValue {
|
||||
// existing field needs updating
|
||||
itemFields[idx] = fieldValue
|
||||
updated = true
|
||||
}
|
||||
return updated
|
||||
}
|
||||
func (c *Collection) setFields(
|
||||
item *itemT, fieldNames []string, fieldValues []float64, updateWeight bool,
|
||||
) (updatedCount int) {
|
||||
// TODO: optimize to predict the item data growth.
|
||||
// TODO: do all sets here, instead of calling setFields in a loop
|
||||
for i, fieldName := range fieldNames {
|
||||
var fieldValue float64
|
||||
if i < len(fieldValues) {
|
||||
fieldValue = fieldValues[i]
|
||||
}
|
||||
if c.setField(item, fieldName, fieldValue, updateWeight) {
|
||||
updatedCount++
|
||||
}
|
||||
}
|
||||
return updatedCount
|
||||
}
|
|
@ -15,9 +15,9 @@ type btreeItem struct {
|
|||
// keyedItem must match layout of ../collection/itemT, otherwise
|
||||
// there's a risk for memory corruption.
|
||||
type keyedItem struct {
|
||||
_ interface{}
|
||||
_ uint32
|
||||
obj interface{}
|
||||
keyLen uint32
|
||||
_ uint32
|
||||
data unsafe.Pointer
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,589 @@
|
|||
package ptrbtree
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func makeItem(key string, obj interface{}) unsafe.Pointer {
|
||||
item := new(keyedItem)
|
||||
item.obj = obj
|
||||
if len(key) > 0 {
|
||||
data := make([]byte, len(key))
|
||||
copy(data, key)
|
||||
item.keyLen = uint32(len(key))
|
||||
item.data = unsafe.Pointer(&data[0])
|
||||
}
|
||||
return unsafe.Pointer(item)
|
||||
}
|
||||
|
||||
func itemKey(ptr unsafe.Pointer) string {
|
||||
return (btreeItem{ptr}).key()
|
||||
}
|
||||
|
||||
func itemValue(ptr unsafe.Pointer) interface{} {
|
||||
return (*keyedItem)(ptr).obj
|
||||
}
|
||||
|
||||
func init() {
|
||||
seed := time.Now().UnixNano()
|
||||
fmt.Printf("seed: %d\n", seed)
|
||||
rand.Seed(seed)
|
||||
}
|
||||
|
||||
func randKeys(N int) (keys []string) {
|
||||
format := fmt.Sprintf("%%0%dd", len(fmt.Sprintf("%d", N-1)))
|
||||
for _, i := range rand.Perm(N) {
|
||||
keys = append(keys, fmt.Sprintf(format, i))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const flatLeaf = true
|
||||
|
||||
func (tr *BTree) print() {
|
||||
tr.root.print(0, tr.height)
|
||||
}
|
||||
|
||||
func (n *node) print(level, height int) {
|
||||
if n == nil {
|
||||
println("NIL")
|
||||
return
|
||||
}
|
||||
if height == 0 && flatLeaf {
|
||||
fmt.Printf("%s", strings.Repeat(" ", level))
|
||||
}
|
||||
for i := 0; i < n.numItems; i++ {
|
||||
if height > 0 {
|
||||
n.children[i].print(level+1, height-1)
|
||||
}
|
||||
if height > 0 || (height == 0 && !flatLeaf) {
|
||||
fmt.Printf("%s%v\n", strings.Repeat(" ", level), n.items[i].key())
|
||||
} else {
|
||||
if i > 0 {
|
||||
fmt.Printf(",")
|
||||
}
|
||||
fmt.Printf("%s", n.items[i].key())
|
||||
}
|
||||
}
|
||||
if height == 0 && flatLeaf {
|
||||
fmt.Printf("\n")
|
||||
}
|
||||
if height > 0 {
|
||||
n.children[n.numItems].print(level+1, height-1)
|
||||
}
|
||||
}
|
||||
|
||||
func (tr *BTree) deepPrint() {
|
||||
fmt.Printf("%#v\n", tr)
|
||||
tr.root.deepPrint(0, tr.height)
|
||||
}
|
||||
|
||||
func (n *node) deepPrint(level, height int) {
|
||||
if n == nil {
|
||||
fmt.Printf("%s %#v\n", strings.Repeat(" ", level), n)
|
||||
return
|
||||
}
|
||||
fmt.Printf("%s count: %v\n", strings.Repeat(" ", level), n.numItems)
|
||||
fmt.Printf("%s items: %v\n", strings.Repeat(" ", level), n.items)
|
||||
if height > 0 {
|
||||
fmt.Printf("%s child: %v\n", strings.Repeat(" ", level), n.children)
|
||||
}
|
||||
if height > 0 {
|
||||
for i := 0; i < n.numItems; i++ {
|
||||
n.children[i].deepPrint(level+1, height-1)
|
||||
}
|
||||
n.children[n.numItems].deepPrint(level+1, height-1)
|
||||
}
|
||||
}
|
||||
|
||||
func stringsEquals(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(a); i++ {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestDescend(t *testing.T) {
|
||||
var tr BTree
|
||||
var count int
|
||||
tr.Descend("1", func(ptr unsafe.Pointer) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
t.Fatalf("expected 0, got %v", count)
|
||||
}
|
||||
var keys []string
|
||||
for i := 0; i < 1000; i += 10 {
|
||||
keys = append(keys, fmt.Sprintf("%03d", i))
|
||||
tr.Set(makeItem(keys[len(keys)-1], nil))
|
||||
}
|
||||
var exp []string
|
||||
tr.Reverse(func(ptr unsafe.Pointer) bool {
|
||||
exp = append(exp, itemKey(ptr))
|
||||
return true
|
||||
})
|
||||
for i := 999; i >= 0; i-- {
|
||||
var key string
|
||||
key = fmt.Sprintf("%03d", i)
|
||||
var all []string
|
||||
tr.Descend(key, func(ptr unsafe.Pointer) bool {
|
||||
all = append(all, itemKey(ptr))
|
||||
return true
|
||||
})
|
||||
for len(exp) > 0 && key < exp[0] {
|
||||
exp = exp[1:]
|
||||
}
|
||||
var count int
|
||||
tr.Descend(key, func(ptr unsafe.Pointer) bool {
|
||||
if count == (i+1)%maxItems {
|
||||
return false
|
||||
}
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count > len(exp) {
|
||||
t.Fatalf("expected 1, got %v", count)
|
||||
}
|
||||
|
||||
if !stringsEquals(exp, all) {
|
||||
fmt.Printf("exp: %v\n", exp)
|
||||
fmt.Printf("all: %v\n", all)
|
||||
t.Fatal("mismatch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAscend(t *testing.T) {
|
||||
var tr BTree
|
||||
var count int
|
||||
tr.Ascend("1", func(ptr unsafe.Pointer) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
t.Fatalf("expected 0, got %v", count)
|
||||
}
|
||||
var keys []string
|
||||
for i := 0; i < 1000; i += 10 {
|
||||
keys = append(keys, fmt.Sprintf("%03d", i))
|
||||
tr.Set(makeItem(keys[len(keys)-1], nil))
|
||||
}
|
||||
exp := keys
|
||||
for i := -1; i < 1000; i++ {
|
||||
var key string
|
||||
if i == -1 {
|
||||
key = ""
|
||||
} else {
|
||||
key = fmt.Sprintf("%03d", i)
|
||||
}
|
||||
var all []string
|
||||
tr.Ascend(key, func(ptr unsafe.Pointer) bool {
|
||||
all = append(all, itemKey(ptr))
|
||||
return true
|
||||
})
|
||||
|
||||
for len(exp) > 0 && key > exp[0] {
|
||||
exp = exp[1:]
|
||||
}
|
||||
var count int
|
||||
tr.Ascend(key, func(ptr unsafe.Pointer) bool {
|
||||
if count == (i+1)%maxItems {
|
||||
return false
|
||||
}
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count > len(exp) {
|
||||
t.Fatalf("expected 1, got %v", count)
|
||||
}
|
||||
if !stringsEquals(exp, all) {
|
||||
t.Fatal("mismatch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBTree(t *testing.T) {
|
||||
N := 10000
|
||||
var tr BTree
|
||||
keys := randKeys(N)
|
||||
|
||||
// insert all items
|
||||
for _, key := range keys {
|
||||
value, replaced := tr.Set(makeItem(key, key))
|
||||
if replaced {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
if value != nil {
|
||||
t.Fatal("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
// check length
|
||||
if tr.Len() != len(keys) {
|
||||
t.Fatalf("expected %v, got %v", len(keys), tr.Len())
|
||||
}
|
||||
|
||||
// get each value
|
||||
for _, key := range keys {
|
||||
value, gotten := tr.Get(key)
|
||||
if !gotten {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
if value == nil || itemValue(value) != key {
|
||||
t.Fatalf("expected '%v', got '%v'", key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// scan all items
|
||||
var last string
|
||||
all := make(map[string]interface{})
|
||||
tr.Scan(func(ptr unsafe.Pointer) bool {
|
||||
if itemKey(ptr) <= last {
|
||||
t.Fatal("out of order")
|
||||
}
|
||||
if itemValue(ptr).(string) != itemKey(ptr) {
|
||||
t.Fatalf("mismatch")
|
||||
}
|
||||
last = itemKey(ptr)
|
||||
all[itemKey(ptr)] = itemValue(ptr)
|
||||
return true
|
||||
})
|
||||
if len(all) != len(keys) {
|
||||
t.Fatalf("expected '%v', got '%v'", len(keys), len(all))
|
||||
}
|
||||
|
||||
// reverse all items
|
||||
var prev string
|
||||
all = make(map[string]interface{})
|
||||
tr.Reverse(func(ptr unsafe.Pointer) bool {
|
||||
if prev != "" && itemKey(ptr) >= prev {
|
||||
t.Fatal("out of order")
|
||||
}
|
||||
if itemValue(ptr).(string) != itemKey(ptr) {
|
||||
t.Fatalf("mismatch")
|
||||
}
|
||||
prev = itemKey(ptr)
|
||||
all[itemKey(ptr)] = itemValue(ptr)
|
||||
return true
|
||||
})
|
||||
if len(all) != len(keys) {
|
||||
t.Fatalf("expected '%v', got '%v'", len(keys), len(all))
|
||||
}
|
||||
|
||||
// try to get an invalid item
|
||||
value, gotten := tr.Get("invalid")
|
||||
if gotten {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
if value != nil {
|
||||
t.Fatal("expected nil")
|
||||
}
|
||||
|
||||
// scan and quit at various steps
|
||||
for i := 0; i < 100; i++ {
|
||||
var j int
|
||||
tr.Scan(func(ptr unsafe.Pointer) bool {
|
||||
if j == i {
|
||||
return false
|
||||
}
|
||||
j++
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// reverse and quit at various steps
|
||||
for i := 0; i < 100; i++ {
|
||||
var j int
|
||||
tr.Reverse(func(ptr unsafe.Pointer) bool {
|
||||
if j == i {
|
||||
return false
|
||||
}
|
||||
j++
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// delete half the items
|
||||
for _, key := range keys[:len(keys)/2] {
|
||||
value, deleted := tr.Delete(key)
|
||||
if !deleted {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
if value == nil || itemValue(value).(string) != key {
|
||||
t.Fatalf("expected '%v', got '%v'", key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// check length
|
||||
if tr.Len() != len(keys)/2 {
|
||||
t.Fatalf("expected %v, got %v", len(keys)/2, tr.Len())
|
||||
}
|
||||
|
||||
// try delete half again
|
||||
for _, key := range keys[:len(keys)/2] {
|
||||
value, deleted := tr.Delete(key)
|
||||
if deleted {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
if value != nil {
|
||||
t.Fatalf("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
// try delete half again
|
||||
for _, key := range keys[:len(keys)/2] {
|
||||
value, deleted := tr.Delete(key)
|
||||
if deleted {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
if value != nil {
|
||||
t.Fatalf("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
// check length
|
||||
if tr.Len() != len(keys)/2 {
|
||||
t.Fatalf("expected %v, got %v", len(keys)/2, tr.Len())
|
||||
}
|
||||
|
||||
// scan items
|
||||
last = ""
|
||||
all = make(map[string]interface{})
|
||||
tr.Scan(func(ptr unsafe.Pointer) bool {
|
||||
if itemKey(ptr) <= last {
|
||||
t.Fatal("out of order")
|
||||
}
|
||||
if itemValue(ptr).(string) != itemKey(ptr) {
|
||||
t.Fatalf("mismatch")
|
||||
}
|
||||
last = itemKey(ptr)
|
||||
all[itemKey(ptr)] = itemValue(ptr)
|
||||
return true
|
||||
})
|
||||
if len(all) != len(keys)/2 {
|
||||
t.Fatalf("expected '%v', got '%v'", len(keys), len(all))
|
||||
}
|
||||
|
||||
// replace second half
|
||||
for _, key := range keys[len(keys)/2:] {
|
||||
value, replaced := tr.Set(makeItem(key, key))
|
||||
if !replaced {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
if value == nil || itemValue(value).(string) != key {
|
||||
t.Fatalf("expected '%v', got '%v'", key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// delete next half the items
|
||||
for _, key := range keys[len(keys)/2:] {
|
||||
value, deleted := tr.Delete(key)
|
||||
if !deleted {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
if value == nil || itemValue(value).(string) != key {
|
||||
t.Fatalf("expected '%v', got '%v'", key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// check length
|
||||
if tr.Len() != 0 {
|
||||
t.Fatalf("expected %v, got %v", 0, tr.Len())
|
||||
}
|
||||
|
||||
// do some stuff on an empty tree
|
||||
value, gotten = tr.Get(keys[0])
|
||||
if gotten {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
if value != nil {
|
||||
t.Fatal("expected nil")
|
||||
}
|
||||
tr.Scan(func(ptr unsafe.Pointer) bool {
|
||||
t.Fatal("should not be reached")
|
||||
return true
|
||||
})
|
||||
tr.Reverse(func(ptr unsafe.Pointer) bool {
|
||||
t.Fatal("should not be reached")
|
||||
return true
|
||||
})
|
||||
|
||||
var deleted bool
|
||||
value, deleted = tr.Delete("invalid")
|
||||
if deleted {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
if value != nil {
|
||||
t.Fatal("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTidwallSequentialSet(b *testing.B) {
|
||||
var tr BTree
|
||||
keys := randKeys(b.N)
|
||||
sort.Strings(keys)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tr.Set(makeItem(keys[i], nil))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTidwallSequentialGet(b *testing.B) {
|
||||
var tr BTree
|
||||
keys := randKeys(b.N)
|
||||
sort.Strings(keys)
|
||||
for i := 0; i < b.N; i++ {
|
||||
tr.Set(makeItem(keys[i], nil))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tr.Get(keys[i])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTidwallRandomSet(b *testing.B) {
|
||||
var tr BTree
|
||||
keys := randKeys(b.N)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tr.Set(makeItem(keys[i], nil))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTidwallRandomGet(b *testing.B) {
|
||||
var tr BTree
|
||||
keys := randKeys(b.N)
|
||||
for i := 0; i < b.N; i++ {
|
||||
tr.Set(makeItem(keys[i], nil))
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tr.Get(keys[i])
|
||||
}
|
||||
}
|
||||
|
||||
// type googleKind struct {
|
||||
// key string
|
||||
// }
|
||||
|
||||
// func (a *googleKind) Less(b btree.Item) bool {
|
||||
// return a.key < b.(*googleKind).key
|
||||
// }
|
||||
|
||||
// func BenchmarkGoogleSequentialSet(b *testing.B) {
|
||||
// tr := btree.New(32)
|
||||
// keys := randKeys(b.N)
|
||||
// sort.Strings(keys)
|
||||
// gkeys := make([]*googleKind, len(keys))
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// gkeys[i] = &googleKind{keys[i]}
|
||||
// }
|
||||
// b.ResetTimer()
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// tr.ReplaceOrInsert(gkeys[i])
|
||||
// }
|
||||
// }
|
||||
|
||||
// func BenchmarkGoogleSequentialGet(b *testing.B) {
|
||||
// tr := btree.New(32)
|
||||
// keys := randKeys(b.N)
|
||||
// gkeys := make([]*googleKind, len(keys))
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// gkeys[i] = &googleKind{keys[i]}
|
||||
// }
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// tr.ReplaceOrInsert(gkeys[i])
|
||||
// }
|
||||
// sort.Strings(keys)
|
||||
// b.ResetTimer()
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// tr.Get(gkeys[i])
|
||||
// }
|
||||
// }
|
||||
|
||||
// func BenchmarkGoogleRandomSet(b *testing.B) {
|
||||
// tr := btree.New(32)
|
||||
// keys := randKeys(b.N)
|
||||
// gkeys := make([]*googleKind, len(keys))
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// gkeys[i] = &googleKind{keys[i]}
|
||||
// }
|
||||
// b.ResetTimer()
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// tr.ReplaceOrInsert(gkeys[i])
|
||||
// }
|
||||
// }
|
||||
|
||||
// func BenchmarkGoogleRandomGet(b *testing.B) {
|
||||
// tr := btree.New(32)
|
||||
// keys := randKeys(b.N)
|
||||
// gkeys := make([]*googleKind, len(keys))
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// gkeys[i] = &googleKind{keys[i]}
|
||||
// }
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// tr.ReplaceOrInsert(gkeys[i])
|
||||
// }
|
||||
// b.ResetTimer()
|
||||
// for i := 0; i < b.N; i++ {
|
||||
// tr.Get(gkeys[i])
|
||||
// }
|
||||
// }
|
||||
|
||||
func TestBTreeOne(t *testing.T) {
|
||||
var tr BTree
|
||||
tr.Set(makeItem("1", "1"))
|
||||
tr.Delete("1")
|
||||
tr.Set(makeItem("1", "1"))
|
||||
tr.Delete("1")
|
||||
tr.Set(makeItem("1", "1"))
|
||||
tr.Delete("1")
|
||||
}
|
||||
|
||||
func TestBTree256(t *testing.T) {
|
||||
var tr BTree
|
||||
var n int
|
||||
for j := 0; j < 2; j++ {
|
||||
for _, i := range rand.Perm(256) {
|
||||
tr.Set(makeItem(fmt.Sprintf("%d", i), i))
|
||||
n++
|
||||
if tr.Len() != n {
|
||||
t.Fatalf("expected 256, got %d", n)
|
||||
}
|
||||
}
|
||||
for _, i := range rand.Perm(256) {
|
||||
v, ok := tr.Get(fmt.Sprintf("%d", i))
|
||||
if !ok {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
if itemValue(v).(int) != i {
|
||||
t.Fatalf("expected %d, got %d", i, itemValue(v).(int))
|
||||
}
|
||||
}
|
||||
for _, i := range rand.Perm(256) {
|
||||
tr.Delete(fmt.Sprintf("%d", i))
|
||||
n--
|
||||
if tr.Len() != n {
|
||||
t.Fatalf("expected 256, got %d", n)
|
||||
}
|
||||
}
|
||||
for _, i := range rand.Perm(256) {
|
||||
_, ok := tr.Get(fmt.Sprintf("%d", i))
|
||||
if ok {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue