diff --git a/buntdb.go b/buntdb.go index 1be4e67..ae5f74d 100644 --- a/buntdb.go +++ b/buntdb.go @@ -247,6 +247,23 @@ type index struct { less func(a, b string) bool // less comparison function rect func(item string) (min, max []float64) // rect from string function db *DB // the origin database + opts IndexOptions // index options +} + +// match matches the pattern to the key +func (idx *index) match(key string) bool { + if idx.pattern == "*" { + return true + } + if idx.opts.CaseInsensitiveKeyMatching { + for i := 0; i < len(key); i++ { + if key[i] >= 'A' && key[i] <= 'Z' { + key = strings.ToLower(key) + break + } + } + } + return match.Match(key, idx.pattern) } // clearCopy creates a copy of the index, but with an empty dataset. @@ -258,6 +275,7 @@ func (idx *index) clearCopy() *index { db: idx.db, less: idx.less, rect: idx.rect, + opts: idx.opts, } // initialize with empty trees if nidx.less != nil { @@ -281,7 +299,7 @@ func (idx *index) rebuild() { // iterate through all keys and fill the index idx.db.keys.Ascend(func(item btree.Item) bool { dbi := item.(*dbItem) - if idx.pattern != "*" && !match.Match(dbi.key, idx.pattern) { + if !idx.match(dbi.key) { // does not match the pattern, conintue return true } @@ -470,7 +488,7 @@ func (db *DB) insertIntoDatabase(item *dbItem) *dbItem { db.exps.ReplaceOrInsert(item) } for _, idx := range db.idxs { - if !match.Match(item.key, idx.pattern) { + if !idx.match(item.key) { continue } if idx.btr != nil { @@ -1745,6 +1763,14 @@ func (tx *Tx) Len() (int, error) { return tx.db.keys.Len(), nil } +// IndexOptions provides an index with addtional features or +// alternate functionality. +type IndexOptions struct { + // CaseInsensitiveKeyMatching allow for case-insensitive + // matching on keys when setting key/values. + CaseInsensitiveKeyMatching bool +} + // CreateIndex builds a new index and populates it with items. // The items are ordered in an b-tree and can be retrieved using the // Ascend* and Descend* methods. @@ -1762,7 +1788,15 @@ func (tx *Tx) Len() (int, error) { // IndexString, IndexBinary, etc. func (tx *Tx) CreateIndex(name, pattern string, less ...func(a, b string) bool) error { - return tx.createIndex(name, pattern, less, nil) + return tx.createIndex(name, pattern, less, nil, nil) +} + +// CreateIndexOptions is the same as CreateIndex except that it allows +// for additonal options. +func (tx *Tx) CreateIndexOptions(name, pattern string, + opts *IndexOptions, + less ...func(a, b string) bool) error { + return tx.createIndex(name, pattern, less, nil, opts) } // CreateSpatialIndex builds a new index and populates it with items. @@ -1781,13 +1815,22 @@ func (tx *Tx) CreateIndex(name, pattern string, // parameter. func (tx *Tx) CreateSpatialIndex(name, pattern string, rect func(item string) (min, max []float64)) error { - return tx.createIndex(name, pattern, nil, rect) + return tx.createIndex(name, pattern, nil, rect, nil) +} + +// CreateSpatialIndexOptions is the same as CreateSpatialIndex except that +// it allows for additonal options. +func (tx *Tx) CreateSpatialIndexOptions(name, pattern string, + opts *IndexOptions, + rect func(item string) (min, max []float64)) error { + return tx.createIndex(name, pattern, nil, rect, nil) } // createIndex is called by CreateIndex() and CreateSpatialIndex() func (tx *Tx) createIndex(name string, pattern string, lessers []func(a, b string) bool, rect func(item string) (min, max []float64), + opts *IndexOptions, ) error { if tx.db == nil { return ErrTxClosed @@ -1828,6 +1871,13 @@ func (tx *Tx) createIndex(name string, pattern string, case 1: less = lessers[0] } + var sopts IndexOptions + if opts != nil { + sopts = *opts + } + if sopts.CaseInsensitiveKeyMatching { + pattern = strings.ToLower(pattern) + } // intialize new index idx := &index{ name: name, @@ -1835,6 +1885,7 @@ func (tx *Tx) createIndex(name string, pattern string, less: less, rect: rect, db: tx.db, + opts: sopts, } idx.rebuild() // save the index diff --git a/buntdb_test.go b/buntdb_test.go index 90732c9..3a81977 100644 --- a/buntdb_test.go +++ b/buntdb_test.go @@ -144,6 +144,7 @@ func TestSaveLoad(t *testing.T) { t.Fatal(err) } } + func TestMutatingIterator(t *testing.T) { db := testOpen(t) defer testClose(db) @@ -181,9 +182,54 @@ func TestMutatingIterator(t *testing.T) { }); err != nil { t.Fatal(err) } - } } + +func TestCaseInsensitiveIndex(t *testing.T) { + db := testOpen(t) + defer testClose(db) + count := 1000 + if err := db.Update(func(tx *Tx) error { + opts := &IndexOptions{ + CaseInsensitiveKeyMatching: true, + } + return tx.CreateIndexOptions("ages", "User:*:age", opts, IndexInt) + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *Tx) error { + for j := 0; j < count; j++ { + key := fmt.Sprintf("user:%d:age", j) + val := fmt.Sprintf("%d", rand.Intn(100)) + if _, _, err := tx.Set(key, val, nil); err != nil { + return err + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *Tx) error { + var vals []string + err := tx.Ascend("ages", func(key, value string) bool { + vals = append(vals, value) + return true + }) + if err != nil { + return err + } + if len(vals) != count { + return fmt.Errorf("expected '%v', got '%v'", count, len(vals)) + } + return nil + }); err != nil { + t.Fatal(err) + } + +} + func TestIndexTransaction(t *testing.T) { db := testOpen(t) defer testClose(db)