diff --git a/go.mod b/go.mod index 60a6b2a3..cfc3dd8f 100644 --- a/go.mod +++ b/go.mod @@ -16,15 +16,17 @@ require ( github.com/peterh/liner v1.2.1 github.com/prometheus/client_golang v1.12.1 github.com/streadway/amqp v1.0.0 + github.com/tidwall/assert v0.1.0 github.com/tidwall/btree v1.4.3 github.com/tidwall/buntdb v1.2.9 github.com/tidwall/geojson v1.3.6 - github.com/tidwall/gjson v1.12.1 + github.com/tidwall/gjson v1.14.3 + github.com/tidwall/hashmap v1.6.1 github.com/tidwall/match v1.1.1 github.com/tidwall/pretty v1.2.0 github.com/tidwall/redbench v0.1.0 github.com/tidwall/redcon v1.4.4 - github.com/tidwall/resp v0.1.0 + github.com/tidwall/resp v0.1.1 github.com/tidwall/rtree v1.8.1 github.com/tidwall/sjson v1.2.4 github.com/xdg/scram v1.0.5 @@ -101,7 +103,7 @@ require ( golang.org/x/mod v0.3.0 // indirect golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c // indirect golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect - golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect + golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/tools v0.0.0-20200825202427-b303f430e36d // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect diff --git a/go.sum b/go.sum index d478d03a..c8983911 100644 --- a/go.sum +++ b/go.sum @@ -361,10 +361,13 @@ github.com/tidwall/geoindex v1.7.0 h1:jtk41sfgwIt8MEDyC3xyKSj75iXXf6rjReJGDNPtR5 github.com/tidwall/geoindex v1.7.0/go.mod h1:rvVVNEFfkJVWGUdEfU8QaoOg/9zFX0h9ofWzA60mz1I= github.com/tidwall/geojson v1.3.6 h1:ZbpDNwdhXyDe8XGTplGVaGrcS2ViFaSoo3QBNXe1uhM= github.com/tidwall/geojson v1.3.6/go.mod h1:1cn3UWfSYCJOq53NZoQ9rirdw89+DM0vw+ZOAVvuReg= -github.com/tidwall/gjson v1.12.1 h1:ikuZsLdhr8Ws0IdROXUS1Gi4v9Z4pGqpX/CvJkxvfpo= github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= +github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/grect v0.1.4 h1:dA3oIgNgWdSspFzn1kS4S/RDpZFLrIxAZOdJKjYapOg= github.com/tidwall/grect v0.1.4/go.mod h1:9FBsaYRaR0Tcy4UwefBX/UDcDcDy9V5jUcxHzv2jd5Q= +github.com/tidwall/hashmap v1.6.1 h1:FIAHjKwcyOo1Y3/orsQO08floKhInbEX2VQv7CQRNuw= +github.com/tidwall/hashmap v1.6.1/go.mod h1:hX452N3VtFD8okD3/6q/yOquJvJmYxmZ1H0nLtwkaxM= github.com/tidwall/lotsa v1.0.2 h1:dNVBH5MErdaQ/xd9s769R31/n2dXavsQ0Yf4TMEHHw8= github.com/tidwall/lotsa v1.0.2/go.mod h1:X6NiU+4yHA3fE3Puvpnn1XMDrFZrE9JO2/w+UMuqgR8= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -375,8 +378,8 @@ github.com/tidwall/redbench v0.1.0 h1:UZYUMhwMMObQRq5xU4SA3lmlJRztXzqtushDii+AmP github.com/tidwall/redbench v0.1.0/go.mod h1:zxcRGCq/JcqV48YjK9WxBNJL7JSpMzbLXaHvMcnanKQ= github.com/tidwall/redcon v1.4.4 h1:N3ZwZx6n5dqNxB3cfmj9D/8zNboFia5FAv1wt+azwyU= github.com/tidwall/redcon v1.4.4/go.mod h1:p5Wbsgeyi2VSTBWOcA5vRXrOb9arFTcU2+ZzFjqV75Y= -github.com/tidwall/resp v0.1.0 h1:zZ6Hq+2cY4QqhZ4LqrV05T5yLOSPspj+l+DgAoJ25Ak= -github.com/tidwall/resp v0.1.0/go.mod h1:18xEj855iMY2bK6tNF2A4x+nZy5gWO1iO7OOl3jETKw= +github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= +github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= github.com/tidwall/rtred v0.1.2 h1:exmoQtOLvDoO8ud++6LwVsAMTu0KPzLTUrMln8u1yu8= github.com/tidwall/rtred v0.1.2/go.mod h1:hd69WNXQ5RP9vHd7dqekAz+RIdtfBogmglkZSRxCHFQ= github.com/tidwall/rtree v1.3.1/go.mod h1:S+JSsqPTI8LfWA4xHBo5eXzie8WJLVFeppAutSegl6M= @@ -551,8 +554,9 @@ golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 h1:UiNENfZ8gDvpiWw7IpOMQ27spWmThO1RwwdQVbJahJM= +golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= diff --git a/internal/collection/collection.go b/internal/collection/collection.go index 60b96fa4..8eff93e2 100644 --- a/internal/collection/collection.go +++ b/internal/collection/collection.go @@ -9,6 +9,7 @@ import ( "github.com/tidwall/geojson/geometry" "github.com/tidwall/rtree" "github.com/tidwall/tile38/internal/deadline" + "github.com/tidwall/tile38/internal/field" ) // yieldStep forces the iterator to yield goroutine every 256 steps. @@ -21,10 +22,10 @@ type Cursor interface { } type itemT struct { - id string - obj geojson.Object - expires int64 // unix nano expiration - fieldValuesSlot fieldValuesSlot + id string + obj geojson.Object + expires int64 // unix nano expiration + fields field.List } func byID(a, b *itemT) bool { @@ -57,17 +58,14 @@ func byExpires(a, b *itemT) bool { // Collection represents a collection of geojson objects. type Collection struct { - items *btree.BTreeG[*itemT] // items sorted by id - spatial *rtree.RTreeG[*itemT] // items geospatially indexed - values *btree.BTreeG[*itemT] // items sorted by value+id - expires *btree.BTreeG[*itemT] // items sorted by ex+id - fieldMap map[string]int - fieldArr []string - fieldValues *fieldValues - weight int - points int - objects int // geometry count - nobjects int // non-geometry count + items *btree.BTreeG[*itemT] // items sorted by id + spatial *rtree.RTreeG[*itemT] // items geospatially indexed + values *btree.BTreeG[*itemT] // items sorted by value+id + expires *btree.BTreeG[*itemT] // items sorted by ex+id + weight int + points int + objects int // geometry count + nobjects int // non-geometry count } var optsNoLock = btree.Options{NoLocks: true} @@ -75,13 +73,10 @@ var optsNoLock = btree.Options{NoLocks: true} // New creates an empty collection func New() *Collection { col := &Collection{ - items: btree.NewBTreeGOptions(byID, optsNoLock), - values: btree.NewBTreeGOptions(byValue, optsNoLock), - expires: btree.NewBTreeGOptions(byExpires, optsNoLock), - spatial: &rtree.RTreeG[*itemT]{}, - fieldMap: make(map[string]int), - fieldArr: make([]string, 0), - fieldValues: &fieldValues{}, + items: btree.NewBTreeGOptions(byID, optsNoLock), + values: btree.NewBTreeGOptions(byValue, optsNoLock), + expires: btree.NewBTreeGOptions(byExpires, optsNoLock), + spatial: &rtree.RTreeG[*itemT]{}, } return col } @@ -122,12 +117,14 @@ func objIsSpatial(obj geojson.Object) bool { func (c *Collection) objWeight(item *itemT) int { var weight int + weight += len(item.id) if objIsSpatial(item.obj) { - weight = item.obj.NumPoints() * 16 + weight += item.obj.NumPoints() * 16 } else { - weight = len(item.obj.String()) + weight += len(item.obj.String()) } - return weight + len(c.fieldValues.get(item.fieldValuesSlot))*8 + len(item.id) + weight += item.fields.Weight() + return weight } func (c *Collection) indexDelete(item *itemT) { @@ -151,16 +148,16 @@ func (c *Collection) indexInsert(item *itemT) { } // Set adds or replaces an object in the collection and returns the fields -// array. If an item with the same id is already in the collection then the -// new item will adopt the old item's fields. -// The fields argument is optional. -// The return values are the old object, the old fields, and the new fields -func (c *Collection) Set( - id string, obj geojson.Object, fields []string, values []float64, ex int64, -) ( - oldObject geojson.Object, oldFieldValues []float64, newFieldValues []float64, +// array. +func (c *Collection) Set(id string, obj geojson.Object, fields field.List, ex int64) ( + oldObject geojson.Object, oldFields, newFields field.List, ) { - newItem := &itemT{id: id, obj: obj, fieldValuesSlot: nilValuesSlot, expires: ex} + newItem := &itemT{ + id: id, + obj: obj, + expires: ex, + fields: fields, + } // add the new item to main btree and remove the old one if needed oldItem, ok := c.items.Set(newItem) @@ -183,24 +180,6 @@ func (c *Collection) Set( // decrement the weights c.weight -= c.objWeight(oldItem) - - // references - oldObject = oldItem.obj - oldFieldValues = c.fieldValues.get(oldItem.fieldValuesSlot) - newFieldValues = oldFieldValues - newItem.fieldValuesSlot = oldItem.fieldValuesSlot - if len(oldFieldValues) > 0 { - oldFieldValues = append([]float64{}, oldFieldValues...) - } - } - if fields == nil { - if len(values) > 0 { - newFieldValues = values - newFieldValuesSlot := c.fieldValues.set(newItem.fieldValuesSlot, newFieldValues) - newItem.fieldValuesSlot = newFieldValuesSlot - } - } else { - newFieldValues, _, _ = c.setFieldValues(newItem, fields, values) } // insert the new item into the rtree or strings tree. @@ -222,17 +201,20 @@ func (c *Collection) Set( // add the new weights c.weight += c.objWeight(newItem) - return oldObject, oldFieldValues, newFieldValues + if oldItem != nil { + return oldItem.obj, oldItem.fields, newItem.fields + } + return nil, field.List{}, newItem.fields } // Delete removes an object and returns it. // If the object does not exist then the 'ok' return value will be false. func (c *Collection) Delete(id string) ( - obj geojson.Object, fields []float64, ok bool, + obj geojson.Object, fields field.List, ok bool, ) { oldItem, ok := c.items.Delete(&itemT{id: id}) if !ok { - return nil, nil, false + return nil, field.List{}, false } if objIsSpatial(oldItem.obj) { if !oldItem.obj.Empty() { @@ -250,127 +232,22 @@ func (c *Collection) Delete(id string) ( c.weight -= c.objWeight(oldItem) c.points -= oldItem.obj.NumPoints() - fields = c.fieldValues.get(oldItem.fieldValuesSlot) - c.fieldValues.remove(oldItem.fieldValuesSlot) - return oldItem.obj, fields, true + return oldItem.obj, oldItem.fields, true } // Get returns an object. // If the object does not exist then the 'ok' return value will be false. func (c *Collection) Get(id string) ( - obj geojson.Object, fields []float64, ex int64, ok bool, + obj geojson.Object, + fields field.List, + ex int64, + ok bool, ) { item, ok := c.items.Get(&itemT{id: id}) if !ok { - return nil, nil, 0, false - } - return item.obj, c.fieldValues.get(item.fieldValuesSlot), item.expires, true -} - -func (c *Collection) SetExpires(id string, ex int64) bool { - item, ok := c.items.Get(&itemT{id: id}) - if !ok { - return false - } - if item.expires != 0 { - c.expires.Delete(item) - } - item.expires = ex - if item.expires != 0 { - c.expires.Set(item) - } - return true -} - -// SetField set a field value for an object and returns that object. -// If the object does not exist then the 'ok' return value will be false. -func (c *Collection) SetField(id, field string, value float64) ( - obj geojson.Object, fields []float64, updated bool, ok bool, -) { - item, ok := c.items.Get(&itemT{id: id}) - if !ok { - return nil, nil, false, false - } - _, updateCount, weightDelta := c.setFieldValues(item, []string{field}, []float64{value}) - c.weight += weightDelta - return item.obj, c.fieldValues.get(item.fieldValuesSlot), updateCount > 0, true -} - -// SetFields is similar to SetField, just setting multiple fields at once -func (c *Collection) SetFields( - id string, inFields []string, inValues []float64, -) (obj geojson.Object, fields []float64, updatedCount int, ok bool) { - item, ok := c.items.Get(&itemT{id: id}) - if !ok { - return nil, nil, 0, false - } - newFieldValues, updateCount, weightDelta := c.setFieldValues(item, inFields, inValues) - c.weight += weightDelta - return item.obj, newFieldValues, updateCount, true -} - -func (c *Collection) setFieldValues(item *itemT, fields []string, updateValues []float64) ( - newValues []float64, - updated int, - weightDelta int, -) { - newValues = c.fieldValues.get(item.fieldValuesSlot) - for i, field := range fields { - fieldIdx, ok := c.fieldMap[field] - if !ok { - fieldIdx = len(c.fieldMap) - c.fieldMap[field] = fieldIdx - c.addToFieldArr(field) - } - for fieldIdx >= len(newValues) { - newValues = append(newValues, 0) - weightDelta += 8 - } - ovalue := newValues[fieldIdx] - nvalue := updateValues[i] - newValues[fieldIdx] = nvalue - if ovalue != nvalue { - updated++ - } - } - newSlot := c.fieldValues.set(item.fieldValuesSlot, newValues) - item.fieldValuesSlot = newSlot - return newValues, updated, weightDelta -} - -// FieldMap return a maps of the field names. -func (c *Collection) FieldMap() map[string]int { - return c.fieldMap -} - -// FieldArr return an array representation of the field names. -func (c *Collection) FieldArr() []string { - return c.fieldArr -} - -// bsearch searches array for value. -func bsearch(arr []string, val string) (index int, found bool) { - i, j := 0, len(arr) - for i < j { - h := i + (j-i)/2 - if val >= arr[h] { - i = h + 1 - } else { - j = h - } - } - if i > 0 && arr[i-1] >= val { - return i - 1, true - } - return i, false -} - -func (c *Collection) addToFieldArr(field string) { - if index, found := bsearch(c.fieldArr, field); !found { - c.fieldArr = append(c.fieldArr, "") - copy(c.fieldArr[index+1:], c.fieldArr[index:len(c.fieldArr)-1]) - c.fieldArr[index] = field + return nil, field.List{}, 0, false } + return item.obj, item.fields, item.expires, true } // Scan iterates though the collection ids. @@ -378,7 +255,7 @@ func (c *Collection) Scan( desc bool, cursor Cursor, deadline *deadline.Deadline, - iterator func(id string, obj geojson.Object, fields []float64) bool, + iterator func(id string, obj geojson.Object, fields field.List) bool, ) bool { var keepon = true var count uint64 @@ -393,7 +270,7 @@ func (c *Collection) Scan( return true } nextStep(count, cursor, deadline) - keepon = iterator(item.id, item.obj, c.fieldValues.get(item.fieldValuesSlot)) + keepon = iterator(item.id, item.obj, item.fields) return keepon } if desc { @@ -410,7 +287,7 @@ func (c *Collection) ScanRange( desc bool, cursor Cursor, deadline *deadline.Deadline, - iterator func(id string, obj geojson.Object, fields []float64) bool, + iterator func(id string, obj geojson.Object, fields field.List) bool, ) bool { var keepon = true var count uint64 @@ -434,7 +311,7 @@ func (c *Collection) ScanRange( return false } } - keepon = iterator(item.id, item.obj, c.fieldValues.get(item.fieldValuesSlot)) + keepon = iterator(item.id, item.obj, item.fields) return keepon } @@ -451,7 +328,7 @@ func (c *Collection) SearchValues( desc bool, cursor Cursor, deadline *deadline.Deadline, - iterator func(id string, obj geojson.Object, fields []float64) bool, + iterator func(id string, obj geojson.Object, fields field.List) bool, ) bool { var keepon = true var count uint64 @@ -466,7 +343,7 @@ func (c *Collection) SearchValues( return true } nextStep(count, cursor, deadline) - keepon = iterator(item.id, item.obj, c.fieldValues.get(item.fieldValuesSlot)) + keepon = iterator(item.id, item.obj, item.fields) return keepon } if desc { @@ -481,7 +358,7 @@ func (c *Collection) SearchValues( func (c *Collection) SearchValuesRange(start, end string, desc bool, cursor Cursor, deadline *deadline.Deadline, - iterator func(id string, obj geojson.Object, fields []float64) bool, + iterator func(id string, obj geojson.Object, fields field.List) bool, ) bool { var keepon = true var count uint64 @@ -496,7 +373,7 @@ func (c *Collection) SearchValuesRange(start, end string, desc bool, return true } nextStep(count, cursor, deadline) - keepon = iterator(item.id, item.obj, c.fieldValues.get(item.fieldValuesSlot)) + keepon = iterator(item.id, item.obj, item.fields) return keepon } pstart := &itemT{obj: String(start)} @@ -521,7 +398,7 @@ func bGT(tr *btree.BTreeG[*itemT], a, b *itemT) bool { return tr.Less(b, a) } func (c *Collection) ScanGreaterOrEqual(id string, desc bool, cursor Cursor, deadline *deadline.Deadline, - iterator func(id string, obj geojson.Object, fields []float64, ex int64) bool, + iterator func(id string, obj geojson.Object, fields field.List, ex int64) bool, ) bool { var keepon = true var count uint64 @@ -536,7 +413,7 @@ func (c *Collection) ScanGreaterOrEqual(id string, desc bool, return true } nextStep(count, cursor, deadline) - keepon = iterator(item.id, item.obj, c.fieldValues.get(item.fieldValuesSlot), item.expires) + keepon = iterator(item.id, item.obj, item.fields, item.expires) return keepon } if desc { @@ -549,14 +426,14 @@ func (c *Collection) ScanGreaterOrEqual(id string, desc bool, func (c *Collection) geoSearch( rect geometry.Rect, - iter func(id string, obj geojson.Object, fields []float64) bool, + iter func(id string, obj geojson.Object, fields field.List) bool, ) bool { alive := true c.spatial.Search( [2]float64{rect.Min.X, rect.Min.Y}, [2]float64{rect.Max.X, rect.Max.Y}, func(_, _ [2]float64, item *itemT) bool { - alive = iter(item.id, item.obj, c.fieldValues.get(item.fieldValuesSlot)) + alive = iter(item.id, item.obj, item.fields) return alive }, ) @@ -565,12 +442,12 @@ func (c *Collection) geoSearch( func (c *Collection) geoSparse( obj geojson.Object, sparse uint8, - iter func(id string, obj geojson.Object, fields []float64) (match, ok bool), + iter func(id string, obj geojson.Object, fields field.List) (match, ok bool), ) bool { matches := make(map[string]bool) alive := true c.geoSparseInner(obj.Rect(), sparse, - func(id string, o geojson.Object, fields []float64) ( + func(id string, o geojson.Object, fields field.List) ( match, ok bool, ) { ok = true @@ -587,7 +464,7 @@ func (c *Collection) geoSparse( } func (c *Collection) geoSparseInner( rect geometry.Rect, sparse uint8, - iter func(id string, obj geojson.Object, fields []float64) (match, ok bool), + iter func(id string, obj geojson.Object, fields field.List) (match, ok bool), ) bool { if sparse > 0 { w := rect.Max.X - rect.Min.X @@ -619,7 +496,7 @@ func (c *Collection) geoSparseInner( } alive := true c.geoSearch(rect, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, fields field.List) bool { match, ok := iter(id, obj, fields) if !ok { alive = false @@ -638,7 +515,7 @@ func (c *Collection) Within( sparse uint8, cursor Cursor, deadline *deadline.Deadline, - iter func(id string, obj geojson.Object, fields []float64) bool, + iter func(id string, obj geojson.Object, fields field.List) bool, ) bool { var count uint64 var offset uint64 @@ -648,7 +525,7 @@ func (c *Collection) Within( } if sparse > 0 { return c.geoSparse(obj, sparse, - func(id string, o geojson.Object, fields []float64) ( + func(id string, o geojson.Object, fields field.List) ( match, ok bool, ) { count++ @@ -664,7 +541,7 @@ func (c *Collection) Within( ) } return c.geoSearch(obj.Rect(), - func(id string, o geojson.Object, fields []float64) bool { + func(id string, o geojson.Object, fields field.List) bool { count++ if count <= offset { return true @@ -685,7 +562,7 @@ func (c *Collection) Intersects( sparse uint8, cursor Cursor, deadline *deadline.Deadline, - iter func(id string, obj geojson.Object, fields []float64) bool, + iter func(id string, obj geojson.Object, fields field.List) bool, ) bool { var count uint64 var offset uint64 @@ -695,7 +572,7 @@ func (c *Collection) Intersects( } if sparse > 0 { return c.geoSparse(obj, sparse, - func(id string, o geojson.Object, fields []float64) ( + func(id string, o geojson.Object, fields field.List) ( match, ok bool, ) { count++ @@ -711,7 +588,7 @@ func (c *Collection) Intersects( ) } return c.geoSearch(obj.Rect(), - func(id string, o geojson.Object, fields []float64) bool { + func(id string, o geojson.Object, fields field.List) bool { count++ if count <= offset { return true @@ -730,7 +607,7 @@ func (c *Collection) Nearby( target geojson.Object, cursor Cursor, deadline *deadline.Deadline, - iter func(id string, obj geojson.Object, fields []float64, dist float64) bool, + iter func(id string, obj geojson.Object, fields field.List, dist float64) bool, ) bool { // First look to see if there's at least one candidate in the circle's // outer rectangle. This is a fast-fail operation. @@ -772,7 +649,7 @@ func (c *Collection) Nearby( return true } nextStep(count, cursor, deadline) - alive = iter(item.id, item.obj, c.fieldValues.get(item.fieldValuesSlot), dist) + alive = iter(item.id, item.obj, item.fields, dist) return alive }, ) diff --git a/internal/collection/collection_test.go b/internal/collection/collection_test.go index 3f32c658..6a4e0592 100644 --- a/internal/collection/collection_test.go +++ b/internal/collection/collection_test.go @@ -11,6 +11,7 @@ import ( "github.com/tidwall/geojson" "github.com/tidwall/geojson/geometry" "github.com/tidwall/gjson" + "github.com/tidwall/tile38/internal/field" ) func PO(x, y float64) *geojson.Point { @@ -46,14 +47,14 @@ func TestCollectionNewCollection(t *testing.T) { id := strconv.FormatInt(int64(i), 10) obj := PO(rand.Float64()*360-180, rand.Float64()*180-90) objs[id] = obj - c.Set(id, obj, nil, nil, 0) + c.Set(id, obj, field.List{}, 0) } count := 0 bbox := geometry.Rect{ Min: geometry.Point{X: -180, Y: -90}, Max: geometry.Point{X: 180, Y: 90}, } - c.geoSearch(bbox, func(id string, obj geojson.Object, field []float64) bool { + c.geoSearch(bbox, func(id string, obj geojson.Object, _ field.List) bool { count++ return true }) @@ -67,77 +68,95 @@ func TestCollectionNewCollection(t *testing.T) { testCollectionVerifyContents(t, c, objs) } +func toFields(fNames, fValues []string) field.List { + var fields field.List + for i := 0; i < len(fNames); i++ { + fields = fields.Set(field.Make(fNames[i], fValues[i])) + } + return fields +} + func TestCollectionSet(t *testing.T) { t.Run("AddString", func(t *testing.T) { c := New() str1 := String("hello") - oldObject, oldFields, newFields := c.Set("str", str1, nil, nil, 0) + oldObject, oldFields, newFields := c.Set("str", str1, field.List{}, 0) expect(t, oldObject == nil) - expect(t, len(oldFields) == 0) - expect(t, len(newFields) == 0) + expect(t, oldFields.Len() == 0) + expect(t, newFields.Len() == 0) }) t.Run("UpdateString", func(t *testing.T) { c := New() str1 := String("hello") str2 := String("world") - oldObject, oldFields, newFields := c.Set("str", str1, nil, nil, 0) + oldObject, oldFields, newFields := c.Set("str", str1, field.List{}, 0) expect(t, oldObject == nil) - expect(t, len(oldFields) == 0) - expect(t, len(newFields) == 0) - oldObject, oldFields, newFields = c.Set("str", str2, nil, nil, 0) + expect(t, oldFields.Len() == 0) + expect(t, newFields.Len() == 0) + oldObject, oldFields, newFields = c.Set("str", str2, field.List{}, 0) expect(t, oldObject == str1) - expect(t, len(oldFields) == 0) - expect(t, len(newFields) == 0) + expect(t, oldFields.Len() == 0) + expect(t, newFields.Len() == 0) }) t.Run("AddPoint", func(t *testing.T) { c := New() point1 := PO(-112.1, 33.1) - oldObject, oldFields, newFields := c.Set("point", point1, nil, nil, 0) + oldObject, oldFields, newFields := c.Set("point", point1, field.List{}, 0) expect(t, oldObject == nil) - expect(t, len(oldFields) == 0) - expect(t, len(newFields) == 0) + expect(t, oldFields.Len() == 0) + expect(t, newFields.Len() == 0) }) t.Run("UpdatePoint", func(t *testing.T) { c := New() point1 := PO(-112.1, 33.1) point2 := PO(-112.2, 33.2) - oldObject, oldFields, newFields := c.Set("point", point1, nil, nil, 0) + oldObject, oldFields, newFields := c.Set("point", point1, field.List{}, 0) expect(t, oldObject == nil) - expect(t, len(oldFields) == 0) - expect(t, len(newFields) == 0) - oldObject, oldFields, newFields = c.Set("point", point2, nil, nil, 0) + expect(t, oldFields.Len() == 0) + expect(t, newFields.Len() == 0) + oldObject, oldFields, newFields = c.Set("point", point2, field.List{}, 0) expect(t, oldObject == point1) - expect(t, len(oldFields) == 0) - expect(t, len(newFields) == 0) + expect(t, oldFields.Len() == 0) + expect(t, newFields.Len() == 0) }) t.Run("Fields", func(t *testing.T) { c := New() str1 := String("hello") + fNames := []string{"a", "b", "c"} - fValues := []float64{1, 2, 3} - oldObj, oldFlds, newFlds := c.Set("str", str1, fNames, fValues, 0) + fValues := []string{"1", "2", "3"} + fields1 := toFields(fNames, fValues) + oldObj, oldFlds, newFlds := c.Set("str", str1, fields1, 0) + expect(t, oldObj == nil) - expect(t, len(oldFlds) == 0) - expect(t, reflect.DeepEqual(newFlds, fValues)) + expect(t, oldFlds.Len() == 0) + expect(t, reflect.DeepEqual(newFlds, fields1)) + str2 := String("hello") + fNames = []string{"d", "e", "f"} - fValues = []float64{4, 5, 6} - oldObj, oldFlds, newFlds = c.Set("str", str2, fNames, fValues, 0) + fValues = []string{"4", "5", "6"} + fields2 := toFields(fNames, fValues) + + oldObj, oldFlds, newFlds = c.Set("str", str2, fields2, 0) expect(t, oldObj == str1) - expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3})) - expect(t, reflect.DeepEqual(newFlds, []float64{1, 2, 3, 4, 5, 6})) - fValues = []float64{7, 8, 9, 10, 11, 12} - oldObj, oldFlds, newFlds = c.Set("str", str1, nil, fValues, 0) + expect(t, reflect.DeepEqual(oldFlds, fields1)) + expect(t, reflect.DeepEqual(newFlds, fields2)) + + fNames = []string{"a", "b", "c", "d", "e", "f"} + fValues = []string{"7", "8", "9", "10", "11", "12"} + fields3 := toFields(fNames, fValues) + oldObj, oldFlds, newFlds = c.Set("str", str1, fields3, 0) expect(t, oldObj == str2) - expect(t, reflect.DeepEqual(oldFlds, []float64{1, 2, 3, 4, 5, 6})) - expect(t, reflect.DeepEqual(newFlds, []float64{7, 8, 9, 10, 11, 12})) + expect(t, reflect.DeepEqual(oldFlds, fields2)) + expect(t, reflect.DeepEqual(newFlds, fields3)) }) t.Run("Delete", func(t *testing.T) { c := New() - c.Set("1", String("1"), nil, nil, 0) - c.Set("2", String("2"), nil, nil, 0) - c.Set("3", PO(1, 2), nil, nil, 0) + c.Set("1", String("1"), field.List{}, 0) + c.Set("2", String("2"), field.List{}, 0) + c.Set("3", PO(1, 2), field.List{}, 0) expect(t, c.Count() == 3) expect(t, c.StringCount() == 2) @@ -147,9 +166,9 @@ func TestCollectionSet(t *testing.T) { Max: geometry.Point{X: 1, Y: 2}}) var v geojson.Object var ok bool - var flds []float64 - var updated bool - var updateCount int + // var flds []float64 + // var updated bool + // var updateCount int v, _, ok = c.Delete("2") expect(t, v.String() == "2") @@ -165,32 +184,32 @@ func TestCollectionSet(t *testing.T) { expect(t, c.StringCount() == 0) expect(t, c.PointCount() == 1) - expect(t, len(c.FieldMap()) == 0) + // expect(t, len(c.FieldMap()) == 0) - _, flds, updated, ok = c.SetField("3", "hello", 123) - expect(t, ok) - expect(t, reflect.DeepEqual(flds, []float64{123})) - expect(t, updated) - expect(t, c.FieldMap()["hello"] == 0) + // _, flds, updated, ok = c.SetField("3", "hello", 123) + // expect(t, ok) + // expect(t, reflect.DeepEqual(flds, []float64{123})) + // expect(t, updated) + // expect(t, c.FieldMap()["hello"] == 0) - _, flds, updated, ok = c.SetField("3", "hello", 1234) - expect(t, ok) - expect(t, reflect.DeepEqual(flds, []float64{1234})) - expect(t, updated) + // _, flds, updated, ok = c.SetField("3", "hello", 1234) + // expect(t, ok) + // expect(t, reflect.DeepEqual(flds, []float64{1234})) + // expect(t, updated) - _, flds, updated, ok = c.SetField("3", "hello", 1234) - expect(t, ok) - expect(t, reflect.DeepEqual(flds, []float64{1234})) - expect(t, !updated) + // _, flds, updated, ok = c.SetField("3", "hello", 1234) + // expect(t, ok) + // expect(t, reflect.DeepEqual(flds, []float64{1234})) + // expect(t, !updated) - _, flds, updateCount, ok = c.SetFields("3", - []string{"planet", "world"}, []float64{55, 66}) - expect(t, ok) - expect(t, reflect.DeepEqual(flds, []float64{1234, 55, 66})) - expect(t, updateCount == 2) - expect(t, c.FieldMap()["hello"] == 0) - expect(t, c.FieldMap()["planet"] == 1) - expect(t, c.FieldMap()["world"] == 2) + // _, flds, updateCount, ok = c.SetFields("3", + // []string{"planet", "world"}, []float64{55, 66}) + // expect(t, ok) + // expect(t, reflect.DeepEqual(flds, []float64{1234, 55, 66})) + // expect(t, updateCount == 2) + // expect(t, c.FieldMap()["hello"] == 0) + // expect(t, c.FieldMap()["planet"] == 1) + // expect(t, c.FieldMap()["world"] == 2) v, _, ok = c.Delete("3") expect(t, v.String() == `{"type":"Point","coordinates":[1,2]}`) @@ -206,45 +225,63 @@ func TestCollectionSet(t *testing.T) { v, _, _, ok = c.Get("3") expect(t, v == nil) expect(t, !ok) - _, _, _, ok = c.SetField("3", "hello", 123) - expect(t, !ok) - _, _, _, ok = c.SetFields("3", []string{"hello"}, []float64{123}) - expect(t, !ok) - expect(t, c.TotalWeight() == 0) - expect(t, c.FieldMap()["hello"] == 0) - expect(t, c.FieldMap()["planet"] == 1) - expect(t, c.FieldMap()["world"] == 2) - expect(t, reflect.DeepEqual( - c.FieldArr(), []string{"hello", "planet", "world"}), - ) + // _, _, _, ok = c.SetField("3", "hello", 123) + // expect(t, !ok) + // _, _, _, ok = c.SetFields("3", []string{"hello"}, []float64{123}) + // expect(t, !ok) + // expect(t, c.TotalWeight() == 0) + // expect(t, c.FieldMap()["hello"] == 0) + // expect(t, c.FieldMap()["planet"] == 1) + // expect(t, c.FieldMap()["world"] == 2) + // expect(t, reflect.DeepEqual( + // c.FieldArr(), []string{"hello", "planet", "world"}), + // ) }) } +func fieldValueAt(fields field.List, index int) string { + if index < 0 || index >= fields.Len() { + panic("out of bounds") + } + var retval string + var i int + fields.Scan(func(f field.Field) bool { + if i == index { + retval = f.Value().Data() + } + i++ + return true + }) + return retval +} + func TestCollectionScan(t *testing.T) { N := 256 c := New() for _, i := range rand.Perm(N) { id := fmt.Sprintf("%04d", i) - c.Set(id, String(id), []string{"ex"}, []float64{float64(i)}, 0) + c.Set(id, String(id), makeFields( + field.Make("ex", id), + ), 0) } var n int var prevID string - c.Scan(false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.Scan(false, nil, nil, func(id string, obj geojson.Object, fields field.List) bool { if n > 0 { expect(t, id > prevID) } - expect(t, id == fmt.Sprintf("%04d", int(fields[0]))) + expect(t, id == fieldValueAt(fields, 0)) n++ prevID = id return true }) expect(t, n == c.Count()) n = 0 - c.Scan(true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.Scan(true, nil, nil, func(id string, obj geojson.Object, fields field.List) bool { if n > 0 { expect(t, id < prevID) } - expect(t, id == fmt.Sprintf("%04d", int(fields[0]))) + expect(t, id == fieldValueAt(fields, 0)) n++ prevID = id return true @@ -253,11 +290,11 @@ func TestCollectionScan(t *testing.T) { n = 0 c.ScanRange("0060", "0070", false, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, fields field.List) bool { if n > 0 { expect(t, id > prevID) } - expect(t, id == fmt.Sprintf("%04d", int(fields[0]))) + expect(t, id == fieldValueAt(fields, 0)) n++ prevID = id return true @@ -266,11 +303,11 @@ func TestCollectionScan(t *testing.T) { n = 0 c.ScanRange("0070", "0060", true, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, fields field.List) bool { if n > 0 { expect(t, id < prevID) } - expect(t, id == fmt.Sprintf("%04d", int(fields[0]))) + expect(t, id == fieldValueAt(fields, 0)) n++ prevID = id return true @@ -279,11 +316,11 @@ func TestCollectionScan(t *testing.T) { n = 0 c.ScanGreaterOrEqual("0070", true, nil, nil, - func(id string, obj geojson.Object, fields []float64, ex int64) bool { + func(id string, obj geojson.Object, fields field.List, ex int64) bool { if n > 0 { expect(t, id < prevID) } - expect(t, id == fmt.Sprintf("%04d", int(fields[0]))) + expect(t, id == fieldValueAt(fields, 0)) n++ prevID = id return true @@ -292,11 +329,11 @@ func TestCollectionScan(t *testing.T) { n = 0 c.ScanGreaterOrEqual("0070", false, nil, nil, - func(id string, obj geojson.Object, fields []float64, ex int64) bool { + func(id string, obj geojson.Object, fields field.List, ex int64) bool { if n > 0 { expect(t, id > prevID) } - expect(t, id == fmt.Sprintf("%04d", int(fields[0]))) + expect(t, id == fieldValueAt(fields, 0)) n++ prevID = id return true @@ -305,33 +342,44 @@ func TestCollectionScan(t *testing.T) { } +func makeFields(entries ...field.Field) field.List { + var fields field.List + for _, f := range entries { + fields = fields.Set(f) + } + return fields +} + func TestCollectionSearch(t *testing.T) { N := 256 c := New() for i, j := range rand.Perm(N) { id := fmt.Sprintf("%04d", j) ex := fmt.Sprintf("%04d", i) - c.Set(id, String(ex), []string{"i", "j"}, - []float64{float64(i), float64(j)}, 0) + c.Set(id, String(ex), + makeFields( + field.Make("i", ex), + field.Make("j", id), + ), 0) } var n int var prevValue string - c.SearchValues(false, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.SearchValues(false, nil, nil, func(id string, obj geojson.Object, fields field.List) bool { if n > 0 { expect(t, obj.String() > prevValue) } - expect(t, id == fmt.Sprintf("%04d", int(fields[1]))) + expect(t, id == fieldValueAt(fields, 1)) n++ prevValue = obj.String() return true }) expect(t, n == c.Count()) n = 0 - c.SearchValues(true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { + c.SearchValues(true, nil, nil, func(id string, obj geojson.Object, fields field.List) bool { if n > 0 { expect(t, obj.String() < prevValue) } - expect(t, id == fmt.Sprintf("%04d", int(fields[1]))) + expect(t, id == fieldValueAt(fields, 1)) n++ prevValue = obj.String() return true @@ -340,11 +388,11 @@ func TestCollectionSearch(t *testing.T) { n = 0 c.SearchValuesRange("0060", "0070", false, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, fields field.List) bool { if n > 0 { expect(t, obj.String() > prevValue) } - expect(t, id == fmt.Sprintf("%04d", int(fields[1]))) + expect(t, id == fieldValueAt(fields, 1)) n++ prevValue = obj.String() return true @@ -353,11 +401,11 @@ func TestCollectionSearch(t *testing.T) { n = 0 c.SearchValuesRange("0070", "0060", true, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, fields field.List) bool { if n > 0 { expect(t, obj.String() < prevValue) } - expect(t, id == fmt.Sprintf("%04d", int(fields[1]))) + expect(t, id == fieldValueAt(fields, 1)) n++ prevValue = obj.String() return true @@ -367,31 +415,39 @@ func TestCollectionSearch(t *testing.T) { func TestCollectionWeight(t *testing.T) { c := New() - c.Set("1", String("1"), nil, nil, 0) + c.Set("1", String("1"), field.List{}, 0) expect(t, c.TotalWeight() > 0) c.Delete("1") expect(t, c.TotalWeight() == 0) c.Set("1", String("1"), - []string{"a", "b", "c"}, - []float64{1, 2, 3}, + toFields( + []string{"a", "b", "c"}, + []string{"1", "2", "3"}, + ), 0, ) expect(t, c.TotalWeight() > 0) c.Delete("1") expect(t, c.TotalWeight() == 0) c.Set("1", String("1"), - []string{"a", "b", "c"}, - []float64{1, 2, 3}, + toFields( + []string{"a", "b", "c"}, + []string{"1", "2", "3"}, + ), 0, ) c.Set("2", String("2"), - []string{"d", "e", "f"}, - []float64{4, 5, 6}, + toFields( + []string{"d", "e", "f"}, + []string{"4", "5", "6"}, + ), 0, ) c.Set("1", String("1"), - []string{"d", "e", "f"}, - []float64{4, 5, 6}, + toFields( + []string{"d", "e", "f"}, + []string{"4", "5", "6"}, + ), 0, ) c.Delete("1") @@ -428,19 +484,19 @@ func TestSpatialSearch(t *testing.T) { q4, _ := geojson.Parse(gjson.Get(json, `features.#[id=="q4"]`).Raw, nil) c := New() - c.Set("p1", p1, nil, nil, 0) - c.Set("p2", p2, nil, nil, 0) - c.Set("p3", p3, nil, nil, 0) - c.Set("p4", p4, nil, nil, 0) - c.Set("r1", r1, nil, nil, 0) - c.Set("r2", r2, nil, nil, 0) - c.Set("r3", r3, nil, nil, 0) + c.Set("p1", p1, field.List{}, 0) + c.Set("p2", p2, field.List{}, 0) + c.Set("p3", p3, field.List{}, 0) + c.Set("p4", p4, field.List{}, 0) + c.Set("r1", r1, field.List{}, 0) + c.Set("r2", r2, field.List{}, 0) + c.Set("r3", r3, field.List{}, 0) var n int n = 0 c.Within(q1, 0, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, _ field.List) bool { n++ return true }, @@ -449,7 +505,7 @@ func TestSpatialSearch(t *testing.T) { n = 0 c.Within(q2, 0, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, _ field.List) bool { n++ return true }, @@ -458,7 +514,7 @@ func TestSpatialSearch(t *testing.T) { n = 0 c.Within(q3, 0, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, _ field.List) bool { n++ return true }, @@ -467,7 +523,7 @@ func TestSpatialSearch(t *testing.T) { n = 0 c.Intersects(q1, 0, nil, nil, - func(_ string, _ geojson.Object, _ []float64) bool { + func(_ string, _ geojson.Object, _ field.List) bool { n++ return true }, @@ -476,7 +532,7 @@ func TestSpatialSearch(t *testing.T) { n = 0 c.Intersects(q2, 0, nil, nil, - func(_ string, _ geojson.Object, _ []float64) bool { + func(_ string, _ geojson.Object, _ field.List) bool { n++ return true }, @@ -485,7 +541,7 @@ func TestSpatialSearch(t *testing.T) { n = 0 c.Intersects(q3, 0, nil, nil, - func(_ string, _ geojson.Object, _ []float64) bool { + func(_ string, _ geojson.Object, _ field.List) bool { n++ return true }, @@ -494,7 +550,7 @@ func TestSpatialSearch(t *testing.T) { n = 0 c.Intersects(q3, 0, nil, nil, - func(_ string, _ geojson.Object, _ []float64) bool { + func(_ string, _ geojson.Object, _ field.List) bool { n++ return n <= 1 }, @@ -509,7 +565,7 @@ func TestSpatialSearch(t *testing.T) { lastDist := float64(-1) distsMonotonic := true c.Nearby(q4, nil, nil, - func(id string, obj geojson.Object, fields []float64, dist float64) bool { + func(id string, obj geojson.Object, fields field.List, dist float64) bool { if dist < lastDist { distsMonotonic = false } @@ -534,12 +590,12 @@ func TestCollectionSparse(t *testing.T) { x := (r.Max.X-r.Min.X)*rand.Float64() + r.Min.X y := (r.Max.Y-r.Min.Y)*rand.Float64() + r.Min.Y point := PO(x, y) - c.Set(fmt.Sprintf("%d", i), point, nil, nil, 0) + c.Set(fmt.Sprintf("%d", i), point, field.List{}, 0) } var n int n = 0 c.Within(rect, 1, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, fields field.List) bool { n++ return true }, @@ -548,7 +604,7 @@ func TestCollectionSparse(t *testing.T) { n = 0 c.Within(rect, 2, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, fields field.List) bool { n++ return true }, @@ -557,7 +613,7 @@ func TestCollectionSparse(t *testing.T) { n = 0 c.Within(rect, 3, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, fields field.List) bool { n++ return true }, @@ -566,7 +622,7 @@ func TestCollectionSparse(t *testing.T) { n = 0 c.Within(rect, 3, nil, nil, - func(id string, obj geojson.Object, fields []float64) bool { + func(id string, obj geojson.Object, fields field.List) bool { n++ return n <= 30 }, @@ -575,7 +631,7 @@ func TestCollectionSparse(t *testing.T) { n = 0 c.Intersects(rect, 3, nil, nil, - func(id string, _ geojson.Object, _ []float64) bool { + func(id string, _ geojson.Object, _ field.List) bool { n++ return true }, @@ -584,7 +640,7 @@ func TestCollectionSparse(t *testing.T) { n = 0 c.Intersects(rect, 3, nil, nil, - func(id string, _ geojson.Object, _ []float64) bool { + func(id string, _ geojson.Object, _ field.List) bool { n++ return n <= 30 }, @@ -626,7 +682,7 @@ func TestManyCollections(t *testing.T) { col = New() colsM[key] = col } - col.Set(id, obj, nil, nil, 0) + col.Set(id, obj, field.List{}, 0) k++ } } @@ -637,7 +693,7 @@ func TestManyCollections(t *testing.T) { Min: geometry.Point{X: -180, Y: 30}, Max: geometry.Point{X: 34, Y: 100}, } - col.geoSearch(bbox, func(id string, obj geojson.Object, fields []float64) bool { + col.geoSearch(bbox, func(id string, obj geojson.Object, fields field.List) bool { //println(id) return true }) @@ -646,15 +702,17 @@ func TestManyCollections(t *testing.T) { type testPointItem struct { id string object geojson.Object - fields []float64 + fields field.List } -func makeBenchFields(nFields int) []float64 { - if nFields == 0 { - return nil +func makeBenchFields(nFields int) field.List { + var fields field.List + for i := 0; i < nFields; i++ { + key := fmt.Sprintf("%d", i) + val := key + fields = fields.Set(field.Make(key, val)) } - - return make([]float64, nFields) + return fields } func BenchmarkInsert_Fields(t *testing.B) { @@ -678,7 +736,7 @@ func benchmarkInsert(t *testing.B, nFields int) { col := New() t.ResetTimer() for i := 0; i < t.N; i++ { - col.Set(items[i].id, items[i].object, nil, items[i].fields, 0) + col.Set(items[i].id, items[i].object, items[i].fields, 0) } } @@ -702,11 +760,11 @@ func benchmarkReplace(t *testing.B, nFields int) { } col := New() for i := 0; i < t.N; i++ { - col.Set(items[i].id, items[i].object, nil, items[i].fields, 0) + col.Set(items[i].id, items[i].object, items[i].fields, 0) } t.ResetTimer() for _, i := range rand.Perm(t.N) { - o, _, _ := col.Set(items[i].id, items[i].object, nil, nil, 0) + o, _, _ := col.Set(items[i].id, items[i].object, field.List{}, 0) if o != items[i].object { t.Fatal("shoot!") } @@ -733,7 +791,7 @@ func benchmarkGet(t *testing.B, nFields int) { } col := New() for i := 0; i < t.N; i++ { - col.Set(items[i].id, items[i].object, nil, items[i].fields, 0) + col.Set(items[i].id, items[i].object, items[i].fields, 0) } t.ResetTimer() for _, i := range rand.Perm(t.N) { @@ -764,7 +822,7 @@ func benchmarkRemove(t *testing.B, nFields int) { } col := New() for i := 0; i < t.N; i++ { - col.Set(items[i].id, items[i].object, nil, items[i].fields, 0) + col.Set(items[i].id, items[i].object, items[i].fields, 0) } t.ResetTimer() for _, i := range rand.Perm(t.N) { @@ -795,12 +853,12 @@ func benchmarkScan(t *testing.B, nFields int) { } col := New() for i := 0; i < t.N; i++ { - col.Set(items[i].id, items[i].object, nil, items[i].fields, 0) + col.Set(items[i].id, items[i].object, items[i].fields, 0) } t.ResetTimer() for i := 0; i < t.N; i++ { var scanIteration int - col.Scan(true, nil, nil, func(id string, obj geojson.Object, fields []float64) bool { + col.Scan(true, nil, nil, func(id string, obj geojson.Object, fields field.List) bool { scanIteration++ return scanIteration <= 500 }) diff --git a/internal/collection/fieldvalues.go b/internal/collection/fieldvalues.go deleted file mode 100644 index cfb1b4bb..00000000 --- a/internal/collection/fieldvalues.go +++ /dev/null @@ -1,53 +0,0 @@ -package collection - -type fieldValues struct { - freelist []fieldValuesSlot - data [][]float64 -} - -type fieldValuesSlot int - -const nilValuesSlot fieldValuesSlot = -1 - -func (f *fieldValues) get(k fieldValuesSlot) []float64 { - if k == nilValuesSlot { - return nil - } - return f.data[int(k)] -} - -func (f *fieldValues) set(k fieldValuesSlot, itemData []float64) fieldValuesSlot { - // if we're asked to store into the nil values slot, it means one of two things: - // - we are doing a replace on an item that previously had nil fields - // - we are inserting a new item - // in either case, check if the new values are not nil, and if so allocate a - // new slot - if k == nilValuesSlot { - if itemData == nil { - return nilValuesSlot - } - - // first check if there is a slot on the freelist to reuse - if len(f.freelist) > 0 { - var slot fieldValuesSlot - slot, f.freelist = f.freelist[len(f.freelist)-1], f.freelist[:len(f.freelist)-1] - f.data[slot] = itemData - return slot - } - - // no reusable slot, append - f.data = append(f.data, itemData) - return fieldValuesSlot(len(f.data) - 1) - - } - f.data[int(k)] = itemData - return k -} - -func (f *fieldValues) remove(k fieldValuesSlot) { - if k == nilValuesSlot { - return - } - f.data[int(k)] = nil - f.freelist = append(f.freelist, k) -} diff --git a/internal/field/field.go b/internal/field/field.go new file mode 100644 index 00000000..1e0a2789 --- /dev/null +++ b/internal/field/field.go @@ -0,0 +1,232 @@ +package field + +import ( + "math" + "strconv" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/pretty" +) + +var ZeroValue = Value{kind: Number, data: "0", num: 0} +var ZeroField = Field{name: "", value: ZeroValue} + +type Kind byte + +const ( + Null = Kind(gjson.Null) + False = Kind(gjson.False) + Number = Kind(gjson.Number) + String = Kind(gjson.String) + True = Kind(gjson.True) + JSON = Kind(gjson.JSON) +) + +type Value struct { + kind Kind + data string + num float64 +} + +func (v Value) IsZero() bool { + return (v.kind == Number && v.data == "0" && v.num == 0) || v == (Value{}) +} + +func (v Value) Equals(b Value) bool { + return v.kind == b.kind && v.data == b.data +} + +func (v Value) Kind() Kind { + return v.kind +} + +func (v Value) Data() string { + return v.data +} + +func (v Value) Num() float64 { + return v.num +} + +func (v Value) JSON() string { + switch v.Kind() { + case Number: + switch v.Data() { + case "NaN": + return `"NaN"` + case "+Inf": + return `"+Inf"` + case "-Inf": + return `"-Inf"` + default: + return v.Data() + } + case String: + return string(gjson.AppendJSONString(nil, v.Data())) + case True: + return "true" + case False: + return "false" + case Null: + if v != (Value{}) { + return "null" + } + case JSON: + return v.Data() + } + return "0" +} + +func stringLessInsensitive(a, b string) bool { + for i := 0; i < len(a) && i < len(b); i++ { + if a[i] >= 'A' && a[i] <= 'Z' { + if b[i] >= 'A' && b[i] <= 'Z' { + // both are uppercase, do nothing + if a[i] < b[i] { + return true + } else if a[i] > b[i] { + return false + } + } else { + // a is uppercase, convert a to lowercase + if a[i]+32 < b[i] { + return true + } else if a[i]+32 > b[i] { + return false + } + } + } else if b[i] >= 'A' && b[i] <= 'Z' { + // b is uppercase, convert b to lowercase + if a[i] < b[i]+32 { + return true + } else if a[i] > b[i]+32 { + return false + } + } else { + // neither are uppercase + if a[i] < b[i] { + return true + } else if a[i] > b[i] { + return false + } + } + } + return len(a) < len(b) +} + +// Less return true if a value is less than another value. +// The caseSensitive paramater is used when the value are Strings. +// The order when comparing two different kinds is: +// +// Null < False < Number < String < True < JSON +// +// Pulled from github.com/tidwall/gjson +func (v Value) LessCase(b Value, caseSensitive bool) bool { + if v.kind < b.kind { + return true + } + if v.kind > b.kind { + return false + } + if v.kind == Number { + return v.num < b.num + } + if v.kind == String { + if caseSensitive { + return v.data < b.data + } + return stringLessInsensitive(v.data, b.data) + } + return v.data < b.data +} + +// Less return true if a value is less than another value. +// +// Null < False < Number < String < True < JSON +// +// Pulled from github.com/tidwall/gjson +func (v Value) Less(b Value) bool { + return v.LessCase(b, false) +} + +type Field struct { + name string + value Value +} + +func (f Field) Name() string { + return f.name +} + +func (f Field) Value() Value { + return f.value +} + +func (f Field) Weight() int { + return len(f.name) + 8 + len(f.value.data) +} + +var nan = math.NaN() +var pinf = math.Inf(+1) +var ninf = math.Inf(-1) + +func ValueOf(data string) Value { + data = strings.TrimSpace(data) + num, err := strconv.ParseFloat(data, 64) + if err == nil { + if math.IsInf(num, 0) { + if math.IsInf(num, +1) { + return Value{kind: Number, data: "+Inf", num: pinf} + } else { + return Value{kind: Number, data: "-Inf", num: ninf} + } + } else if math.IsNaN(num) { + return Value{kind: Number, data: "NaN", num: nan} + } + return Value{kind: Number, data: data, num: num} + } + if gjson.Valid(data) { + data = strings.TrimSpace(data) + r := gjson.Parse(data) + switch r.Type { + case gjson.Null: + return Value{kind: Null, data: "null"} + case gjson.JSON: + return Value{kind: JSON, data: string(pretty.Ugly([]byte(data)))} + case gjson.True: + return Value{kind: True, data: "true"} + case gjson.False: + return Value{kind: False, data: "false"} + case gjson.Number: + // Ignore. Numbers will always be picked up by the ParseFloat above. + case gjson.String: + // Ignore. Strings fallthrough by default + } + // Extract String from JSON + data = r.String() + } + // Check if string is NaN, Inf(inity), +Inf(inity), -Inf(inity) + if len(data) >= 3 && len(data) <= 9 { + switch data[0] { + case '-', '+', 'I', 'i', 'N', 'n': + switch strings.ToLower(data) { + case "nan": + return Value{kind: Number, data: "NaN", num: nan} + case "inf", "+inf", "infinity", "+infinity": + return Value{kind: Number, data: "+Inf", num: pinf} + case "-inf", "-infinity": + return Value{kind: Number, data: "-Inf", num: ninf} + } + } + } + + return Value{kind: String, data: data} +} + +func Make(name, data string) Field { + return Field{ + strings.ToLower(strings.TrimSpace(name)), + ValueOf(data), + } +} diff --git a/internal/field/field_test.go b/internal/field/field_test.go new file mode 100644 index 00000000..84b25ef1 --- /dev/null +++ b/internal/field/field_test.go @@ -0,0 +1,134 @@ +package field + +import ( + "testing" + + "github.com/tidwall/assert" +) + +func mLT(a, b Value) bool { return a.Less(b) } +func mLTE(a, b Value) bool { return !mLT(b, a) } +func mGT(a, b Value) bool { return mLT(b, a) } +func mGTE(a, b Value) bool { return !mLT(a, b) } +func mEQ(a, b Value) bool { return !mLT(a, b) && !mLT(b, a) } + +func TestOrder(t *testing.T) { + assert.Assert(mLT(ValueOf("hello"), ValueOf("jello"))) + assert.Assert(mLT(ValueOf("hello"), ValueOf("JELLO"))) + assert.Assert(mLT(ValueOf("HELLO"), ValueOf("JELLO"))) + assert.Assert(mLT(ValueOf("HELLO"), ValueOf("jello"))) + assert.Assert(!mLT(ValueOf("hello"), ValueOf("hello"))) + assert.Assert(!mLT(ValueOf("jello"), ValueOf("hello"))) + assert.Assert(!mLT(ValueOf("Jello"), ValueOf("Hello"))) + assert.Assert(!mLT(ValueOf("Jello"), ValueOf("hello"))) + assert.Assert(!mLT(ValueOf("jello"), ValueOf("Hello"))) + assert.Assert(mGT(ValueOf("jello"), ValueOf("hello"))) + assert.Assert(!mGT(ValueOf("jello"), ValueOf("jello"))) + assert.Assert(!mGT(ValueOf("hello"), ValueOf("jello"))) + assert.Assert(mLTE(ValueOf("hello"), ValueOf("jello"))) + assert.Assert(mLTE(ValueOf("hello"), ValueOf("hello"))) + assert.Assert(mLTE(ValueOf("hello"), ValueOf("HELLO"))) + assert.Assert(!mLTE(ValueOf("jello"), ValueOf("hello"))) + assert.Assert(mGTE(ValueOf("jello"), ValueOf("jello"))) + assert.Assert(mGTE(ValueOf("jello"), ValueOf("hello"))) + assert.Assert(mGTE(ValueOf("jello"), ValueOf("JELLO"))) + assert.Assert(!mGTE(ValueOf("hello"), ValueOf("jello"))) + assert.Assert(mEQ(ValueOf("jello"), ValueOf("jello"))) + assert.Assert(mEQ(ValueOf("jello"), ValueOf("JELLO"))) + assert.Assert(!mEQ(ValueOf("jello"), ValueOf("hello"))) +} + +func TestLess(t *testing.T) { + assert.Assert(mLT(ValueOf("null"), ValueOf("false"))) + assert.Assert(mLT(ValueOf("null"), ValueOf("123"))) + assert.Assert(mLT(ValueOf("null"), ValueOf("hello"))) + assert.Assert(mLT(ValueOf("null"), ValueOf("true"))) + assert.Assert(mLT(ValueOf("null"), ValueOf("[]"))) + assert.Assert(mLT(ValueOf("false"), ValueOf("123"))) + assert.Assert(mLT(ValueOf("false"), ValueOf("hello"))) + assert.Assert(mLT(ValueOf("false"), ValueOf("true"))) + assert.Assert(mLT(ValueOf("false"), ValueOf("[]"))) + assert.Assert(mLT(ValueOf("123"), ValueOf("hello"))) + assert.Assert(mLT(ValueOf("123"), ValueOf("true"))) + assert.Assert(mLT(ValueOf("123"), ValueOf("[]"))) + assert.Assert(mLT(ValueOf("hello"), ValueOf("true"))) + assert.Assert(mLT(ValueOf("hello"), ValueOf("[]"))) + assert.Assert(mLT(ValueOf("true"), ValueOf("[]"))) + assert.Assert(!mLT(ValueOf("false"), ValueOf("null"))) + assert.Assert(!mLT(ValueOf("123"), ValueOf("null"))) + assert.Assert(!mLT(ValueOf("hello"), ValueOf("null"))) + assert.Assert(!mLT(ValueOf("true"), ValueOf("null"))) + assert.Assert(!mLT(ValueOf("[]"), ValueOf("null"))) + assert.Assert(!mLT(ValueOf("123"), ValueOf("false"))) + assert.Assert(!mLT(ValueOf("hello"), ValueOf("false"))) + assert.Assert(!mLT(ValueOf("true"), ValueOf("false"))) + assert.Assert(!mLT(ValueOf("[]"), ValueOf("false"))) + assert.Assert(!mLT(ValueOf("hello"), ValueOf("123"))) + assert.Assert(!mLT(ValueOf("true"), ValueOf("123"))) + assert.Assert(!mLT(ValueOf("[]"), ValueOf("123"))) + assert.Assert(!mLT(ValueOf("true"), ValueOf("hello"))) + assert.Assert(!mLT(ValueOf("[]"), ValueOf("hello"))) + assert.Assert(!mLT(ValueOf("[]"), ValueOf("true"))) + assert.Assert(mLT(ValueOf("123"), ValueOf("456"))) + assert.Assert(mLT(ValueOf("[1]"), ValueOf("[2]"))) +} + +func TestLessCase(t *testing.T) { + assert.Assert(ValueOf("A").LessCase(ValueOf("B"), true)) + assert.Assert(!ValueOf("A").LessCase(ValueOf("A"), true)) + assert.Assert(!ValueOf("B").LessCase(ValueOf("A"), true)) +} + +func TestVarious(t *testing.T) { + assert.Assert(!ValueOf("A").IsZero()) + assert.Assert(ValueOf("0").IsZero()) + assert.Assert(Value{}.IsZero()) + assert.Assert(ZeroValue.IsZero()) + assert.Assert(ZeroValue.Equals(ZeroValue)) + assert.Assert(ZeroValue.Kind() == Number) + assert.Assert(ValueOf("0").Kind() == Number) + assert.Assert(ValueOf("hello").Kind() == String) + assert.Assert(ValueOf(`"hello"`).Kind() == String) + assert.Assert(ValueOf(`"123"`).Kind() == String) + assert.Assert(ValueOf(`"123"`).Data() == `123`) + assert.Assert(ValueOf(`"123"`).Num() == 0) +} + +func TestJSON(t *testing.T) { + assert.Assert(ValueOf(`A`).JSON() == `"A"`) + assert.Assert(ValueOf(`"A"`).JSON() == `"A"`) + assert.Assert(ValueOf(`123`).JSON() == `123`) + assert.Assert(ValueOf(`{}`).JSON() == `{}`) + assert.Assert(ValueOf(`{ }`).JSON() == `{}`) + assert.Assert(ValueOf(` -Inf `).JSON() == `"-Inf"`) + assert.Assert(ValueOf(` "-Inf" `).JSON() == `"-Inf"`) + assert.Assert(ValueOf(`+Inf`).JSON() == `"+Inf"`) + assert.Assert(ValueOf(`"+Inf"`).JSON() == `"+Inf"`) + assert.Assert(ValueOf(`Inf`).JSON() == `"+Inf"`) + assert.Assert(ValueOf(`"Inf"`).JSON() == `"+Inf"`) + assert.Assert(ValueOf(`NaN`).JSON() == `"NaN"`) + assert.Assert(ValueOf(`"NaN"`).JSON() == `"NaN"`) + assert.Assert(ValueOf(`nan`).JSON() == `"NaN"`) + assert.Assert(ValueOf(`infinity`).JSON() == `"+Inf"`) + assert.Assert(ValueOf(` true `).JSON() == `true`) + assert.Assert(ValueOf(` false `).JSON() == `false`) + assert.Assert(ValueOf(` null `).JSON() == `null`) + assert.Assert(Value{}.JSON() == `0`) + assert.Assert(Value{}.JSON() == `0`) +} + +func TestField(t *testing.T) { + assert.Assert(Make("hello", "123").Name() == "hello") + assert.Assert(Make("HELLO", "123").Name() == "hello") + assert.Assert(Make("HELLO", "123").Value().Num() == 123) + assert.Assert(Make("HELLO", "123").Value().JSON() == "123") + assert.Assert(Make("HELLO", "123").Value().Num() == 123) +} + +func TestWeight(t *testing.T) { + assert.Assert(Make("hello", "123").Weight() == 16) +} + +func TestNumber(t *testing.T) { + assert.Assert(ValueOf("012").Num() == 12) +} diff --git a/internal/field/list_array.go b/internal/field/list_array.go new file mode 100644 index 00000000..f911db3e --- /dev/null +++ b/internal/field/list_array.go @@ -0,0 +1,84 @@ +//go:build exclude + +package field + +type List struct { + entries []Field +} + +// bsearch searches array for value. +func (fields List) bsearch(name string) (index int, found bool) { + i, j := 0, len(fields.entries) + for i < j { + h := i + (j-i)/2 + if name >= fields.entries[h].name { + i = h + 1 + } else { + j = h + } + } + if i > 0 && fields.entries[i-1].name >= name { + return i - 1, true + } + return i, false +} + +func (fields List) Set(field Field) List { + var updated List + index, found := fields.bsearch(field.name) + if found { + if field.value.IsZero() { + // delete + if len(fields.entries) > 1 { + updated.entries = make([]Field, len(fields.entries)-1) + copy(updated.entries, fields.entries[:index]) + copy(updated.entries[index:], fields.entries[index+1:]) + } + } else if !fields.entries[index].value.Equals(field.value) { + // update + updated.entries = make([]Field, len(fields.entries)) + copy(updated.entries, fields.entries) + updated.entries[index].value = field.value + } else { + // nothing changes + updated = fields + } + return updated + } + if field.Value().IsZero() { + return fields + } + updated.entries = make([]Field, len(fields.entries)+1) + copy(updated.entries, fields.entries[:index]) + copy(updated.entries[index+1:], fields.entries[index:]) + updated.entries[index] = field + return updated +} + +func (fields List) Get(name string) Field { + index, found := fields.bsearch(name) + if !found { + return ZeroField + } + return fields.entries[index] +} + +func (fields List) Scan(iter func(field Field) bool) { + for _, f := range fields.entries { + if !iter(f) { + return + } + } +} + +func (fields List) Len() int { + return len(fields.entries) +} + +func (fields List) Weight() int { + var weight int + for _, f := range fields.entries { + weight += f.Weight() + } + return weight +} diff --git a/internal/field/list_binary.go b/internal/field/list_binary.go new file mode 100644 index 00000000..f41bcf42 --- /dev/null +++ b/internal/field/list_binary.go @@ -0,0 +1,349 @@ +package field + +import ( + "encoding/binary" + "strconv" + "unsafe" + + "github.com/tidwall/tile38/internal/sstring" +) + +// binary format +// (size,entry,[entry...]) +// size: uvarint -- size of the full byte slice, excluding itself. +// entry: (name,value) -- one field entry +// name: shared string num -- field name, string data, uses the shared library +// size: uvarint -- number of bytes in data +// value: (kind,vdata) -- field value +// kind: byte -- value kind +// vdata: (size,data) -- value data, string data + +// useSharedNames will results in smaller memory usage by sharing the names +// of fields using the sstring package. Otherwise the names are embeded with +// the list. +const useSharedNames = true + +// List of fields, ordered by Name. +type List struct { + p *byte +} + +type bytes struct { + p *byte + l int + c int +} + +func ptob(p *byte) []byte { + if p == nil { + return nil + } + // Get the size of the bytes (excluding the header) + x, n := uvarint(*(*[]byte)(unsafe.Pointer(&bytes{p, 10, 10}))) + // Return the byte slice (excluding the header) + return (*(*[]byte)(unsafe.Pointer(&bytes{p, n + x, n + x})))[n:] +} + +func btoa(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// uvarint is a slightly modified version of binary.Uvarint, and it's a little +// faster. But it lacks overflow checks which are not needed for our use. +func uvarint(buf []byte) (int, int) { + var x uint64 + for i := 0; i < len(buf); i++ { + b := buf[i] + if b < 0x80 { + return int(x | uint64(b)<<(i*7)), i + 1 + } + x |= uint64(b&0x7f) << (i * 7) + } + return 0, 0 +} + +func datakind(kind Kind) bool { + switch kind { + case Number, String, JSON: + return true + } + return false +} + +func bfield(name string, kind Kind, data string) Field { + var num float64 + switch kind { + case Number: + num, _ = strconv.ParseFloat(data, 64) + case Null: + data = "null" + case False: + data = "false" + case True: + data = "true" + } + return Field{ + name: name, + value: Value{ + kind: Kind(kind), + data: data, + num: num, + }, + } +} + +// Set a field in the list. +// If the input field value is zero `f.Value().IsZero()` then the field is +// deleted or removed from the list since lists cannot have Zero values. +// Returns a newly allocated list the updated field. +// The original (receiver) list is not modified. +func (fields List) Set(field Field) List { + b := ptob(fields.p) + var i int + for { + s := i + // read the name + var name string + x, n := uvarint(b[i:]) + if n == 0 { + break + } + if useSharedNames { + name = sstring.Load(x) + i += n + } else { + name = btoa(b[i+n : i+n+x]) + i += n + x + } + kind := Kind(b[i]) + i++ + var data string + if datakind(kind) { + x, n = uvarint(b[i:]) + data = btoa(b[i+n : i+n+x]) + i += n + x + } + if field.name < name { + // insert before + i = s + break + } + if name == field.name { + if field.Value().IsZero() { + // delete + return List{delfield(b, s, i)} + } + prev := bfield(name, kind, data) + if prev.Value().Equals(field.Value()) { + // no change + return fields + } + // replace + return List{putfield(b, field, s, i)} + } + } + if field.Value().IsZero() { + return fields + } + // insert after + return List{putfield(b, field, i, i)} +} + +func delfield(b []byte, s, e int) *byte { + totallen := s + (len(b) - e) + var psz [10]byte + pn := binary.PutUvarint(psz[:], uint64(totallen)) + plen := pn + totallen + p := make([]byte, plen) + // copy each component + i := 0 + + // -- header size + copy(p[i:], psz[:pn]) + i += pn + + // -- head entries + copy(p[i:], b[:s]) + i += s + + // -- tail entries + copy(p[i:], b[e:]) + + return &p[0] +} + +func putfield(b []byte, f Field, s, e int) *byte { + name := f.Name() + var namesz [10]byte + var namen int + if useSharedNames { + num := sstring.Store(name) + namen = binary.PutUvarint(namesz[:], uint64(num)) + } else { + namen = binary.PutUvarint(namesz[:], uint64(len(name))) + } + value := f.Value() + kind := value.Kind() + isdatakind := datakind(kind) + var data string + var datasz [10]byte + var datan int + if isdatakind { + data = value.Data() + datan = binary.PutUvarint(datasz[:], uint64(len(data))) + } + var totallen int + if useSharedNames { + totallen = s + namen + 1 + (len(b) - e) + } else { + totallen = s + namen + len(name) + 1 + +(len(b) - e) + } + if isdatakind { + totallen += datan + len(data) + } + var psz [10]byte + pn := binary.PutUvarint(psz[:], uint64(totallen)) + plen := pn + totallen + p := make([]byte, plen) + + // copy each component + i := 0 + + // -- header size + copy(p[i:], psz[:pn]) + i += pn + + // -- head entries + copy(p[i:], b[:s]) + i += s + + // -- name + copy(p[i:], namesz[:namen]) + i += namen + + if !useSharedNames { + copy(p[i:], name) + i += len(name) + } + + // -- kind + p[i] = byte(kind) + i++ + + if isdatakind { + // -- data + copy(p[i:], datasz[:datan]) + i += datan + + copy(p[i:], data) + i += len(data) + } + + // -- tail entries + copy(p[i:], b[e:]) + + return &p[0] +} + +// Get a field from the list. Or returns ZeroField if not found. +func (fields List) Get(name string) Field { + b := ptob(fields.p) + var i int + for { + // read the fname + var fname string + x, n := uvarint(b[i:]) + if n == 0 { + break + } + if useSharedNames { + fname = sstring.Load(x) + i += n + } else { + fname = btoa(b[i+n : i+n+x]) + i += n + x + } + kind := Kind(b[i]) + i++ + var data string + if datakind(kind) { + x, n = uvarint(b[i:]) + data = btoa(b[i+n : i+n+x]) + i += n + x + } + if name < fname { + break + } + if fname == name { + return bfield(fname, kind, data) + } + } + return ZeroField +} + +// Scan each field in list +func (fields List) Scan(iter func(field Field) bool) { + b := ptob(fields.p) + var i int + for { + // read the fname + var fname string + x, n := uvarint(b[i:]) + if n == 0 { + break + } + if useSharedNames { + fname = sstring.Load(x) + i += n + } else { + fname = btoa(b[i+n : i+n+x]) + i += n + x + } + kind := Kind(b[i]) + i++ + var data string + if datakind(kind) { + x, n = uvarint(b[i:]) + data = btoa(b[i+n : i+n+x]) + i += n + x + } + if !iter(bfield(fname, kind, data)) { + return + } + } +} + +// Len return the number of fields in list. +func (fields List) Len() int { + var count int + b := ptob(fields.p) + var i int + for { + x, n := uvarint(b[i:]) + if n == 0 { + break + } + if useSharedNames { + i += n + } else { + i += n + x + } + isdatakind := datakind(Kind(b[i])) + i++ + if isdatakind { + x, n = uvarint(b[i:]) + i += n + x + } + count++ + } + return count +} + +// Weight is the number of bytes of the list. +func (fields List) Weight() int { + if fields.p == nil { + return 0 + } + x, n := uvarint(*(*[]byte)(unsafe.Pointer(&bytes{fields.p, 10, 10}))) + return x + n +} diff --git a/internal/field/list_test.go b/internal/field/list_test.go new file mode 100644 index 00000000..749005dc --- /dev/null +++ b/internal/field/list_test.go @@ -0,0 +1,181 @@ +package field + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/tidwall/assert" + "github.com/tidwall/btree" +) + +func TestList(t *testing.T) { + var fields List + + fields = fields.Set(Make("hello", "123")) + assert.Assert(fields.Len() == 1) + // println(fields.Weight()) + // assert.Assert(fields.Weight() == 16) + + fields = fields.Set(Make("jello", "456")) + assert.Assert(fields.Len() == 2) + // assert.Assert(fields.Weight() == 32) + + value := fields.Get("jello") + assert.Assert(value.Value().Data() == "456") + assert.Assert(value.Value().JSON() == "456") + assert.Assert(value.Value().Num() == 456) + + value = fields.Get("nello") + assert.Assert(value.Name() == "") + assert.Assert(value.Value().IsZero()) + + fields = fields.Set(Make("jello", "789")) + assert.Assert(fields.Len() == 2) + // assert.Assert(fields.Weight() == 32) + + fields = fields.Set(Make("nello", "0")) + assert.Assert(fields.Len() == 2) + // assert.Assert(fields.Weight() == 32) + + fields = fields.Set(Make("jello", "789")) + assert.Assert(fields.Len() == 2) + // assert.Assert(fields.Weight() == 32) + + fields = fields.Set(Make("jello", "0")) + assert.Assert(fields.Len() == 1) + // assert.Assert(fields.Weight() == 16) + + fields = fields.Set(Make("nello", "012")) + fields = fields.Set(Make("hello", "456")) + fields = fields.Set(Make("fello", "123")) + fields = fields.Set(Make("jello", "789")) + + var names string + var datas string + var nums float64 + fields.Scan(func(f Field) bool { + names += f.Name() + datas += f.Value().Data() + nums += f.Value().Num() + return true + }) + assert.Assert(names == "fellohellojellonello") + assert.Assert(datas == "123456789012") + assert.Assert(nums == 1380) + + names = "" + datas = "" + nums = 0 + fields.Scan(func(f Field) bool { + names += f.Name() + datas += f.Value().Data() + nums += f.Value().Num() + return false + }) + assert.Assert(names == "fello") + assert.Assert(datas == "123") + assert.Assert(nums == 123) + +} + +func randStr(n int) string { + b := make([]byte, n) + rand.Read(b) + for i := 0; i < n; i++ { + b[i] = 'a' + b[i]%26 + } + return string(b) +} + +func randVal(n int) string { + switch rand.Intn(10) { + case 0: + return "null" + case 1: + return "true" + case 2: + return "false" + case 3: + return `{"a":"` + randStr(n) + `"}` + case 4: + return `["` + randStr(n) + `"]` + case 5: + return `"` + randStr(n) + `"` + case 6: + return randStr(n) + default: + return fmt.Sprintf("%f", rand.Float64()*360) + } +} + +func TestRandom(t *testing.T) { + seed := time.Now().UnixNano() + // seed = 1663607868546669000 + rand.Seed(seed) + start := time.Now() + var total int + for time.Since(start) < time.Second*2 { + N := rand.Intn(500) + var org []Field + var tr btree.Map[string, Field] + var fields List + for i := 0; i < N; i++ { + name := randStr(rand.Intn(10)) + value := randVal(rand.Intn(10)) + field := Make(name, value) + org = append(org, field) + fields = fields.Set(field) + v := fields.Get(name) + // println(name, v.Value().Data(), field.Value().Data()) + if !v.Value().Equals(field.Value()) { + t.Fatalf("seed: %d, expected true", seed) + } + tr.Set(name, field) + if fields.Len() != tr.Len() { + t.Fatalf("seed: %d, expected %d, got %d", + seed, tr.Len(), fields.Len()) + } + } + comp := func() { + var all []Field + fields.Scan(func(f Field) bool { + all = append(all, f) + return true + }) + if len(all) != fields.Len() { + t.Fatalf("seed: %d, expected %d, got %d", + seed, fields.Len(), len(all)) + } + if fields.Len() != tr.Len() { + t.Fatalf("seed: %d, expected %d, got %d", + seed, tr.Len(), fields.Len()) + } + var i int + tr.Scan(func(name string, f Field) bool { + if name != f.Name() || all[i].Name() != f.Name() { + t.Fatalf("seed: %d, out of order", seed) + } + i++ + return true + }) + } + comp() + rand.Shuffle(len(org), func(i, j int) { + org[i], org[j] = org[j], org[i] + }) + for _, f := range org { + comp() + tr.Delete(f.Name()) + fields = fields.Set(Make(f.Name(), "0")) + if fields.Len() != tr.Len() { + t.Fatalf("seed: %d, expected %d, got %d", + seed, tr.Len(), fields.Len()) + } + comp() + } + total++ + } + +} diff --git a/internal/server/aof.go b/internal/server/aof.go index 47e92f40..f90ab6e6 100644 --- a/internal/server/aof.go +++ b/internal/server/aof.go @@ -294,7 +294,6 @@ func (s *Server) queueHooks(d *commandDetails) error { for _, hook := range candidates { // Calculate all matching fence messages for all candidates and append // them to the appropriate message slice - hook.ScanWriter.loadWheres() msgs := FenceMatch(hook.Name, hook.ScanWriter, hook.Fence, hook.Metas, d) if len(msgs) > 0 { if hook.channel { diff --git a/internal/server/aofshrink.go b/internal/server/aofshrink.go index f467563e..980f0dc4 100644 --- a/internal/server/aofshrink.go +++ b/internal/server/aofshrink.go @@ -11,6 +11,7 @@ import ( "github.com/tidwall/geojson" "github.com/tidwall/tile38/core" "github.com/tidwall/tile38/internal/collection" + "github.com/tidwall/tile38/internal/field" "github.com/tidwall/tile38/internal/log" ) @@ -93,12 +94,10 @@ func (s *Server) aofshrink() { if !ok { return } - var fnames = col.FieldArr() // reload an array of field names to match each object - var fmap = col.FieldMap() // var now = time.Now().UnixNano() // used for expiration var count = 0 // the object count col.ScanGreaterOrEqual(nextid, false, nil, nil, - func(id string, obj geojson.Object, fields []float64, ex int64) bool { + func(id string, obj geojson.Object, fields field.List, ex int64) bool { if count == maxids { // we reached the max number of ids for one batch nextid = id @@ -110,16 +109,14 @@ func (s *Server) aofshrink() { values = append(values, "set") values = append(values, keys[0]) values = append(values, id) - if len(fields) > 0 { - fvs := orderFields(fmap, fnames, fields) - for _, fv := range fvs { - if fv.value != 0 { - values = append(values, "field") - values = append(values, fv.field) - values = append(values, strconv.FormatFloat(fv.value, 'f', -1, 64)) - } + fields.Scan(func(f field.Field) bool { + if !f.Value().IsZero() { + values = append(values, "field") + values = append(values, f.Name()) + values = append(values, f.Value().JSON()) } - } + return true + }) if ex != 0 { ttl := math.Floor(float64(ex-now)/float64(time.Second)*10) / 10 if ttl < 0.1 { diff --git a/internal/server/crud.go b/internal/server/crud.go index e1fb64fa..4dd33b9e 100644 --- a/internal/server/crud.go +++ b/internal/server/crud.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "errors" "strconv" "strings" "time" @@ -11,30 +12,32 @@ import ( "github.com/tidwall/geojson/geometry" "github.com/tidwall/resp" "github.com/tidwall/tile38/internal/collection" + "github.com/tidwall/tile38/internal/field" "github.com/tidwall/tile38/internal/glob" ) -type fvt struct { - field string - value float64 -} +// type fvt struct { +// field string +// value float64 +// } + +// func orderFields(fmap map[string]int, farr []string, fields []float64) []fvt { +// var fv fvt +// var idx int +// fvs := make([]fvt, 0, len(fmap)) +// for _, field := range farr { +// idx = fmap[field] +// if idx < len(fields) { +// fv.field = field +// fv.value = fields[idx] +// if fv.value != 0 { +// fvs = append(fvs, fv) +// } +// } +// } +// return fvs +// } -func orderFields(fmap map[string]int, farr []string, fields []float64) []fvt { - var fv fvt - var idx int - fvs := make([]fvt, 0, len(fmap)) - for _, field := range farr { - idx = fmap[field] - if idx < len(fields) { - fv.field = field - fv.value = fields[idx] - if fv.value != 0 { - fvs = append(fvs, fv) - } - } - } - return fvs -} func (s *Server) cmdBounds(msg *Message) (resp.Value, error) { start := time.Now() vs := msg.Args[1:] @@ -236,23 +239,26 @@ func (s *Server) cmdGet(msg *Message) (resp.Value, error) { return NOMessage, errInvalidNumberOfArguments } if withfields { - fvs := orderFields(col.FieldMap(), col.FieldArr(), fields) - if len(fvs) > 0 { - fvals := make([]resp.Value, 0, len(fvs)*2) + nfields := fields.Len() + if nfields > 0 { + fvals := make([]resp.Value, 0, nfields*2) if msg.OutputType == JSON { buf.WriteString(`,"fields":{`) } - for i, fv := range fvs { + var i int + fields.Scan(func(f field.Field) bool { if msg.OutputType == JSON { if i > 0 { buf.WriteString(`,`) } - buf.WriteString(jsonString(fv.field) + ":" + strconv.FormatFloat(fv.value, 'f', -1, 64)) + buf.WriteString(jsonString(f.Name()) + ":" + f.Value().JSON()) } else { - fvals = append(fvals, resp.StringValue(fv.field), resp.StringValue(strconv.FormatFloat(fv.value, 'f', -1, 64))) + fvals = append(fvals, + resp.StringValue(f.Name()), resp.StringValue(f.Value().Data())) } i++ - } + return true + }) if msg.OutputType == JSON { buf.WriteString(`}`) } else { @@ -354,7 +360,7 @@ func (s *Server) cmdPdel(msg *Message) (res resp.Value, d commandDetails, err er return } now := time.Now() - iter := func(id string, o geojson.Object, fields []float64) bool { + iter := func(id string, o geojson.Object, fields field.List) bool { if match, _ := glob.Match(d.pattern, id); match { d.children = append(d.children, &commandDetails{ command: "del", @@ -513,7 +519,7 @@ func (s *Server) cmdRename(msg *Message) (res resp.Value, d commandDetails, err return } -func (s *Server) cmdFlushDB(msg *Message) (res resp.Value, d commandDetails, err error) { +func (s *Server) cmdFLUSHDB(msg *Message) (res resp.Value, d commandDetails, err error) { start := time.Now() vs := msg.Args[1:] if len(vs) != 0 { @@ -543,424 +549,347 @@ func (s *Server) cmdFlushDB(msg *Message) (res resp.Value, d commandDetails, err return } -func (s *Server) parseSetArgs(vs []string) ( - d commandDetails, fields []string, values []float64, - xx, nx bool, - ex int64, etype []byte, evs []string, err error, -) { - var ok bool - var typ []byte - if vs, d.key, ok = tokenval(vs); !ok || d.key == "" { - err = errInvalidNumberOfArguments - return +// SET key id [FIELD name value ...] [EX seconds] [NX|XX] +// (OBJECT geojson)|(POINT lat lon z)|(BOUNDS minlat minlon maxlat maxlon)|(HASH geohash)|(STRING value) +func (s *Server) cmdSET(msg *Message) (resp.Value, commandDetails, error) { + start := time.Now() + if s.config.maxMemory() > 0 && s.outOfMemory.on() { + return retwerr(errOOM) } - if vs, d.id, ok = tokenval(vs); !ok || d.id == "" { - err = errInvalidNumberOfArguments - return + + // >> Args + + var key string + var id string + var fields []field.Field + var ex int64 + var xx bool + var nx bool + var obj geojson.Object + + args := msg.Args + if len(args) < 3 { + return retwerr(errInvalidNumberOfArguments) } - var arg []byte - var nvs []string - for { - if nvs, arg, ok = tokenvalbytes(vs); !ok || len(arg) == 0 { - err = errInvalidNumberOfArguments - return - } - if lcb(arg, "field") { - vs = nvs - var name string - var svalue string - var value float64 - if vs, name, ok = tokenval(vs); !ok || name == "" { - err = errInvalidNumberOfArguments - return + + key, id = args[1], args[2] + + for i := 3; i < len(args); i++ { + switch strings.ToLower(args[i]) { + case "field": + if i+2 >= len(args) { + return retwerr(errInvalidNumberOfArguments) } - if isReservedFieldName(name) { - err = errInvalidArgument(name) - return + fkey := strings.ToLower(args[i+1]) + fval := args[i+2] + i += 2 + if isReservedFieldName(fkey) { + return retwerr(errInvalidArgument(fkey)) } - if vs, svalue, ok = tokenval(vs); !ok || svalue == "" { - err = errInvalidNumberOfArguments - return + fields = append(fields, field.Make(fkey, fval)) + case "ex": + if i+1 >= len(args) { + return retwerr(errInvalidNumberOfArguments) } - value, err = strconv.ParseFloat(svalue, 64) + exval := args[i+1] + i += 1 + x, err := strconv.ParseFloat(exval, 64) if err != nil { - err = errInvalidArgument(svalue) - return + return retwerr(errInvalidArgument(exval)) } - fields = append(fields, name) - values = append(values, value) - continue - } - if lcb(arg, "ex") { - vs = nvs - if ex != 0 { - err = errInvalidArgument(string(arg)) - return - } - var s string - var v float64 - if vs, s, ok = tokenval(vs); !ok || s == "" { - err = errInvalidNumberOfArguments - return - } - v, err = strconv.ParseFloat(s, 64) - if err != nil { - err = errInvalidArgument(s) - return - } - ex = time.Now().UnixNano() + int64(float64(time.Second)*v) - continue - } - if lcb(arg, "xx") { - vs = nvs - if nx { - err = errInvalidArgument(string(arg)) - return - } - xx = true - continue - } - if lcb(arg, "nx") { - vs = nvs + ex = time.Now().UnixNano() + int64(float64(time.Second)*x) + case "nx": if xx { - err = errInvalidArgument(string(arg)) - return + return retwerr(errInvalidArgument(args[i])) } nx = true - continue - } - break - } - if vs, typ, ok = tokenvalbytes(vs); !ok || len(typ) == 0 { - err = errInvalidNumberOfArguments - return - } - if len(vs) == 0 { - err = errInvalidNumberOfArguments - return - } - etype = typ - evs = vs - switch { - default: - err = errInvalidArgument(string(typ)) - return - case lcb(typ, "string"): - var str string - if vs, str, ok = tokenval(vs); !ok { - err = errInvalidNumberOfArguments - return - } - d.obj = collection.String(str) - case lcb(typ, "point"): - var slat, slon, sz string - if vs, slat, ok = tokenval(vs); !ok || slat == "" { - err = errInvalidNumberOfArguments - return - } - if vs, slon, ok = tokenval(vs); !ok || slon == "" { - err = errInvalidNumberOfArguments - return - } - vs, sz, ok = tokenval(vs) - if !ok || sz == "" { - var x, y float64 - y, err = strconv.ParseFloat(slat, 64) - if err != nil { - err = errInvalidArgument(slat) - return + case "xx": + if nx { + return retwerr(errInvalidArgument(args[i])) } - x, err = strconv.ParseFloat(slon, 64) - if err != nil { - err = errInvalidArgument(slon) - return + xx = true + case "string": + if i+1 >= len(args) { + return retwerr(errInvalidNumberOfArguments) } - d.obj = geojson.NewPoint(geometry.Point{X: x, Y: y}) - } else { - var x, y, z float64 - y, err = strconv.ParseFloat(slat, 64) - if err != nil { - err = errInvalidArgument(slat) - return + str := args[i+1] + i += 1 + obj = collection.String(str) + case "point": + if i+2 >= len(args) { + return retwerr(errInvalidNumberOfArguments) } - x, err = strconv.ParseFloat(slon, 64) - if err != nil { - err = errInvalidArgument(slon) - return + slat := args[i+1] + slon := args[i+2] + i += 2 + var z float64 + var hasZ bool + if i+1 < len(args) { + // probe for possible z coordinate + var err error + z, err = strconv.ParseFloat(args[i+1], 64) + if err == nil { + hasZ = true + i++ + } } - z, err = strconv.ParseFloat(sz, 64) + y, err := strconv.ParseFloat(slat, 64) if err != nil { - err = errInvalidArgument(sz) - return + return retwerr(errInvalidArgument(slat)) } - d.obj = geojson.NewPointZ(geometry.Point{X: x, Y: y}, z) - } - case lcb(typ, "bounds"): - var sminlat, sminlon, smaxlat, smaxlon string - if vs, sminlat, ok = tokenval(vs); !ok || sminlat == "" { - err = errInvalidNumberOfArguments - return - } - if vs, sminlon, ok = tokenval(vs); !ok || sminlon == "" { - err = errInvalidNumberOfArguments - return - } - if vs, smaxlat, ok = tokenval(vs); !ok || smaxlat == "" { - err = errInvalidNumberOfArguments - return - } - if vs, smaxlon, ok = tokenval(vs); !ok || smaxlon == "" { - err = errInvalidNumberOfArguments - return - } - var minlat, minlon, maxlat, maxlon float64 - minlat, err = strconv.ParseFloat(sminlat, 64) - if err != nil { - err = errInvalidArgument(sminlat) - return - } - minlon, err = strconv.ParseFloat(sminlon, 64) - if err != nil { - err = errInvalidArgument(sminlon) - return - } - maxlat, err = strconv.ParseFloat(smaxlat, 64) - if err != nil { - err = errInvalidArgument(smaxlat) - return - } - maxlon, err = strconv.ParseFloat(smaxlon, 64) - if err != nil { - err = errInvalidArgument(smaxlon) - return - } - d.obj = geojson.NewRect(geometry.Rect{ - Min: geometry.Point{X: minlon, Y: minlat}, - Max: geometry.Point{X: maxlon, Y: maxlat}, - }) - case lcb(typ, "hash"): - var shash string - if vs, shash, ok = tokenval(vs); !ok || shash == "" { - err = errInvalidNumberOfArguments - return - } - lat, lon := geohash.Decode(shash) - d.obj = geojson.NewPoint(geometry.Point{X: lon, Y: lat}) - case lcb(typ, "object"): - var object string - if vs, object, ok = tokenval(vs); !ok || object == "" { - err = errInvalidNumberOfArguments - return - } - d.obj, err = geojson.Parse(object, &s.geomParseOpts) - if err != nil { - return + x, err := strconv.ParseFloat(slon, 64) + if err != nil { + return retwerr(errInvalidArgument(slon)) + } + if !hasZ { + obj = geojson.NewPoint(geometry.Point{X: x, Y: y}) + } else { + obj = geojson.NewPointZ(geometry.Point{X: x, Y: y}, z) + } + case "bounds": + if i+4 >= len(args) { + return retwerr(errInvalidNumberOfArguments) + } + var vals [4]float64 + for j := 0; j < 4; j++ { + var err error + vals[j], err = strconv.ParseFloat(args[i+1+j], 64) + if err != nil { + return retwerr(errInvalidArgument(args[i+1+j])) + } + } + i += 4 + obj = geojson.NewRect(geometry.Rect{ + Min: geometry.Point{X: vals[1], Y: vals[0]}, + Max: geometry.Point{X: vals[3], Y: vals[2]}, + }) + case "hash": + if i+1 >= len(args) { + return retwerr(errInvalidNumberOfArguments) + } + shash := args[i+1] + i += 1 + lat, lon := geohash.Decode(shash) + obj = geojson.NewPoint(geometry.Point{X: lon, Y: lat}) + case "object": + if i+1 >= len(args) { + return retwerr(errInvalidNumberOfArguments) + } + json := args[i+1] + i += 1 + var err error + obj, err = geojson.Parse(json, &s.geomParseOpts) + if err != nil { + return retwerr(err) + } + default: + return retwerr(errInvalidArgument(args[i])) } } - if len(vs) != 0 { - err = errInvalidNumberOfArguments - } - return -} -func (s *Server) cmdSet(msg *Message) (res resp.Value, d commandDetails, err error) { - if s.config.maxMemory() > 0 && s.outOfMemory.on() { - err = errOOM - return - } - start := time.Now() - vs := msg.Args[1:] - var fmap map[string]int - var fields []string - var values []float64 - var xx, nx bool - var ex int64 - d, fields, values, xx, nx, ex, _, _, err = s.parseSetArgs(vs) - if err != nil { - return - } - col, _ := s.cols.Get(d.key) - if col == nil { + // >> Operation + + var nada bool + col, ok := s.cols.Get(key) + if !ok { if xx { - goto notok - } - col = collection.New() - s.cols.Set(d.key, col) - } - if xx || nx { - _, _, _, ok := col.Get(d.id) - if (nx && ok) || (xx && !ok) { - goto notok + nada = true + } else { + col = collection.New() + s.cols.Set(key, col) } } - d.oldObj, d.oldFields, d.fields = col.Set(d.id, d.obj, fields, values, ex) + + var ofields field.List + if !nada { + _, ofields, _, ok = col.Get(id) + if xx || nx { + if (nx && ok) || (xx && !ok) { + nada = true + } + } + } + + if nada { + // exclude operation due to 'xx' or 'nx' match + switch msg.OutputType { + default: + case JSON: + if nx { + return retwerr(errIDAlreadyExists) + } else { + return retwerr(errIDNotFound) + } + case RESP: + return resp.NullValue(), commandDetails{}, nil + } + return retwerr(errors.New("nada unknown output")) + } + + for _, f := range fields { + ofields = ofields.Set(f) + } + + oldObj, oldFields, newFields := col.Set(id, obj, ofields, ex) + + // >> Response + + var d commandDetails d.command = "set" + d.key = key + d.id = id + d.obj = obj + d.oldObj = oldObj + d.oldFields = oldFields + d.fields = newFields d.updated = true // perhaps we should do a diff on the previous object? d.timestamp = time.Now() - if msg.ConnType != Null || msg.OutputType != Null { - // likely loaded from aof at server startup, ignore field remapping. - fmap = col.FieldMap() - d.fmap = make(map[string]int) - for key, idx := range fmap { - d.fmap[key] = idx - } - } - // if ex != nil { - // server.expireAt(d.key, d.id, d.timestamp.Add(time.Duration(float64(time.Second)*(*ex)))) - // } + + var res resp.Value switch msg.OutputType { default: case JSON: - res = resp.StringValue(`{"ok":true,"elapsed":"` + time.Since(start).String() + "\"}") + res = resp.StringValue(`{"ok":true,"elapsed":"` + + time.Since(start).String() + "\"}") case RESP: res = resp.SimpleStringValue("OK") } - return -notok: - switch msg.OutputType { - default: - case JSON: - if nx { - err = errIDAlreadyExists - } else { - err = errIDNotFound - } - return - case RESP: - res = resp.NullValue() - } - return + return res, d, nil } -func (s *Server) parseFSetArgs(vs []string) ( - d commandDetails, fields []string, values []float64, xx bool, err error, -) { - var ok bool - if vs, d.key, ok = tokenval(vs); !ok || d.key == "" { - err = errInvalidNumberOfArguments - return - } - if vs, d.id, ok = tokenval(vs); !ok || d.id == "" { - err = errInvalidNumberOfArguments - return - } - for len(vs) > 0 { - var name string - if vs, name, ok = tokenval(vs); !ok || name == "" { - err = errInvalidNumberOfArguments - return - } - if lc(name, "xx") { - xx = true - continue - } - if isReservedFieldName(name) { - err = errInvalidArgument(name) - return - } - var svalue string - var value float64 - if vs, svalue, ok = tokenval(vs); !ok || svalue == "" { - err = errInvalidNumberOfArguments - return - } - value, err = strconv.ParseFloat(svalue, 64) - if err != nil { - err = errInvalidArgument(svalue) - return - } - fields = append(fields, name) - values = append(values, value) - } - return +func retwerr(err error) (resp.Value, commandDetails, error) { + return resp.Value{}, commandDetails{}, err +} +func retrerr(err error) (resp.Value, error) { + return resp.Value{}, err } -func (s *Server) cmdFset(msg *Message) (res resp.Value, d commandDetails, err error) { - if s.config.maxMemory() > 0 && s.outOfMemory.on() { - err = errOOM - return - } +// FSET key id [XX] field value [field value...] +func (s *Server) cmdFSET(msg *Message) (resp.Value, commandDetails, error) { start := time.Now() - vs := msg.Args[1:] - var fields []string - var values []float64 - var xx bool - var updateCount int - d, fields, values, xx, err = s.parseFSetArgs(vs) + if s.config.maxMemory() > 0 && s.outOfMemory.on() { + return retwerr(errOOM) + } - col, _ := s.cols.Get(d.key) - if col == nil { - err = errKeyNotFound - return + // >> Args + + var id string + var key string + var xx bool + var fields []field.Field // raw fields + + args := msg.Args + if len(args) < 5 { + return retwerr(errInvalidNumberOfArguments) } - var ok bool - d.obj, d.fields, updateCount, ok = col.SetFields(d.id, fields, values) + key, id = args[1], args[2] + for i := 3; i < len(args); i++ { + arg := strings.ToLower(args[i]) + switch arg { + case "xx": + xx = true + default: + fkey := arg + i++ + if i == len(args) { + return retwerr(errInvalidNumberOfArguments) + } + if isReservedFieldName(fkey) { + return retwerr(errInvalidArgument(fkey)) + } + fval := args[i] + fields = append(fields, field.Make(fkey, fval)) + } + } + + // >> Operation + + var d commandDetails + var updateCount int + + col, ok := s.cols.Get(key) + if !ok { + return retwerr(errKeyNotFound) + } + obj, ofields, ex, ok := col.Get(id) if !(ok || xx) { - err = errIDNotFound - return + return retwerr(errIDNotFound) } + if ok { + for _, f := range fields { + prev := ofields.Get(f.Name()) + if !prev.Value().Equals(f.Value()) { + ofields = ofields.Set(f) + updateCount++ + } + } + col.Set(id, obj, ofields, ex) + d.obj = obj d.command = "fset" + d.key = key + d.id = id d.timestamp = time.Now() d.updated = updateCount > 0 - fmap := col.FieldMap() - d.fmap = make(map[string]int) - for key, idx := range fmap { - d.fmap[key] = idx - } } + // >> Response + + var res resp.Value + switch msg.OutputType { case JSON: - res = resp.StringValue(`{"ok":true,"elapsed":"` + time.Since(start).String() + "\"}") + res = resp.StringValue(`{"ok":true,"elapsed":"` + + time.Since(start).String() + "\"}") case RESP: res = resp.IntegerValue(updateCount) } - return + + return res, d, nil } -func (s *Server) cmdExpire(msg *Message) (res resp.Value, d commandDetails, err error) { +// EXPIRE key id seconds +func (s *Server) cmdEXPIRE(msg *Message) (resp.Value, commandDetails, error) { start := time.Now() - vs := msg.Args[1:] - var key, id, svalue string - var ok bool - if vs, key, ok = tokenval(vs); !ok || key == "" { - err = errInvalidNumberOfArguments - return + args := msg.Args + if len(args) != 4 { + return retwerr(errInvalidNumberOfArguments) } - if vs, id, ok = tokenval(vs); !ok || id == "" { - err = errInvalidNumberOfArguments - return - } - if vs, svalue, ok = tokenval(vs); !ok || svalue == "" { - err = errInvalidNumberOfArguments - return - } - if len(vs) != 0 { - err = errInvalidNumberOfArguments - return - } - var value float64 - value, err = strconv.ParseFloat(svalue, 64) + key, id, svalue := args[1], args[2], args[3] + value, err := strconv.ParseFloat(svalue, 64) if err != nil { - err = errInvalidArgument(svalue) - return + return retwerr(errInvalidArgument(svalue)) } - ok = false + var ok bool col, _ := s.cols.Get(key) if col != nil { + // replace the expiration by getting the old objec ex := time.Now().Add(time.Duration(float64(time.Second) * value)).UnixNano() - ok = col.SetExpires(id, ex) + var obj geojson.Object + var fields field.List + obj, fields, _, ok = col.Get(id) + if ok { + col.Set(id, obj, fields, ex) + } } + var d commandDetails if ok { + d.key = key + d.id = id + d.command = "expire" d.updated = true + d.timestamp = time.Now() } + var res resp.Value switch msg.OutputType { case JSON: if ok { - res = resp.StringValue(`{"ok":true,"elapsed":"` + time.Since(start).String() + "\"}") + res = resp.StringValue(`{"ok":true,"elapsed":"` + + time.Since(start).String() + "\"}") + } else if col == nil { + return retwerr(errKeyNotFound) } else { - return resp.SimpleStringValue(""), d, errIDNotFound + return retwerr(errIDNotFound) } case RESP: if ok { @@ -969,48 +898,55 @@ func (s *Server) cmdExpire(msg *Message) (res resp.Value, d commandDetails, err res = resp.IntegerValue(0) } } - return + return res, d, nil } -func (s *Server) cmdPersist(msg *Message) (res resp.Value, d commandDetails, err error) { +// PERSIST key id +func (s *Server) cmdPERSIST(msg *Message) (resp.Value, commandDetails, error) { start := time.Now() - vs := msg.Args[1:] - var key, id string - var ok bool - if vs, key, ok = tokenval(vs); !ok || key == "" { - err = errInvalidNumberOfArguments - return - } - if vs, id, ok = tokenval(vs); !ok || id == "" { - err = errInvalidNumberOfArguments - return - } - if len(vs) != 0 { - err = errInvalidNumberOfArguments - return + args := msg.Args + if len(args) != 3 { + return retwerr(errInvalidNumberOfArguments) } + key, id := args[1], args[2] var cleared bool - ok = false + var ok bool col, _ := s.cols.Get(key) if col != nil { var ex int64 _, _, ex, ok = col.Get(id) if ok && ex != 0 { - ok = col.SetExpires(id, 0) + var obj geojson.Object + var fields field.List + obj, fields, _, ok = col.Get(id) + if ok { + col.Set(id, obj, fields, 0) + } if ok { cleared = true } } } + if !ok { if msg.OutputType == RESP { - return resp.IntegerValue(0), d, nil + return resp.IntegerValue(0), commandDetails{}, nil } - return resp.SimpleStringValue(""), d, errIDNotFound + if col == nil { + return retwerr(errKeyNotFound) + } + return retwerr(errIDNotFound) } + + var res resp.Value + + var d commandDetails + d.key = key + d.id = id d.command = "persist" d.updated = cleared d.timestamp = time.Now() + switch msg.OutputType { case JSON: res = resp.SimpleStringValue(`{"ok":true,"elapsed":"` + time.Since(start).String() + "\"}") @@ -1021,28 +957,19 @@ func (s *Server) cmdPersist(msg *Message) (res resp.Value, d commandDetails, err res = resp.IntegerValue(0) } } - return + return res, d, nil } -func (s *Server) cmdTTL(msg *Message) (res resp.Value, err error) { +// TTL key id +func (s *Server) cmdTTL(msg *Message) (resp.Value, error) { start := time.Now() - vs := msg.Args[1:] - var key, id string - var ok bool - if vs, key, ok = tokenval(vs); !ok || key == "" { - err = errInvalidNumberOfArguments - return - } - if vs, id, ok = tokenval(vs); !ok || id == "" { - err = errInvalidNumberOfArguments - return - } - if len(vs) != 0 { - err = errInvalidNumberOfArguments - return + args := msg.Args + if len(args) != 3 { + return retrerr(errInvalidNumberOfArguments) } + key, id := args[1], args[2] var v float64 - ok = false + var ok bool var ok2 bool col, _ := s.cols.Get(key) if col != nil { @@ -1063,6 +990,7 @@ func (s *Server) cmdTTL(msg *Message) (res resp.Value, err error) { } } } + var res resp.Value switch msg.OutputType { case JSON: if ok { @@ -1073,9 +1001,13 @@ func (s *Server) cmdTTL(msg *Message) (res resp.Value, err error) { ttl = "-1" } res = resp.SimpleStringValue( - `{"ok":true,"ttl":` + ttl + `,"elapsed":"` + time.Since(start).String() + "\"}") + `{"ok":true,"ttl":` + ttl + `,"elapsed":"` + + time.Since(start).String() + "\"}") } else { - return resp.SimpleStringValue(""), errIDNotFound + if col == nil { + return retrerr(errKeyNotFound) + } + return retrerr(errIDNotFound) } case RESP: if ok { @@ -1088,5 +1020,5 @@ func (s *Server) cmdTTL(msg *Message) (res resp.Value, err error) { res = resp.IntegerValue(-2) } } - return + return res, nil } diff --git a/internal/server/fence.go b/internal/server/fence.go index 2eb17dec..58d2d9b1 100644 --- a/internal/server/fence.go +++ b/internal/server/fence.go @@ -10,6 +10,7 @@ import ( "github.com/tidwall/geojson/geo" "github.com/tidwall/geojson/geometry" "github.com/tidwall/gjson" + "github.com/tidwall/tile38/internal/field" "github.com/tidwall/tile38/internal/glob" ) @@ -87,9 +88,7 @@ func fenceMatch( return nil } if details.command == "fset" { - sw.mu.Lock() nofields := sw.nofields - sw.mu.Unlock() if nofields { return nil } @@ -166,9 +165,10 @@ func fenceMatch( } } - if details.fmap == nil { - return nil - } + // TODO: fields + // if details.fmap == nil { + // return nil + // } for { if fence.detect != nil && !fence.detect[detect] { if detect == "enter" { @@ -183,26 +183,24 @@ func fenceMatch( } break } - sw.mu.Lock() var distance float64 if fence.distance && fence.obj != nil { distance = details.obj.Distance(fence.obj) } - sw.fmap = details.fmap + // TODO: fields + // sw.fmap = details.fmap sw.fullFields = true sw.msg.OutputType = JSON sw.writeObject(ScanWriterParams{ id: details.id, o: details.obj, fields: details.fields, - noLock: true, noTest: true, distance: distance, distOutput: fence.distance, }) if sw.wr.Len() == 0 { - sw.mu.Unlock() return nil } @@ -214,7 +212,6 @@ func fenceMatch( if sw.output == outputIDs { res = `{"id":` + string(res) + `}` } - sw.mu.Unlock() var group string if detect == "enter" { @@ -300,7 +297,7 @@ func extendRoamMessage( } pattern := match.id + fence.roam.scan iterator := func( - oid string, o geojson.Object, fields []float64, + oid string, o geojson.Object, fields field.List, ) bool { if oid == match.id { return true @@ -387,7 +384,7 @@ func fenceMatchNearbys( Max: geometry.Point{X: maxLon, Y: maxLat}, } col.Intersects(geojson.NewRect(rect), 0, nil, nil, func( - id2 string, obj2 geojson.Object, fields []float64, + id2 string, obj2 geojson.Object, fields field.List, ) bool { var idMatch bool if id2 == id { diff --git a/internal/server/json.go b/internal/server/json.go index b3a2b9b6..b9181ef9 100644 --- a/internal/server/json.go +++ b/internal/server/json.go @@ -270,7 +270,7 @@ func (s *Server) cmdJset(msg *Message) (res resp.Value, d commandDetails, err er } var json string var geoobj bool - o, _, _, ok := col.Get(id) + o, fields, _, ok := col.Get(id) if ok { geoobj = objIsSpatial(o) json = o.String() @@ -290,7 +290,7 @@ func (s *Server) cmdJset(msg *Message) (res resp.Value, d commandDetails, err er nmsg := *msg nmsg.Args = []string{"SET", key, id, "OBJECT", json} // SET key id OBJECT json - return s.cmdSet(&nmsg) + return s.cmdSET(&nmsg) } if createcol { s.cols.Set(key, col) @@ -302,7 +302,7 @@ func (s *Server) cmdJset(msg *Message) (res resp.Value, d commandDetails, err er d.timestamp = time.Now() d.updated = true - col.Set(d.id, d.obj, nil, nil, 0) + col.Set(d.id, d.obj, fields, 0) switch msg.OutputType { case JSON: var buf bytes.Buffer @@ -335,7 +335,7 @@ func (s *Server) cmdJdel(msg *Message) (res resp.Value, d commandDetails, err er var json string var geoobj bool - o, _, _, ok := col.Get(id) + o, fields, _, ok := col.Get(id) if ok { geoobj = objIsSpatial(o) json = o.String() @@ -358,7 +358,7 @@ func (s *Server) cmdJdel(msg *Message) (res resp.Value, d commandDetails, err er nmsg := *msg nmsg.Args = []string{"SET", key, id, "OBJECT", json} // SET key id OBJECT json - return s.cmdSet(&nmsg) + return s.cmdSET(&nmsg) } d.key = key @@ -366,8 +366,7 @@ func (s *Server) cmdJdel(msg *Message) (res resp.Value, d commandDetails, err er d.obj = collection.String(json) d.timestamp = time.Now() d.updated = true - - col.Set(d.id, d.obj, nil, nil, 0) + col.Set(d.id, d.obj, fields, 0) switch msg.OutputType { case JSON: var buf bytes.Buffer diff --git a/internal/server/scan.go b/internal/server/scan.go index 0151ae90..02c36c0a 100644 --- a/internal/server/scan.go +++ b/internal/server/scan.go @@ -7,6 +7,7 @@ import ( "github.com/tidwall/geojson" "github.com/tidwall/resp" + "github.com/tidwall/tile38/internal/field" ) func (s *Server) cmdScanArgs(vs []string) ( @@ -54,10 +55,11 @@ func (s *Server) cmdScan(msg *Message) (res resp.Value, err error) { if msg.OutputType == JSON { wr.WriteString(`{"ok":true`) } - sw.writeHead() + var ierr error if sw.col != nil { if sw.output == outputCount && len(sw.wheres) == 0 && - len(sw.whereins) == 0 && sw.globEverything { + len(sw.whereins) == 0 && len(sw.whereevals) == 0 && + sw.globEverything { count := sw.col.Count() - int(args.cursor) if count < 0 { count = 0 @@ -68,28 +70,41 @@ func (s *Server) cmdScan(msg *Message) (res resp.Value, err error) { if limits[0] == "" && limits[1] == "" { sw.col.Scan(args.desc, sw, msg.Deadline, - func(id string, o geojson.Object, fields []float64) bool { - return sw.writeObject(ScanWriterParams{ + func(id string, o geojson.Object, fields field.List) bool { + keepGoing, err := sw.pushObject(ScanWriterParams{ id: id, o: o, fields: fields, }) + if err != nil { + ierr = err + return false + } + return keepGoing }, ) } else { sw.col.ScanRange(limits[0], limits[1], args.desc, sw, msg.Deadline, - func(id string, o geojson.Object, fields []float64) bool { - return sw.writeObject(ScanWriterParams{ + func(id string, o geojson.Object, fields field.List) bool { + keepGoing, err := sw.pushObject(ScanWriterParams{ id: id, o: o, fields: fields, }) + if err != nil { + ierr = err + return false + } + return keepGoing }, ) } } } + if ierr != nil { + return retrerr(ierr) + } sw.writeFoot() if msg.OutputType == JSON { wr.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}") diff --git a/internal/server/scanner.go b/internal/server/scanner.go index bfe46af1..ef5733dd 100644 --- a/internal/server/scanner.go +++ b/internal/server/scanner.go @@ -5,13 +5,14 @@ import ( "errors" "math" "strconv" - "sync" "github.com/mmcloughlin/geohash" + "github.com/tidwall/btree" "github.com/tidwall/geojson" "github.com/tidwall/resp" "github.com/tidwall/tile38/internal/clip" "github.com/tidwall/tile38/internal/collection" + "github.com/tidwall/tile38/internal/field" "github.com/tidwall/tile38/internal/glob" ) @@ -30,15 +31,12 @@ const ( ) type scanWriter struct { - mu sync.Mutex s *Server wr *bytes.Buffer - key string + name string msg *Message col *collection.Collection - fmap map[string]int - farr []string - fvals []float64 + fkeys btree.Set[string] output outputT wheres []whereT whereins []whereinT @@ -58,18 +56,15 @@ type scanWriter struct { values []resp.Value matchValues bool respOut resp.Value - orgWheres []whereT - orgWhereins []whereinT + filled []ScanWriterParams } -// ScanWriterParams ... type ScanWriterParams struct { id string o geojson.Object - fields []float64 + fields field.List distance float64 distOutput bool // query or fence requested distance output - noLock bool noTest bool ignoreGlobMatch bool clip geojson.Object @@ -77,7 +72,7 @@ type ScanWriterParams struct { } func (s *Server) newScanWriter( - wr *bytes.Buffer, msg *Message, key string, output outputT, + wr *bytes.Buffer, msg *Message, name string, output outputT, precision uint64, globs []string, matchValues bool, cursor, limit uint64, wheres []whereT, whereins []whereinT, whereevals []whereevalT, nofields bool, @@ -99,7 +94,7 @@ func (s *Server) newScanWriter( sw := &scanWriter{ s: s, wr: wr, - key: key, + name: name, msg: msg, globs: globs, limit: limit, @@ -114,50 +109,12 @@ func (s *Server) newScanWriter( if len(globs) == 0 || (len(globs) == 1 && globs[0] == "*") { sw.globEverything = true } - sw.orgWheres = wheres - sw.orgWhereins = whereins - sw.loadWheres() + sw.wheres = wheres + sw.whereins = whereins + sw.col, _ = sw.s.cols.Get(sw.name) return sw, nil } -func (sw *scanWriter) loadWheres() { - sw.fmap = nil - sw.farr = nil - sw.wheres = nil - sw.whereins = nil - sw.fvals = nil - sw.col, _ = sw.s.cols.Get(sw.key) - if sw.col != nil { - sw.fmap = sw.col.FieldMap() - sw.farr = sw.col.FieldArr() - // This fills index value in wheres/whereins - // so we don't have to map string field names for each tested object - var ok bool - if len(sw.orgWheres) > 0 { - sw.wheres = make([]whereT, len(sw.orgWheres)) - for i, where := range sw.orgWheres { - if where.index, ok = sw.fmap[where.field]; !ok { - where.index = math.MaxInt32 - } - sw.wheres[i] = where - } - } - if len(sw.orgWhereins) > 0 { - sw.whereins = make([]whereinT, len(sw.orgWhereins)) - for i, wherein := range sw.orgWhereins { - if wherein.index, ok = sw.fmap[wherein.field]; !ok { - wherein.index = math.MaxInt32 - } - sw.whereins[i] = wherein - } - } - if len(sw.farr) > 0 { - sw.fvals = make([]float64, len(sw.farr)) - } - } - -} - func (sw *scanWriter) hasFieldsOutput() bool { switch sw.output { default: @@ -167,19 +124,20 @@ func (sw *scanWriter) hasFieldsOutput() bool { } } -func (sw *scanWriter) writeHead() { - sw.mu.Lock() - defer sw.mu.Unlock() +func (sw *scanWriter) writeFoot() { switch sw.msg.OutputType { case JSON: - if len(sw.farr) > 0 && sw.hasFieldsOutput() { + if sw.fkeys.Len() > 0 && sw.hasFieldsOutput() { sw.wr.WriteString(`,"fields":[`) - for i, field := range sw.farr { + var i int + sw.fkeys.Scan(func(name string) bool { if i > 0 { sw.wr.WriteByte(',') } - sw.wr.WriteString(jsonString(field)) - } + sw.wr.WriteString(jsonString(name)) + i++ + return true + }) sw.wr.WriteByte(']') } switch sw.output { @@ -198,11 +156,11 @@ func (sw *scanWriter) writeHead() { } case RESP: } -} -func (sw *scanWriter) writeFoot() { - sw.mu.Lock() - defer sw.mu.Unlock() + for _, opts := range sw.filled { + sw.writeFilled(opts) + } + cursor := sw.numberIters if !sw.hitLimit { cursor = 0 @@ -243,100 +201,43 @@ func extractZCoordinate(o geojson.Object) float64 { } } -func (sw *scanWriter) fieldMatch(fields []float64, o geojson.Object) (fvals []float64, match bool) { - var z float64 - var gotz bool - fvals = sw.fvals - if !sw.hasFieldsOutput() || sw.fullFields { - for _, where := range sw.wheres { - if where.field == "z" { - if !gotz { - z = extractZCoordinate(o) - } - if !where.match(z) { - return - } - continue - } - var value float64 - if where.index < len(fields) { - value = fields[where.index] - } - if !where.match(value) { - return - } +func getFieldValue(o geojson.Object, fields field.List, name string) field.Value { + if name == "z" { + return field.ValueOf(strconv.FormatFloat(extractZCoordinate(o), 'f', -1, 64)) + } + f := fields.Get(name) + return f.Value() +} + +func (sw *scanWriter) fieldMatch(o geojson.Object, fields field.List) (bool, error) { + for _, where := range sw.wheres { + if !where.match(getFieldValue(o, fields, where.name)) { + return false, nil } - for _, wherein := range sw.whereins { - var value float64 - if wherein.index < len(fields) { - value = fields[wherein.index] - } - if !wherein.match(value) { - return - } + } + for _, wherein := range sw.whereins { + if !wherein.match(getFieldValue(o, fields, wherein.name)) { + return false, nil } + } + if len(sw.whereevals) > 0 { + fieldsWithNames := make(map[string]field.Value) + fieldsWithNames["z"] = field.ValueOf(strconv.FormatFloat(extractZCoordinate(o), 'f', -1, 64)) + fields.Scan(func(f field.Field) bool { + fieldsWithNames[f.Name()] = f.Value() + return true + }) for _, whereval := range sw.whereevals { - fieldsWithNames := make(map[string]float64) - for field, idx := range sw.fmap { - if idx < len(fields) { - fieldsWithNames[field] = fields[idx] - } else { - fieldsWithNames[field] = 0 - } + match, err := whereval.match(fieldsWithNames) + if err != nil { + return false, err } - if !whereval.match(fieldsWithNames) { - return - } - } - } else { - copy(sw.fvals, fields) - // fields might be shorter for this item, need to pad sw.fvals with zeros - for i := len(fields); i < len(sw.fvals); i++ { - sw.fvals[i] = 0 - } - for _, where := range sw.wheres { - if where.field == "z" { - if !gotz { - z = extractZCoordinate(o) - } - if !where.match(z) { - return - } - continue - } - var value float64 - if where.index < len(sw.fvals) { - value = sw.fvals[where.index] - } - if !where.match(value) { - return - } - } - for _, wherein := range sw.whereins { - var value float64 - if wherein.index < len(sw.fvals) { - value = sw.fvals[wherein.index] - } - if !wherein.match(value) { - return - } - } - for _, whereval := range sw.whereevals { - fieldsWithNames := make(map[string]float64) - for field, idx := range sw.fmap { - if idx < len(fields) { - fieldsWithNames[field] = fields[idx] - } else { - fieldsWithNames[field] = 0 - } - } - if !whereval.match(fieldsWithNames) { - return + if !match { + return false, nil } } } - match = true - return + return true, nil } func (sw *scanWriter) globMatch(id string, o geojson.Object) (ok, keepGoing bool) { @@ -356,7 +257,6 @@ func (sw *scanWriter) globMatch(id string, o geojson.Object) (ok, keepGoing bool } } return false, true - } // Increment cursor @@ -370,38 +270,64 @@ func (sw *scanWriter) Step(n uint64) { // ok is whether the object passes the test and should be written // keepGoing is whether there could be more objects to test -func (sw *scanWriter) testObject(id string, o geojson.Object, fields []float64) ( - ok, keepGoing bool, fieldVals []float64) { +func (sw *scanWriter) testObject(id string, o geojson.Object, fields field.List, +) (ok, keepGoing bool, err error) { match, kg := sw.globMatch(id, o) if !match { - return false, kg, fieldVals + return false, kg, nil } - nf, ok := sw.fieldMatch(fields, o) - return ok, true, nf + ok, err = sw.fieldMatch(o, fields) + if err != nil { + return false, false, err + } + return ok, true, nil } -// id string, o geojson.Object, fields []float64, noLock bool -func (sw *scanWriter) writeObject(opts ScanWriterParams) bool { - if !opts.noLock { - sw.mu.Lock() - defer sw.mu.Unlock() - } - - keepGoing := true +func (sw *scanWriter) pushObject(opts ScanWriterParams) (keepGoing bool, err error) { + keepGoing = true if !opts.noTest { var ok bool - ok, keepGoing, _ = sw.testObject(opts.id, opts.o, opts.fields) + var err error + ok, keepGoing, err = sw.testObject(opts.id, opts.o, opts.fields) + if err != nil { + return false, err + } if !ok { - return keepGoing + return keepGoing, nil } } sw.count++ if sw.output == outputCount { - return sw.count < sw.limit + return sw.count < sw.limit, nil } if opts.clip != nil { opts.o = clip.Clip(opts.o, opts.clip, &sw.s.geomIndexOpts) } + if !sw.fullFields { + opts.fields.Scan(func(f field.Field) bool { + sw.fkeys.Insert(f.Name()) + return true + }) + } + sw.filled = append(sw.filled, opts) + sw.numberItems++ + if sw.numberItems == sw.limit { + sw.hitLimit = true + return false, nil + } + return keepGoing, nil +} + +func (sw *scanWriter) writeObject(opts ScanWriterParams) { + n := len(sw.filled) + sw.pushObject(opts) + if len(sw.filled) > n { + sw.writeFilled(sw.filled[len(sw.filled)-1]) + sw.filled = sw.filled[:n] + } +} + +func (sw *scanWriter) writeFilled(opts ScanWriterParams) { switch sw.msg.OutputType { case JSON: var wr bytes.Buffer @@ -411,40 +337,36 @@ func (sw *scanWriter) writeObject(opts ScanWriterParams) bool { } else { sw.once = true } - if sw.hasFieldsOutput() { - if sw.fullFields { - if len(sw.fmap) > 0 { - jsfields = `,"fields":{` - var i int - for field, idx := range sw.fmap { - if len(opts.fields) > idx { - if opts.fields[idx] != 0 { - if i > 0 { - jsfields += `,` - } - jsfields += jsonString(field) + ":" + strconv.FormatFloat(opts.fields[idx], 'f', -1, 64) - i++ - } + fieldsOutput := sw.hasFieldsOutput() + if fieldsOutput && sw.fullFields { + if opts.fields.Len() > 0 { + jsfields = `,"fields":{` + var i int + opts.fields.Scan(func(f field.Field) bool { + if !f.Value().IsZero() { + if i > 0 { + jsfields += `,` } + jsfields += jsonString(f.Name()) + ":" + f.Value().JSON() + i++ } - jsfields += `}` - } - - } else if len(sw.farr) > 0 { - jsfields = `,"fields":[` - for i, name := range sw.farr { - if i > 0 { - jsfields += `,` - } - j := sw.fmap[name] - if j < len(opts.fields) { - jsfields += strconv.FormatFloat(opts.fields[j], 'f', -1, 64) - } else { - jsfields += "0" - } - } - jsfields += `]` + return true + }) + jsfields += `}` } + } else if fieldsOutput && sw.fkeys.Len() > 0 && !sw.fullFields { + jsfields = `,"fields":[` + var i int + sw.fkeys.Scan(func(name string) bool { + if i > 0 { + jsfields += `,` + } + f := opts.fields.Get(name) + jsfields += f.Value().JSON() + i++ + return true + }) + jsfields += `]` } if sw.output == outputIDs { if opts.distOutput || opts.distance > 0 { @@ -467,9 +389,7 @@ func (sw *scanWriter) writeObject(opts ScanWriterParams) bool { case outputBounds: wr.WriteString(`,"bounds":` + string(appendJSONSimpleBounds(nil, opts.o))) } - wr.WriteString(jsfields) - if opts.distOutput || opts.distance > 0 { wr.WriteString(`,"distance":` + strconv.FormatFloat(opts.distance, 'f', -1, 64)) } @@ -523,15 +443,17 @@ func (sw *scanWriter) writeObject(opts ScanWriterParams) bool { }), })) } - if sw.hasFieldsOutput() { - fvs := orderFields(sw.fmap, sw.farr, opts.fields) - if len(fvs) > 0 { - fvals := make([]resp.Value, 0, len(fvs)*2) - for i, fv := range fvs { - fvals = append(fvals, resp.StringValue(fv.field), resp.StringValue(strconv.FormatFloat(fv.value, 'f', -1, 64))) - i++ - } + if opts.fields.Len() > 0 { + var fvals []resp.Value + var i int + opts.fields.Scan(func(f field.Field) bool { + if !f.Value().IsZero() { + fvals = append(fvals, resp.StringValue(f.Name()), resp.StringValue(f.Value().Data())) + i++ + } + return true + }) vals = append(vals, resp.ArrayValue(fvals)) } } @@ -542,10 +464,4 @@ func (sw *scanWriter) writeObject(opts ScanWriterParams) bool { sw.values = append(sw.values, resp.ArrayValue(vals)) } } - sw.numberItems++ - if sw.numberItems == sw.limit { - sw.hitLimit = true - return false - } - return keepGoing } diff --git a/internal/server/scanner_test.go b/internal/server/scanner_test.go index e092cd8a..aecb51f0 100644 --- a/internal/server/scanner_test.go +++ b/internal/server/scanner_test.go @@ -1,6 +1,7 @@ package server import ( + "fmt" "math" "math/rand" "testing" @@ -8,11 +9,12 @@ import ( "github.com/tidwall/geojson" "github.com/tidwall/geojson/geometry" + "github.com/tidwall/tile38/internal/field" ) type testPointItem struct { object geojson.Object - fields []float64 + fields field.List } func PO(x, y float64) *geojson.Point { @@ -23,29 +25,29 @@ func BenchmarkFieldMatch(t *testing.B) { rand.Seed(time.Now().UnixNano()) items := make([]testPointItem, t.N) for i := 0; i < t.N; i++ { + var fields field.List + fields = fields.Set(field.Make("foo", fmt.Sprintf("%f", rand.Float64()*9+1))) + fields = fields.Set(field.Make("bar", fmt.Sprintf("%f", math.Round(rand.Float64()*30)+1))) items[i] = testPointItem{ PO(rand.Float64()*360-180, rand.Float64()*180-90), - []float64{rand.Float64()*9 + 1, math.Round(rand.Float64()*30) + 1}, + fields, } } sw := &scanWriter{ wheres: []whereT{ - {"foo", 0, false, 1, false, 3}, - {"bar", 1, false, 10, false, 30}, + {"foo", false, field.ValueOf("1"), false, field.ValueOf("3")}, + {"bar", false, field.ValueOf("10"), false, field.ValueOf("30")}, }, whereins: []whereinT{ - {"foo", 0, []float64{1, 2}}, - {"bar", 1, []float64{11, 25}}, + {"foo", []field.Value{field.ValueOf("1"), field.ValueOf("2")}}, + {"bar", []field.Value{field.ValueOf("11"), field.ValueOf("25")}}, }, - fmap: map[string]int{"foo": 0, "bar": 1}, - farr: []string{"bar", "foo"}, } - sw.fvals = make([]float64, len(sw.farr)) t.ResetTimer() for i := 0; i < t.N; i++ { // one call is super fast, measurements are not reliable, let's do 100 for ix := 0; ix < 100; ix++ { - sw.fieldMatch(items[i].fields, items[i].object) + sw.fieldMatch(items[i].object, items[i].fields) } } } diff --git a/internal/server/scripts.go b/internal/server/scripts.go index 7e0b739e..8b8b5fca 100644 --- a/internal/server/scripts.go +++ b/internal/server/scripts.go @@ -592,9 +592,9 @@ func (s *Server) commandInScript(msg *Message) ( default: err = fmt.Errorf("unknown command '%s'", msg.Args[0]) case "set": - res, d, err = s.cmdSet(msg) + res, d, err = s.cmdSET(msg) case "fset": - res, d, err = s.cmdFset(msg) + res, d, err = s.cmdFSET(msg) case "del": res, d, err = s.cmdDel(msg) case "pdel": @@ -602,13 +602,13 @@ func (s *Server) commandInScript(msg *Message) ( case "drop": res, d, err = s.cmdDrop(msg) case "expire": - res, d, err = s.cmdExpire(msg) + res, d, err = s.cmdEXPIRE(msg) case "rename": res, d, err = s.cmdRename(msg) case "renamenx": res, d, err = s.cmdRename(msg) case "persist": - res, d, err = s.cmdPersist(msg) + res, d, err = s.cmdPERSIST(msg) case "ttl": res, err = s.cmdTTL(msg) case "stats": @@ -618,9 +618,9 @@ func (s *Server) commandInScript(msg *Message) ( case "nearby": res, err = s.cmdNearby(msg) case "within": - res, err = s.cmdWithin(msg) + res, err = s.cmdWITHIN(msg) case "intersects": - res, err = s.cmdIntersects(msg) + res, err = s.cmdINTERSECTS(msg) case "search": res, err = s.cmdSearch(msg) case "bounds": diff --git a/internal/server/search.go b/internal/server/search.go index 7b1b6646..4af4d774 100644 --- a/internal/server/search.go +++ b/internal/server/search.go @@ -16,6 +16,7 @@ import ( "github.com/tidwall/tile38/internal/bing" "github.com/tidwall/tile38/internal/buffer" "github.com/tidwall/tile38/internal/clip" + "github.com/tidwall/tile38/internal/field" "github.com/tidwall/tile38/internal/glob" ) @@ -496,19 +497,23 @@ func (s *Server) cmdNearby(msg *Message) (res resp.Value, err error) { if msg.OutputType == JSON { wr.WriteString(`{"ok":true`) } - sw.writeHead() + var ierr error if sw.col != nil { - iterStep := func(id string, o geojson.Object, fields []float64, meters float64) bool { - return sw.writeObject(ScanWriterParams{ + iterStep := func(id string, o geojson.Object, fields field.List, meters float64) bool { + keepGoing, err := sw.pushObject(ScanWriterParams{ id: id, o: o, fields: fields, distance: meters, distOutput: sargs.distance, - noLock: true, ignoreGlobMatch: true, skipTesting: true, }) + if err != nil { + ierr = err + return false + } + return keepGoing } maxDist := sargs.obj.(*geojson.Circle).Meters() if sargs.sparse > 0 { @@ -518,7 +523,7 @@ func (s *Server) cmdNearby(msg *Message) (res resp.Value, err error) { errors.New("cannot use SPARSE without a point distance") } // An intersects operation is required for SPARSE - iter := func(id string, o geojson.Object, fields []float64) bool { + iter := func(id string, o geojson.Object, fields field.List) bool { var meters float64 if sargs.distance { meters = o.Distance(sargs.obj) @@ -527,7 +532,7 @@ func (s *Server) cmdNearby(msg *Message) (res resp.Value, err error) { } sw.col.Intersects(sargs.obj, sargs.sparse, sw, msg.Deadline, iter) } else { - iter := func(id string, o geojson.Object, fields []float64, dist float64) bool { + iter := func(id string, o geojson.Object, fields field.List, dist float64) bool { if maxDist > 0 && dist > maxDist { return false } @@ -540,6 +545,9 @@ func (s *Server) cmdNearby(msg *Message) (res resp.Value, err error) { sw.col.Nearby(sargs.obj, sw, msg.Deadline, iter) } } + if ierr != nil { + return retrerr(ierr) + } sw.writeFoot() if msg.OutputType == JSON { wr.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}") @@ -548,15 +556,15 @@ func (s *Server) cmdNearby(msg *Message) (res resp.Value, err error) { return sw.respOut, nil } -func (s *Server) cmdWithin(msg *Message) (res resp.Value, err error) { - return s.cmdWithinOrIntersects("within", msg) +func (s *Server) cmdWITHIN(msg *Message) (res resp.Value, err error) { + return s.cmdWITHINorINTERSECTS("within", msg) } -func (s *Server) cmdIntersects(msg *Message) (res resp.Value, err error) { - return s.cmdWithinOrIntersects("intersects", msg) +func (s *Server) cmdINTERSECTS(msg *Message) (res resp.Value, err error) { + return s.cmdWITHINorINTERSECTS("intersects", msg) } -func (s *Server) cmdWithinOrIntersects(cmd string, msg *Message) (res resp.Value, err error) { +func (s *Server) cmdWITHINorINTERSECTS(cmd string, msg *Message) (res resp.Value, err error) { start := time.Now() vs := msg.Args[1:] @@ -588,38 +596,49 @@ func (s *Server) cmdWithinOrIntersects(cmd string, msg *Message) (res resp.Value if msg.OutputType == JSON { wr.WriteString(`{"ok":true`) } - sw.writeHead() + var ierr error if sw.col != nil { if cmd == "within" { sw.col.Within(sargs.obj, sargs.sparse, sw, msg.Deadline, func( - id string, o geojson.Object, fields []float64, + id string, o geojson.Object, fields field.List, ) bool { - return sw.writeObject(ScanWriterParams{ + keepGoing, err := sw.pushObject(ScanWriterParams{ id: id, o: o, fields: fields, - noLock: true, }) + if err != nil { + ierr = err + return false + } + return keepGoing }) } else if cmd == "intersects" { sw.col.Intersects(sargs.obj, sargs.sparse, sw, msg.Deadline, func( id string, o geojson.Object, - fields []float64, + fields field.List, ) bool { params := ScanWriterParams{ id: id, o: o, fields: fields, - noLock: true, } if sargs.clip { params.clip = sargs.obj } - return sw.writeObject(params) + keepGoing, err := sw.pushObject(params) + if err != nil { + ierr = err + return false + } + return keepGoing }) } } + if ierr != nil { + return retrerr(ierr) + } sw.writeFoot() if msg.OutputType == JSON { wr.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}") @@ -701,7 +720,7 @@ func (s *Server) cmdSearch(msg *Message) (res resp.Value, err error) { if msg.OutputType == JSON { wr.WriteString(`{"ok":true`) } - sw.writeHead() + var ierr error if sw.col != nil { if sw.output == outputCount && len(sw.wheres) == 0 && sw.globEverything { count := sw.col.Count() - int(sargs.cursor) @@ -713,13 +732,17 @@ func (s *Server) cmdSearch(msg *Message) (res resp.Value, err error) { limits := multiGlobParse(sw.globs, sargs.desc) if limits[0] == "" && limits[1] == "" { sw.col.SearchValues(sargs.desc, sw, msg.Deadline, - func(id string, o geojson.Object, fields []float64) bool { - return sw.writeObject(ScanWriterParams{ + func(id string, o geojson.Object, fields field.List) bool { + keepGoing, err := sw.pushObject(ScanWriterParams{ id: id, o: o, fields: fields, - noLock: true, }) + if err != nil { + ierr = err + return false + } + return keepGoing }, ) } else { @@ -727,18 +750,25 @@ func (s *Server) cmdSearch(msg *Message) (res resp.Value, err error) { // globSingle is only for ID matches, not values. sw.col.SearchValuesRange(limits[0], limits[1], sargs.desc, sw, msg.Deadline, - func(id string, o geojson.Object, fields []float64) bool { - return sw.writeObject(ScanWriterParams{ + func(id string, o geojson.Object, fields field.List) bool { + keepGoing, err := sw.pushObject(ScanWriterParams{ id: id, o: o, fields: fields, - noLock: true, }) + if err != nil { + ierr = err + return false + } + return keepGoing }, ) } } } + if ierr != nil { + return retrerr(ierr) + } sw.writeFoot() if msg.OutputType == JSON { wr.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}") diff --git a/internal/server/server.go b/internal/server/server.go index 68ad06c9..b7239248 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -36,6 +36,7 @@ import ( "github.com/tidwall/tile38/internal/collection" "github.com/tidwall/tile38/internal/deadline" "github.com/tidwall/tile38/internal/endpoint" + "github.com/tidwall/tile38/internal/field" "github.com/tidwall/tile38/internal/log" ) @@ -53,14 +54,16 @@ const ( // commandDetails is detailed information about a mutable command. It's used // for geofence formulas. type commandDetails struct { - command string // client command, like "SET" or "DEL" - key, id string // collection key and object id of object - newKey string // new key, for RENAME command - fmap map[string]int // map of field names to value indexes - obj geojson.Object // new object - fields []float64 // array of field values - oldObj geojson.Object // previous object, if any - oldFields []float64 // previous object field values + command string // client command, like "SET" or "DEL" + key, id string // collection key and object id of object + newKey string // new key, for RENAME command + + obj geojson.Object // new object + fields field.List // array of field values + + oldObj geojson.Object // previous object, if any + oldFields field.List // previous object field values + updated bool // object was updated timestamp time.Time // timestamp when the update occured parent bool // when true, only children are forwarded @@ -1016,9 +1019,9 @@ func (s *Server) command(msg *Message, client *Client) ( default: err = fmt.Errorf("unknown command '%s'", msg.Args[0]) case "set": - res, d, err = s.cmdSet(msg) + res, d, err = s.cmdSET(msg) case "fset": - res, d, err = s.cmdFset(msg) + res, d, err = s.cmdFSET(msg) case "del": res, d, err = s.cmdDel(msg) case "pdel": @@ -1026,7 +1029,7 @@ func (s *Server) command(msg *Message, client *Client) ( case "drop": res, d, err = s.cmdDrop(msg) case "flushdb": - res, d, err = s.cmdFlushDB(msg) + res, d, err = s.cmdFLUSHDB(msg) case "rename": res, d, err = s.cmdRename(msg) case "renamenx": @@ -1048,9 +1051,9 @@ func (s *Server) command(msg *Message, client *Client) ( case "chans": res, err = s.cmdHooks(msg) case "expire": - res, d, err = s.cmdExpire(msg) + res, d, err = s.cmdEXPIRE(msg) case "persist": - res, d, err = s.cmdPersist(msg) + res, d, err = s.cmdPERSIST(msg) case "ttl": res, err = s.cmdTTL(msg) case "shutdown": @@ -1090,9 +1093,9 @@ func (s *Server) command(msg *Message, client *Client) ( case "nearby": res, err = s.cmdNearby(msg) case "within": - res, err = s.cmdWithin(msg) + res, err = s.cmdWITHIN(msg) case "intersects": - res, err = s.cmdIntersects(msg) + res, err = s.cmdINTERSECTS(msg) case "search": res, err = s.cmdSearch(msg) case "bounds": diff --git a/internal/server/token.go b/internal/server/token.go index 178d9a1f..a4200a5e 100644 --- a/internal/server/token.go +++ b/internal/server/token.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" + "github.com/tidwall/tile38/internal/field" lua "github.com/yuin/gopher-lua" ) @@ -44,31 +45,6 @@ func tokenval(vs []string) (nvs []string, token string, ok bool) { return } -func tokenvalbytes(vs []string) (nvs []string, token []byte, ok bool) { - if len(vs) > 0 { - token = []byte(vs[0]) - nvs = vs[1:] - ok = true - } - return -} - -func lcb(s1 []byte, s2 string) bool { - if len(s1) != len(s2) { - return false - } - for i := 0; i < len(s1); i++ { - ch := s1[i] - if ch >= 'A' && ch <= 'Z' { - if ch+32 != s2[i] { - return false - } - } else if ch != s2[i] { - return false - } - } - return true -} func lc(s1, s2 string) bool { if len(s1) != len(s2) { return false @@ -87,30 +63,35 @@ func lc(s1, s2 string) bool { } type whereT struct { - field string - index int - minx bool - min float64 - maxx bool - max float64 + name string + minx bool + min field.Value + maxx bool + max field.Value } -func (where whereT) match(value float64) bool { +func mLT(a, b field.Value) bool { return a.Less(b) } +func mLTE(a, b field.Value) bool { return !mLT(b, a) } +func mGT(a, b field.Value) bool { return mLT(b, a) } +func mGTE(a, b field.Value) bool { return !mLT(a, b) } +func mEQ(a, b field.Value) bool { return a.Equals(b) } + +func (where whereT) match(value field.Value) bool { if !where.minx { - if value < where.min { + if mLT(value, where.min) { // if value < where.min { return false } } else { - if value <= where.min { + if mLTE(value, where.min) { // if value <= where.min { return false } } if !where.maxx { - if value > where.max { + if mGT(value, where.max) { // if value > where.max { return false } } else { - if value >= where.max { + if mGTE(value, where.max) { // if value >= where.max { return false } } @@ -118,14 +99,13 @@ func (where whereT) match(value float64) bool { } type whereinT struct { - field string - index int - valArr []float64 + name string + valArr []field.Value } -func (wherein whereinT) match(value float64) bool { +func (wherein whereinT) match(value field.Value) bool { for _, val := range wherein.valArr { - if val == value { + if mEQ(val, value) { return true } } @@ -146,12 +126,28 @@ func (whereeval whereevalT) Close() { whereeval.c.luapool.Put(whereeval.luaState) } -func (whereeval whereevalT) match(fieldsWithNames map[string]float64) bool { - fieldsTbl := whereeval.luaState.CreateTable(0, len(fieldsWithNames)) - for field, val := range fieldsWithNames { - fieldsTbl.RawSetString(field, lua.LNumber(val)) +func luaSetField(tbl *lua.LTable, name string, val field.Value) { + var lval lua.LValue + switch val.Kind() { + case field.Null: + lval = lua.LNil + case field.False: + lval = lua.LFalse + case field.True: + lval = lua.LTrue + case field.Number: + lval = lua.LNumber(val.Num()) + default: + lval = lua.LString(val.Data()) } + tbl.RawSetString(name, lval) +} +func (whereeval whereevalT) match(fieldsWithNames map[string]field.Value) (bool, error) { + fieldsTbl := whereeval.luaState.CreateTable(0, len(fieldsWithNames)) + for name, val := range fieldsWithNames { + luaSetField(fieldsTbl, name, val) + } luaSetRawGlobals( whereeval.luaState, map[string]lua.LValue{ "FIELDS": fieldsTbl, @@ -163,7 +159,7 @@ func (whereeval whereevalT) match(fieldsWithNames map[string]float64) bool { whereeval.luaState.Push(whereeval.fn) if err := whereeval.luaState.PCall(0, 1, nil); err != nil { - panic(err.Error()) + return false, err } ret := whereeval.luaState.Get(-1) whereeval.luaState.Pop(1) @@ -171,23 +167,23 @@ func (whereeval whereevalT) match(fieldsWithNames map[string]float64) bool { // Make bool out of returned lua value switch ret.Type() { case lua.LTNil: - return false + return false, nil case lua.LTBool: - return ret == lua.LTrue + return ret == lua.LTrue, nil case lua.LTNumber: - return float64(ret.(lua.LNumber)) != 0 + return float64(ret.(lua.LNumber)) != 0, nil case lua.LTString: - return ret.String() != "" + return ret.String() != "", nil case lua.LTTable: tbl := ret.(*lua.LTable) if tbl.Len() != 0 { - return true + return true, nil } var match bool tbl.ForEach(func(lk lua.LValue, lv lua.LValue) { match = true }) - return match + return match, nil } - panic(fmt.Sprintf("Script returned value of type %s", ret.Type())) + return false, fmt.Errorf("script returned value of type %s", ret.Type()) } type searchScanBaseTokens struct { @@ -265,57 +261,54 @@ func (s *Server) parseSearchScanBaseTokens( continue case "where": vs = nvs - var field, smin, smax string - if vs, field, ok = tokenval(vs); !ok || field == "" { + var name, smin, smax string + if vs, name, ok = tokenval(vs); !ok { err = errInvalidNumberOfArguments return } - if vs, smin, ok = tokenval(vs); !ok || smin == "" { + if vs, smin, ok = tokenval(vs); !ok { err = errInvalidNumberOfArguments return } - if vs, smax, ok = tokenval(vs); !ok || smax == "" { + if vs, smax, ok = tokenval(vs); !ok { err = errInvalidNumberOfArguments return } var minx, maxx bool - var min, max float64 - if strings.ToLower(smin) == "-inf" { - min = math.Inf(-1) + smin = strings.ToLower(smin) + if smin == "-inf" { + smin = "-inf" } else { if strings.HasPrefix(smin, "(") { minx = true smin = smin[1:] } - min, err = strconv.ParseFloat(smin, 64) - if err != nil { - err = errInvalidArgument(smin) - return - } } - if strings.ToLower(smax) == "+inf" { - max = math.Inf(+1) + smax = strings.ToLower(smax) + if smax == "+inf" || smax == "inf" { + smax = "inf" } else { if strings.HasPrefix(smax, "(") { maxx = true smax = smax[1:] } - max, err = strconv.ParseFloat(smax, 64) - if err != nil { - err = errInvalidArgument(smax) - return - } } - t.wheres = append(t.wheres, whereT{field, -1, minx, min, maxx, max}) + t.wheres = append(t.wheres, whereT{ + name: strings.ToLower(name), + minx: minx, + min: field.ValueOf(smin), + maxx: maxx, + max: field.ValueOf(smax), + }) continue case "wherein": vs = nvs - var field, nvalsStr, valStr string - if vs, field, ok = tokenval(vs); !ok || field == "" { + var name, nvalsStr, valStr string + if vs, name, ok = tokenval(vs); !ok { err = errInvalidNumberOfArguments return } - if vs, nvalsStr, ok = tokenval(vs); !ok || nvalsStr == "" { + if vs, nvalsStr, ok = tokenval(vs); !ok { err = errInvalidNumberOfArguments return } @@ -324,20 +317,18 @@ func (s *Server) parseSearchScanBaseTokens( err = errInvalidArgument(nvalsStr) return } - valArr := make([]float64, nvals) - var val float64 + valArr := make([]field.Value, nvals) for i = 0; i < nvals; i++ { - if vs, valStr, ok = tokenval(vs); !ok || valStr == "" { + if vs, valStr, ok = tokenval(vs); !ok { err = errInvalidNumberOfArguments return } - if val, err = strconv.ParseFloat(valStr, 64); err != nil { - err = errInvalidArgument(valStr) - return - } - valArr[i] = val + valArr[i] = field.ValueOf(valStr) } - t.whereins = append(t.whereins, whereinT{field, -1, valArr}) + t.whereins = append(t.whereins, whereinT{ + name: strings.ToLower(name), + valArr: valArr, + }) continue case "whereevalsha": fallthrough @@ -409,7 +400,9 @@ func (s *Server) parseSearchScanBaseTokens( } s.luascripts.Put(shaSum, fn.Proto) } - t.whereevals = append(t.whereevals, whereevalT{s, luaState, fn}) + t.whereevals = append(t.whereevals, whereevalT{ + c: s, luaState: luaState, fn: fn, + }) continue case "nofields": vs = nvs diff --git a/internal/sstring/sstring.go b/internal/sstring/sstring.go new file mode 100644 index 00000000..36d240b9 --- /dev/null +++ b/internal/sstring/sstring.go @@ -0,0 +1,54 @@ +// Package shared allows for +package sstring + +import ( + "sync" + "unsafe" + + "github.com/tidwall/hashmap" +) + +var mu sync.Mutex +var nums hashmap.Map[string, int] +var strs []string + +// Load a shared string from its number. +// Panics when there is no string assigned with that number. +func Load(num int) (str string) { + mu.Lock() + if num >= 0 && num < len(strs) { + str = strs[num] + mu.Unlock() + return str + } + mu.Unlock() + panic("string not found") +} + +// Store a shared string. +// Returns a unique number that can be used to load the string later. +// The number is al +func Store(str string) (num int) { + mu.Lock() + var ok bool + num, ok = nums.Get(str) + if !ok { + // Make a copy of the string to ensure we don't take in slices. + b := make([]byte, len(str)) + copy(b, str) + str = *(*string)(unsafe.Pointer(&b)) + num = len(strs) + strs = append(strs, str) + nums.Set(str, num) + } + mu.Unlock() + return num +} + +// Len returns the number of shared strings +func Len() int { + mu.Lock() + n := len(strs) + mu.Unlock() + return n +} diff --git a/internal/sstring/sstring_test.go b/internal/sstring/sstring_test.go new file mode 100644 index 00000000..85945747 --- /dev/null +++ b/internal/sstring/sstring_test.go @@ -0,0 +1,85 @@ +package sstring + +import ( + "math/rand" + "testing" + "time" + + "github.com/tidwall/assert" +) + +func TestShared(t *testing.T) { + for i := -1; i < 10; i++ { + var str string + func() { + defer func() { + assert.Assert(recover().(string) == "string not found") + }() + str = Load(i) + }() + assert.Assert(str == "") + } + assert.Assert(Store("hello") == 0) + assert.Assert(Store("") == 1) + assert.Assert(Store("jello") == 2) + assert.Assert(Store("hello") == 0) + assert.Assert(Store("") == 1) + assert.Assert(Store("jello") == 2) + str := Load(0) + assert.Assert(str == "hello") + str = Load(1) + assert.Assert(str == "") + str = Load(2) + assert.Assert(str == "jello") + + assert.Assert(Len() == 3) + +} + +func randStr(n int) string { + b := make([]byte, n) + rand.Read(b) + for i := 0; i < n; i++ { + b[i] = 'a' + b[i]%26 + } + return string(b) +} + +func BenchmarkStore(b *testing.B) { + rand.Seed(time.Now().UnixNano()) + wmap := make(map[string]bool, b.N) + for len(wmap) < b.N { + wmap[randStr(10)] = true + } + words := make([]string, 0, b.N) + for word := range wmap { + words = append(words, word) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + Store(words[i]) + } +} + +func BenchmarkLoad(b *testing.B) { + rand.Seed(time.Now().UnixNano()) + wmap := make(map[string]bool, b.N) + for len(wmap) < b.N { + wmap[randStr(10)] = true + } + words := make([]string, 0, b.N) + for word := range wmap { + words = append(words, word) + } + var nums []int + for i := 0; i < b.N; i++ { + nums = append(nums, Store(words[i])) + } + rand.Shuffle(len(nums), func(i, j int) { + nums[i], nums[j] = nums[j], nums[i] + }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + Load(nums[i]) + } +} diff --git a/tests/keys_search_test.go b/tests/keys_search_test.go index 2a164ed9..e68b9236 100644 --- a/tests/keys_search_test.go +++ b/tests/keys_search_test.go @@ -740,9 +740,9 @@ func keys_FIELDS_search_test(mc *mockServer) error { `{"id":"5","object":{"type":"Point","coordinates":[-112.2799,33.5228]},"fields":[0,15,28]}` + `],"count":4,"cursor":0}`}, {"NEARBY", "mykey", "WHERE", "field2", 0, 2, "POINT", 33.462, -112.268, 60000}, { - `{"ok":true,"fields":["field1","field2","field3"],"objects":[` + - `{"id":"6","object":{"type":"Point","coordinates":[-112.2801,33.523]},"fields":[0,0,29]},` + - `{"id":"7","object":{"type":"Point","coordinates":[-112.2803,33.5232]},"fields":[0,0,0]}` + + `{"ok":true,"fields":["field3"],"objects":[` + + `{"id":"6","object":{"type":"Point","coordinates":[-112.2801,33.523]},"fields":[29]},` + + `{"id":"7","object":{"type":"Point","coordinates":[-112.2803,33.5232]},"fields":[0]}` + `],"count":2,"cursor":0}`}, {"WITHIN", "mykey", "WHERE", "field2", 11, "+inf", "CIRCLE", 33.462, -112.268, 60000}, { @@ -753,9 +753,9 @@ func keys_FIELDS_search_test(mc *mockServer) error { `{"id":"1","object":{"type":"Point","coordinates":[-112.2791,33.522]},"fields":[10,11,0]}` + `],"count":4,"cursor":0}`}, {"WITHIN", "mykey", "WHERE", "field2", 0, 2, "CIRCLE", 33.462, -112.268, 60000}, { - `{"ok":true,"fields":["field1","field2","field3"],"objects":[` + - `{"id":"7","object":{"type":"Point","coordinates":[-112.2803,33.5232]},"fields":[0,0,0]},` + - `{"id":"6","object":{"type":"Point","coordinates":[-112.2801,33.523]},"fields":[0,0,29]}` + + `{"ok":true,"fields":["field3"],"objects":[` + + `{"id":"7","object":{"type":"Point","coordinates":[-112.2803,33.5232]},"fields":[0]},` + + `{"id":"6","object":{"type":"Point","coordinates":[-112.2801,33.523]},"fields":[29]}` + `],"count":2,"cursor":0}`}, }) } diff --git a/tests/keys_test.go b/tests/keys_test.go index b5b16464..aef23ded 100644 --- a/tests/keys_test.go +++ b/tests/keys_test.go @@ -352,8 +352,8 @@ func keys_FIELDS_test(mc *mockServer) error { return mc.DoBatch([][]interface{}{ {"SET", "mykey", "myid1a", "FIELD", "a", 1, "POINT", 33, -115}, {"OK"}, {"GET", "mykey", "myid1a", "WITHFIELDS"}, {`[{"type":"Point","coordinates":[-115,33]} [a 1]]`}, - {"SET", "mykey", "myid1a", "FIELD", "a", "a", "POINT", 33, -115}, {"ERR invalid argument 'a'"}, - {"GET", "mykey", "myid1a", "WITHFIELDS"}, {`[{"type":"Point","coordinates":[-115,33]} [a 1]]`}, + {"SET", "mykey", "myid1a", "FIELD", "a", "a", "POINT", 33, -115}, {"OK"}, + {"GET", "mykey", "myid1a", "WITHFIELDS"}, {`[{"type":"Point","coordinates":[-115,33]} [a a]]`}, {"SET", "mykey", "myid1a", "FIELD", "a", 1, "FIELD", "b", 2, "POINT", 33, -115}, {"OK"}, {"GET", "mykey", "myid1a", "WITHFIELDS"}, {`[{"type":"Point","coordinates":[-115,33]} [a 1 b 2]]`}, {"SET", "mykey", "myid1a", "FIELD", "b", 2, "POINT", 33, -115}, {"OK"}, @@ -390,7 +390,8 @@ func keys_WHEREIN_test(mc *mockServer) error { {"WITHIN", "mykey", "WHEREIN", "a", 3, 0, 1, 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {`[0 [[myid_a1 {"type":"Point","coordinates":[-115,33]} [a 1]]]]`}, {"WITHIN", "mykey", "WHEREIN", "a", "a", 0, 1, 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {"ERR invalid argument 'a'"}, {"WITHIN", "mykey", "WHEREIN", "a", 1, 0, 1, 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {"ERR invalid argument '1'"}, - {"WITHIN", "mykey", "WHEREIN", "a", 3, 0, "a", 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {"ERR invalid argument 'a'"}, + {"WITHIN", "mykey", "WHEREIN", "a", 3, 0, "a", 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {"[0 []]"}, + {"WITHIN", "mykey", "WHEREIN", "a", 4, 0, "a", 1, 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {`[0 [[myid_a1 {"type":"Point","coordinates":[-115,33]} [a 1]]]]`}, {"SET", "mykey", "myid_a2", "FIELD", "a", 2, "POINT", 32.99, -115}, {"OK"}, {"SET", "mykey", "myid_a3", "FIELD", "a", 3, "POINT", 33, -115.02}, {"OK"}, {"WITHIN", "mykey", "WHEREIN", "a", 3, 0, 1, 2, "BOUNDS", 32.8, -115.2, 33.2, -114.8}, {