diff --git a/README.md b/README.md index b81c3a7..6656232 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,9 @@ err := tx.Ascend("", func(key, value string) bool{ }) ``` -There is also `AscendGreaterOrEqual`, `AscendLessThan`, `AscendRange`, `Descend`, `DescendLessOrEqual`, `DescendGreaterThan`, and `DescendRange`. Please see the [documentation](https://godoc.org/github.com/tidwall/buntdb) for more information on these functions. +There is also `AscendGreaterOrEqual`, `AscendLessThan`, `AscendRange`, `AscendEqual`, `Descend`, `DescendLessOrEqual`, `DescendGreaterThan`, `DescendRange`, and `DescendEqual`. Please see the [documentation](https://godoc.org/github.com/tidwall/buntdb) for more information on these functions. + + ## Custom Indexes diff --git a/buntdb.go b/buntdb.go index 3a08935..9fcaf31 100644 --- a/buntdb.go +++ b/buntdb.go @@ -1185,6 +1185,7 @@ type dbItemOpts struct { type dbItem struct { key, val string // the binary key and value opts *dbItemOpts // optional meta information + keyless bool // keyless item for scanning } func appendArray(buf []byte, count int) []byte { @@ -1278,6 +1279,11 @@ func (dbi *dbItem) Less(item btree.Item, ctx interface{}) bool { } } // Always fall back to the key comparison. This creates absolute uniqueness. + if dbi.keyless { + return false + } else if dbi2.keyless { + return true + } return dbi.key < dbi2.key } @@ -1507,6 +1513,10 @@ func (tx *Tx) scan(desc, gt, lt bool, index, start, stop string, } else { itemA = &dbItem{val: start} itemB = &dbItem{val: stop} + if desc { + itemA.keyless = true + itemB.keyless = true + } } } // execute the scan on the underlying tree. @@ -1709,6 +1719,62 @@ func (tx *Tx) DescendRange(index, lessOrEqual, greaterThan string, ) } +// AscendEqual calls the iterator for every item in the database that equals +// pivot, until iterator returns false. +// When an index is provided, the results will be ordered by the item values +// as specified by the less() function of the defined index. +// When an index is not provided, the results will be ordered by the item key. +// An invalid index will return an error. +func (tx *Tx) AscendEqual(index, pivot string, + iterator func(key, value string) bool) error { + var err error + var less func(a, b string) bool + if index != "" { + less, err = tx.GetLess(index) + if err != nil { + return err + } + } + return tx.AscendGreaterOrEqual(index, pivot, func(key, value string) bool { + if less == nil { + if key != pivot { + return false + } + } else if less(pivot, value) { + return false + } + return iterator(key, value) + }) +} + +// DescendEqual calls the iterator for every item in the database that equals +// pivot, until iterator returns false. +// When an index is provided, the results will be ordered by the item values +// as specified by the less() function of the defined index. +// When an index is not provided, the results will be ordered by the item key. +// An invalid index will return an error. +func (tx *Tx) DescendEqual(index, pivot string, + iterator func(key, value string) bool) error { + var err error + var less func(a, b string) bool + if index != "" { + less, err = tx.GetLess(index) + if err != nil { + return err + } + } + return tx.DescendLessOrEqual(index, pivot, func(key, value string) bool { + if less == nil { + if key != pivot { + return false + } + } else if less(value, pivot) { + return false + } + return iterator(key, value) + }) +} + // rect is used by Intersects type rect struct { min, max []float64 diff --git a/buntdb_test.go b/buntdb_test.go index e250742..9668f2a 100644 --- a/buntdb_test.go +++ b/buntdb_test.go @@ -495,6 +495,110 @@ func TestDeleteAll(t *testing.T) { } } +func TestAscendEqual(t *testing.T) { + db := testOpen(t) + defer testClose(db) + if err := db.Update(func(tx *Tx) error { + for i := 0; i < 300; i++ { + _, _, err := tx.Set(fmt.Sprintf("key:%05dA", i), fmt.Sprintf("%d", i+1000), nil) + if err != nil { + return err + } + _, _, err = tx.Set(fmt.Sprintf("key:%05dB", i), fmt.Sprintf("%d", i+1000), nil) + if err != nil { + return err + } + } + return tx.CreateIndex("num", "*", IndexInt) + }); err != nil { + t.Fatal(err) + } + var res []string + if err := db.View(func(tx *Tx) error { + return tx.AscendEqual("", "key:00055A", func(key, value string) bool { + res = append(res, key) + return true + }) + }); err != nil { + t.Fatal(err) + } + if len(res) != 1 { + t.Fatalf("expected %v, got %v", 1, len(res)) + } + if res[0] != "key:00055A" { + t.Fatalf("expected %v, got %v", "key:00055A", res[0]) + } + res = nil + if err := db.View(func(tx *Tx) error { + return tx.AscendEqual("num", "1125", func(key, value string) bool { + res = append(res, key) + return true + }) + }); err != nil { + t.Fatal(err) + } + if len(res) != 2 { + t.Fatalf("expected %v, got %v", 2, len(res)) + } + if res[0] != "key:00125A" { + t.Fatalf("expected %v, got %v", "key:00125A", res[0]) + } + if res[1] != "key:00125B" { + t.Fatalf("expected %v, got %v", "key:00125B", res[1]) + } +} +func TestDescendEqual(t *testing.T) { + db := testOpen(t) + defer testClose(db) + if err := db.Update(func(tx *Tx) error { + for i := 0; i < 300; i++ { + _, _, err := tx.Set(fmt.Sprintf("key:%05dA", i), fmt.Sprintf("%d", i+1000), nil) + if err != nil { + return err + } + _, _, err = tx.Set(fmt.Sprintf("key:%05dB", i), fmt.Sprintf("%d", i+1000), nil) + if err != nil { + return err + } + } + return tx.CreateIndex("num", "*", IndexInt) + }); err != nil { + t.Fatal(err) + } + var res []string + if err := db.View(func(tx *Tx) error { + return tx.DescendEqual("", "key:00055A", func(key, value string) bool { + res = append(res, key) + return true + }) + }); err != nil { + t.Fatal(err) + } + if len(res) != 1 { + t.Fatalf("expected %v, got %v", 1, len(res)) + } + if res[0] != "key:00055A" { + t.Fatalf("expected %v, got %v", "key:00055A", res[0]) + } + res = nil + if err := db.View(func(tx *Tx) error { + return tx.DescendEqual("num", "1125", func(key, value string) bool { + res = append(res, key) + return true + }) + }); err != nil { + t.Fatal(err) + } + if len(res) != 2 { + t.Fatalf("expected %v, got %v", 2, len(res)) + } + if res[0] != "key:00125B" { + t.Fatalf("expected %v, got %v", "key:00125B", res[0]) + } + if res[1] != "key:00125A" { + t.Fatalf("expected %v, got %v", "key:00125A", res[1]) + } +} func TestVariousTx(t *testing.T) { db := testOpen(t) defer testClose(db)