// Package buntdb implements a low-level in-memory key/value store in pure Go. // It persists to disk, is ACID compliant, and uses locking for multiple // readers and a single writer. Bunt is ideal for projects that need a // dependable database, and favor speed over data size. package buntdb import ( "bufio" "errors" "fmt" "io" "os" "sort" "strconv" "strings" "sync" "time" "github.com/tidwall/btree" "github.com/tidwall/gjson" "github.com/tidwall/grect" "github.com/tidwall/match" "github.com/tidwall/rtred" ) var ( // ErrTxNotWritable is returned when performing a write operation on a // read-only transaction. ErrTxNotWritable = errors.New("tx not writable") // ErrTxClosed is returned when committing or rolling back a transaction // that has already been committed or rolled back. ErrTxClosed = errors.New("tx closed") // ErrNotFound is returned when an item or index is not in the database. ErrNotFound = errors.New("not found") // ErrInvalid is returned when the database file is an invalid format. ErrInvalid = errors.New("invalid database") // ErrDatabaseClosed is returned when the database is closed. ErrDatabaseClosed = errors.New("database closed") // ErrIndexExists is returned when an index already exists in the database. ErrIndexExists = errors.New("index exists") // ErrInvalidOperation is returned when an operation cannot be completed. ErrInvalidOperation = errors.New("invalid operation") // ErrInvalidSyncPolicy is returned for an invalid SyncPolicy value. ErrInvalidSyncPolicy = errors.New("invalid sync policy") // ErrShrinkInProcess is returned when a shrink operation is in-process. ErrShrinkInProcess = errors.New("shrink is in-process") // ErrPersistenceActive is returned when post-loading data from an database // not opened with Open(":memory:"). ErrPersistenceActive = errors.New("persistence active") // ErrTxIterating is returned when Set or Delete are called while iterating. ErrTxIterating = errors.New("tx is iterating") ) const useAbsEx = true // DB represents a collection of key-value pairs that persist on disk. // Transactions are used for all forms of data access to the DB. type DB struct { mu sync.RWMutex // the gatekeeper for all fields file *os.File // the underlying file buf []byte // a buffer to write to keys *btree.BTree // a tree of all item ordered by key exps *btree.BTree // a tree of items ordered by expiration idxs map[string]*index // the index trees. insIdxs []*index // a reuse buffer for gathering indexes flushes int // a count of the number of disk flushes closed bool // set when the database has been closed config Config // the database configuration persist bool // do we write to disk shrinking bool // when an aof shrink is in-process. lastaofsz int // the size of the last shrink aof size } // SyncPolicy represents how often data is synced to disk. type SyncPolicy int const ( // Never is used to disable syncing data to disk. // The faster and less safe method. Never SyncPolicy = 0 // EverySecond is used to sync data to disk every second. // It's pretty fast and you can lose 1 second of data if there // is a disaster. // This is the recommended setting. EverySecond = 1 // Always is used to sync data after every write to disk. // Slow. Very safe. Always = 2 ) // Config represents database configuration options. These // options are used to change various behaviors of the database. type Config struct { // SyncPolicy adjusts how often the data is synced to disk. // This value can be Never, EverySecond, or Always. // The default is EverySecond. SyncPolicy SyncPolicy // AutoShrinkPercentage is used by the background process to trigger // a shrink of the aof file when the size of the file is larger than the // percentage of the result of the previous shrunk file. // For example, if this value is 100, and the last shrink process // resulted in a 100mb file, then the new aof file must be 200mb before // a shrink is triggered. AutoShrinkPercentage int // AutoShrinkMinSize defines the minimum size of the aof file before // an automatic shrink can occur. AutoShrinkMinSize int // AutoShrinkDisabled turns off automatic background shrinking AutoShrinkDisabled bool // OnExpired is used to custom handle the deletion option when a key // has been expired. OnExpired func(keys []string) // OnExpiredSync will be called inside the same transaction that is // performing the deletion of expired items. If OnExpired is present then // this callback will not be called. If this callback is present, then the // deletion of the timeed-out item is the explicit responsibility of this // callback. OnExpiredSync func(key, value string, tx *Tx) error } // exctx is a simple b-tree context for ordering by expiration. type exctx struct { db *DB } // Open opens a database at the provided path. // If the file does not exist then it will be created automatically. func Open(path string) (*DB, error) { db := &DB{} // initialize trees and indexes db.keys = btreeNew(lessCtx(nil)) db.exps = btreeNew(lessCtx(&exctx{db})) db.idxs = make(map[string]*index) // initialize default configuration db.config = Config{ SyncPolicy: EverySecond, AutoShrinkPercentage: 100, AutoShrinkMinSize: 32 * 1024 * 1024, } // turn off persistence for pure in-memory db.persist = path != ":memory:" if db.persist { var err error // hardcoding 0666 as the default mode. db.file, err = os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0666) if err != nil { return nil, err } // load the database from disk if err := db.load(); err != nil { // close on error, ignore close error _ = db.file.Close() return nil, err } } // start the background manager. go db.backgroundManager() return db, nil } // Close releases all database resources. // All transactions must be closed before closing the database. func (db *DB) Close() error { db.mu.Lock() defer db.mu.Unlock() if db.closed { return ErrDatabaseClosed } db.closed = true if db.persist { db.file.Sync() // do a sync but ignore the error if err := db.file.Close(); err != nil { return err } } // Let's release all references to nil. This will help both with debugging // late usage panics and it provides a hint to the garbage collector db.keys, db.exps, db.idxs, db.file = nil, nil, nil, nil return nil } // Save writes a snapshot of the database to a writer. This operation blocks all // writes, but not reads. This can be used for snapshots and backups for pure // in-memory databases using the ":memory:". Database that persist to disk // can be snapshotted by simply copying the database file. func (db *DB) Save(wr io.Writer) error { var err error db.mu.RLock() defer db.mu.RUnlock() // use a buffered writer and flush every 4MB var buf []byte now := time.Now() // iterated through every item in the database and write to the buffer btreeAscend(db.keys, func(item interface{}) bool { dbi := item.(*dbItem) buf = dbi.writeSetTo(buf, now) if len(buf) > 1024*1024*4 { // flush when buffer is over 4MB _, err = wr.Write(buf) if err != nil { return false } buf = buf[:0] } return true }) if err != nil { return err } // one final flush if len(buf) > 0 { _, err = wr.Write(buf) if err != nil { return err } } return nil } // Load loads commands from reader. This operation blocks all reads and writes. // Note that this can only work for fully in-memory databases opened with // Open(":memory:"). func (db *DB) Load(rd io.Reader) error { db.mu.Lock() defer db.mu.Unlock() if db.persist { // cannot load into databases that persist to disk return ErrPersistenceActive } _, err := db.readLoad(rd, time.Now()) return err } // index represents a b-tree or r-tree index and also acts as the // b-tree/r-tree context for itself. type index struct { btr *btree.BTree // contains the items rtr *rtred.RTree // contains the items name string // name of the index pattern string // a required key pattern less func(a, b string) bool // less comparison function rect func(item string) (min, max []float64) // rect from string function db *DB // the origin database opts IndexOptions // index options } // match matches the pattern to the key func (idx *index) match(key string) bool { if idx.pattern == "*" { return true } if idx.opts.CaseInsensitiveKeyMatching { for i := 0; i < len(key); i++ { if key[i] >= 'A' && key[i] <= 'Z' { key = strings.ToLower(key) break } } } return match.Match(key, idx.pattern) } // clearCopy creates a copy of the index, but with an empty dataset. func (idx *index) clearCopy() *index { // copy the index meta information nidx := &index{ name: idx.name, pattern: idx.pattern, db: idx.db, less: idx.less, rect: idx.rect, opts: idx.opts, } // initialize with empty trees if nidx.less != nil { nidx.btr = btreeNew(lessCtx(nidx)) } if nidx.rect != nil { nidx.rtr = rtred.New(nidx) } return nidx } // rebuild rebuilds the index func (idx *index) rebuild() { // initialize trees if idx.less != nil { idx.btr = btreeNew(lessCtx(idx)) } if idx.rect != nil { idx.rtr = rtred.New(idx) } // iterate through all keys and fill the index btreeAscend(idx.db.keys, func(item interface{}) bool { dbi := item.(*dbItem) if !idx.match(dbi.key) { // does not match the pattern, continue return true } if idx.less != nil { idx.btr.Set(dbi) } if idx.rect != nil { idx.rtr.Insert(dbi) } return true }) } // CreateIndex builds a new index and populates it with items. // The items are ordered in an b-tree and can be retrieved using the // Ascend* and Descend* methods. // An error will occur if an index with the same name already exists. // // When a pattern is provided, the index will be populated with // keys that match the specified pattern. This is a very simple pattern // match where '*' matches on any number characters and '?' matches on // any one character. // The less function compares if string 'a' is less than string 'b'. // It allows for indexes to create custom ordering. It's possible // that the strings may be textual or binary. It's up to the provided // less function to handle the content format and comparison. // There are some default less function that can be used such as // IndexString, IndexBinary, etc. func (db *DB) CreateIndex(name, pattern string, less ...func(a, b string) bool) error { return db.Update(func(tx *Tx) error { return tx.CreateIndex(name, pattern, less...) }) } // ReplaceIndex builds a new index and populates it with items. // The items are ordered in an b-tree and can be retrieved using the // Ascend* and Descend* methods. // If a previous index with the same name exists, that index will be deleted. func (db *DB) ReplaceIndex(name, pattern string, less ...func(a, b string) bool) error { return db.Update(func(tx *Tx) error { err := tx.CreateIndex(name, pattern, less...) if err != nil { if err == ErrIndexExists { err := tx.DropIndex(name) if err != nil { return err } return tx.CreateIndex(name, pattern, less...) } return err } return nil }) } // CreateSpatialIndex builds a new index and populates it with items. // The items are organized in an r-tree and can be retrieved using the // Intersects method. // An error will occur if an index with the same name already exists. // // The rect function converts a string to a rectangle. The rectangle is // represented by two arrays, min and max. Both arrays may have a length // between 1 and 20, and both arrays must match in length. A length of 1 is a // one dimensional rectangle, and a length of 4 is a four dimension rectangle. // There is support for up to 20 dimensions. // The values of min must be less than the values of max at the same dimension. // Thus min[0] must be less-than-or-equal-to max[0]. // The IndexRect is a default function that can be used for the rect // parameter. func (db *DB) CreateSpatialIndex(name, pattern string, rect func(item string) (min, max []float64)) error { return db.Update(func(tx *Tx) error { return tx.CreateSpatialIndex(name, pattern, rect) }) } // ReplaceSpatialIndex builds a new index and populates it with items. // The items are organized in an r-tree and can be retrieved using the // Intersects method. // If a previous index with the same name exists, that index will be deleted. func (db *DB) ReplaceSpatialIndex(name, pattern string, rect func(item string) (min, max []float64)) error { return db.Update(func(tx *Tx) error { err := tx.CreateSpatialIndex(name, pattern, rect) if err != nil { if err == ErrIndexExists { err := tx.DropIndex(name) if err != nil { return err } return tx.CreateSpatialIndex(name, pattern, rect) } return err } return nil }) } // DropIndex removes an index. func (db *DB) DropIndex(name string) error { return db.Update(func(tx *Tx) error { return tx.DropIndex(name) }) } // Indexes returns a list of index names. func (db *DB) Indexes() ([]string, error) { var names []string var err = db.View(func(tx *Tx) error { var err error names, err = tx.Indexes() return err }) return names, err } // ReadConfig returns the database configuration. func (db *DB) ReadConfig(config *Config) error { db.mu.RLock() defer db.mu.RUnlock() if db.closed { return ErrDatabaseClosed } *config = db.config return nil } // SetConfig updates the database configuration. func (db *DB) SetConfig(config Config) error { db.mu.Lock() defer db.mu.Unlock() if db.closed { return ErrDatabaseClosed } switch config.SyncPolicy { default: return ErrInvalidSyncPolicy case Never, EverySecond, Always: } db.config = config return nil } // insertIntoDatabase performs inserts an item in to the database and updates // all indexes. If a previous item with the same key already exists, that item // will be replaced with the new one, and return the previous item. func (db *DB) insertIntoDatabase(item *dbItem) *dbItem { var pdbi *dbItem // Generate a list of indexes that this item will be inserted in to. idxs := db.insIdxs for _, idx := range db.idxs { if idx.match(item.key) { idxs = append(idxs, idx) } } prev := db.keys.Set(item) if prev != nil { // A previous item was removed from the keys tree. Let's // fully delete this item from all indexes. pdbi = prev.(*dbItem) if pdbi.opts != nil && pdbi.opts.ex { // Remove it from the expires tree. db.exps.Delete(pdbi) } for _, idx := range idxs { if idx.btr != nil { // Remove it from the btree index. idx.btr.Delete(pdbi) } if idx.rtr != nil { // Remove it from the rtree index. idx.rtr.Remove(pdbi) } } } if item.opts != nil && item.opts.ex { // The new item has eviction options. Add it to the // expires tree db.exps.Set(item) } for i, idx := range idxs { if idx.btr != nil { // Add new item to btree index. idx.btr.Set(item) } if idx.rtr != nil { // Add new item to rtree index. idx.rtr.Insert(item) } // clear the index idxs[i] = nil } // reuse the index list slice db.insIdxs = idxs[:0] // we must return the previous item to the caller. return pdbi } // deleteFromDatabase removes and item from the database and indexes. The input // item must only have the key field specified thus "&dbItem{key: key}" is all // that is needed to fully remove the item with the matching key. If an item // with the matching key was found in the database, it will be removed and // returned to the caller. A nil return value means that the item was not // found in the database func (db *DB) deleteFromDatabase(item *dbItem) *dbItem { var pdbi *dbItem prev := db.keys.Delete(item) if prev != nil { pdbi = prev.(*dbItem) if pdbi.opts != nil && pdbi.opts.ex { // Remove it from the exipres tree. db.exps.Delete(pdbi) } for _, idx := range db.idxs { if !idx.match(pdbi.key) { continue } if idx.btr != nil { // Remove it from the btree index. idx.btr.Delete(pdbi) } if idx.rtr != nil { // Remove it from the rtree index. idx.rtr.Remove(pdbi) } } } return pdbi } // backgroundManager runs continuously in the background and performs various // operations such as removing expired items and syncing to disk. func (db *DB) backgroundManager() { flushes := 0 t := time.NewTicker(time.Second) defer t.Stop() for range t.C { var shrink bool // Open a standard view. This will take a full lock of the // database thus allowing for access to anything we need. var onExpired func([]string) var expired []*dbItem var onExpiredSync func(key, value string, tx *Tx) error err := db.Update(func(tx *Tx) error { onExpired = db.config.OnExpired if onExpired == nil { onExpiredSync = db.config.OnExpiredSync } if db.persist && !db.config.AutoShrinkDisabled { pos, err := db.file.Seek(0, 1) if err != nil { return err } aofsz := int(pos) if aofsz > db.config.AutoShrinkMinSize { prc := float64(db.config.AutoShrinkPercentage) / 100.0 shrink = aofsz > db.lastaofsz+int(float64(db.lastaofsz)*prc) } } // produce a list of expired items that need removing btreeAscendLessThan(db.exps, &dbItem{ opts: &dbItemOpts{ex: true, exat: time.Now()}, }, func(item interface{}) bool { expired = append(expired, item.(*dbItem)) return true }) if onExpired == nil && onExpiredSync == nil { for _, itm := range expired { if _, err := tx.Delete(itm.key); err != nil { // it's ok to get a "not found" because the // 'Delete' method reports "not found" for // expired items. if err != ErrNotFound { return err } } } } else if onExpiredSync != nil { for _, itm := range expired { if err := onExpiredSync(itm.key, itm.val, tx); err != nil { return err } } } return nil }) if err == ErrDatabaseClosed { break } // send expired event, if needed if onExpired != nil && len(expired) > 0 { keys := make([]string, 0, 32) for _, itm := range expired { keys = append(keys, itm.key) } onExpired(keys) } // execute a disk sync, if needed func() { db.mu.Lock() defer db.mu.Unlock() if db.persist && db.config.SyncPolicy == EverySecond && flushes != db.flushes { _ = db.file.Sync() flushes = db.flushes } }() if shrink { if err = db.Shrink(); err != nil { if err == ErrDatabaseClosed { break } } } } } // Shrink will make the database file smaller by removing redundant // log entries. This operation does not block the database. func (db *DB) Shrink() error { db.mu.Lock() if db.closed { db.mu.Unlock() return ErrDatabaseClosed } if !db.persist { // The database was opened with ":memory:" as the path. // There is no persistence, and no need to do anything here. db.mu.Unlock() return nil } if db.shrinking { // The database is already in the process of shrinking. db.mu.Unlock() return ErrShrinkInProcess } db.shrinking = true defer func() { db.mu.Lock() db.shrinking = false db.mu.Unlock() }() fname := db.file.Name() tmpname := fname + ".tmp" // the endpos is used to return to the end of the file when we are // finished writing all of the current items. endpos, err := db.file.Seek(0, 2) if err != nil { return err } db.mu.Unlock() time.Sleep(time.Second / 4) // wait just a bit before starting f, err := os.Create(tmpname) if err != nil { return err } defer func() { _ = f.Close() _ = os.RemoveAll(tmpname) }() // we are going to read items in as chunks as to not hold up the database // for too long. var buf []byte pivot := "" done := false for !done { err := func() error { db.mu.RLock() defer db.mu.RUnlock() if db.closed { return ErrDatabaseClosed } done = true var n int now := time.Now() btreeAscendGreaterOrEqual(db.keys, &dbItem{key: pivot}, func(item interface{}) bool { dbi := item.(*dbItem) // 1000 items or 64MB buffer if n > 1000 || len(buf) > 64*1024*1024 { pivot = dbi.key done = false return false } buf = dbi.writeSetTo(buf, now) n++ return true }, ) if len(buf) > 0 { if _, err := f.Write(buf); err != nil { return err } buf = buf[:0] } return nil }() if err != nil { return err } } // We reached this far so all of the items have been written to a new tmp // There's some more work to do by appending the new line from the aof // to the tmp file and finally swap the files out. return func() error { // We're wrapping this in a function to get the benefit of a defered // lock/unlock. db.mu.Lock() defer db.mu.Unlock() if db.closed { return ErrDatabaseClosed } // We are going to open a new version of the aof file so that we do // not change the seek position of the previous. This may cause a // problem in the future if we choose to use syscall file locking. aof, err := os.Open(fname) if err != nil { return err } defer func() { _ = aof.Close() }() if _, err := aof.Seek(endpos, 0); err != nil { return err } // Just copy all of the new commands that have occurred since we // started the shrink process. if _, err := io.Copy(f, aof); err != nil { return err } // Close all files if err := aof.Close(); err != nil { return err } if err := f.Close(); err != nil { return err } if err := db.file.Close(); err != nil { return err } // Any failures below here are really bad. So just panic. if err := os.Rename(tmpname, fname); err != nil { panicErr(err) } db.file, err = os.OpenFile(fname, os.O_CREATE|os.O_RDWR, 0666) if err != nil { panicErr(err) } pos, err := db.file.Seek(0, 2) if err != nil { return err } db.lastaofsz = int(pos) return nil }() } func panicErr(err error) error { panic(fmt.Errorf("buntdb: %w", err)) } // readLoad reads from the reader and loads commands into the database. // modTime is the modified time of the reader, should be no greater than // the current time.Now(). // Returns the number of bytes of the last command read and the error if any. func (db *DB) readLoad(rd io.Reader, modTime time.Time) (n int64, err error) { defer func() { if err == io.EOF { err = io.ErrUnexpectedEOF } }() totalSize := int64(0) data := make([]byte, 4096) parts := make([]string, 0, 8) r := bufio.NewReader(rd) for { // peek at the first byte. If it's a 'nul' control character then // ignore it and move to the next byte. c, err := r.ReadByte() if err != nil { if err == io.EOF { err = nil } return totalSize, err } if c == 0 { // ignore nul control characters n += 1 continue } if err := r.UnreadByte(); err != nil { return totalSize, err } // read a single command. // first we should read the number of parts that the of the command cmdByteSize := int64(0) line, err := r.ReadBytes('\n') if err != nil { return totalSize, err } if line[0] != '*' { return totalSize, ErrInvalid } cmdByteSize += int64(len(line)) // convert the string number to and int var n int if len(line) == 4 && line[len(line)-2] == '\r' { if line[1] < '0' || line[1] > '9' { return totalSize, ErrInvalid } n = int(line[1] - '0') } else { if len(line) < 5 || line[len(line)-2] != '\r' { return totalSize, ErrInvalid } for i := 1; i < len(line)-2; i++ { if line[i] < '0' || line[i] > '9' { return totalSize, ErrInvalid } n = n*10 + int(line[i]-'0') } } // read each part of the command. parts = parts[:0] for i := 0; i < n; i++ { // read the number of bytes of the part. line, err := r.ReadBytes('\n') if err != nil { return totalSize, err } if line[0] != '$' { return totalSize, ErrInvalid } cmdByteSize += int64(len(line)) // convert the string number to and int var n int if len(line) == 4 && line[len(line)-2] == '\r' { if line[1] < '0' || line[1] > '9' { return totalSize, ErrInvalid } n = int(line[1] - '0') } else { if len(line) < 5 || line[len(line)-2] != '\r' { return totalSize, ErrInvalid } for i := 1; i < len(line)-2; i++ { if line[i] < '0' || line[i] > '9' { return totalSize, ErrInvalid } n = n*10 + int(line[i]-'0') } } // resize the read buffer if len(data) < n+2 { dataln := len(data) for dataln < n+2 { dataln *= 2 } data = make([]byte, dataln) } if _, err = io.ReadFull(r, data[:n+2]); err != nil { return totalSize, err } if data[n] != '\r' || data[n+1] != '\n' { return totalSize, ErrInvalid } // copy string parts = append(parts, string(data[:n])) cmdByteSize += int64(n + 2) } // finished reading the command if len(parts) == 0 { continue } if (parts[0][0] == 's' || parts[0][0] == 'S') && (parts[0][1] == 'e' || parts[0][1] == 'E') && (parts[0][2] == 't' || parts[0][2] == 'T') { // SET if len(parts) < 3 || len(parts) == 4 || len(parts) > 5 { return totalSize, ErrInvalid } if len(parts) == 5 { arg := strings.ToLower(parts[3]) if arg != "ex" && arg != "ae" { return totalSize, ErrInvalid } ex, err := strconv.ParseInt(parts[4], 10, 64) if err != nil { return totalSize, err } var exat time.Time now := time.Now() if arg == "ex" { dur := (time.Duration(ex) * time.Second) - now.Sub(modTime) exat = now.Add(dur) } else { exat = time.Unix(ex, 0) } if exat.After(now) { db.insertIntoDatabase(&dbItem{ key: parts[1], val: parts[2], opts: &dbItemOpts{ ex: true, exat: exat, }, }) } } else { db.insertIntoDatabase(&dbItem{key: parts[1], val: parts[2]}) } } else if (parts[0][0] == 'd' || parts[0][0] == 'D') && (parts[0][1] == 'e' || parts[0][1] == 'E') && (parts[0][2] == 'l' || parts[0][2] == 'L') { // DEL if len(parts) != 2 { return totalSize, ErrInvalid } db.deleteFromDatabase(&dbItem{key: parts[1]}) } else if (parts[0][0] == 'f' || parts[0][0] == 'F') && strings.ToLower(parts[0]) == "flushdb" { db.keys = btreeNew(lessCtx(nil)) db.exps = btreeNew(lessCtx(&exctx{db})) db.idxs = make(map[string]*index) } else { return totalSize, ErrInvalid } totalSize += cmdByteSize } } // load reads entries from the append only database file and fills the database. // The file format uses the Redis append only file format, which is and a series // of RESP commands. For more information on RESP please read // http://redis.io/topics/protocol. The only supported RESP commands are DEL and // SET. func (db *DB) load() error { fi, err := db.file.Stat() if err != nil { return err } n, err := db.readLoad(db.file, fi.ModTime()) if err != nil { if err == io.ErrUnexpectedEOF { // The db file has ended mid-command, which is allowed but the // data file should be truncated to the end of the last valid // command if err := db.file.Truncate(n); err != nil { return err } } else { return err } } if _, err := db.file.Seek(n, 0); err != nil { return err } var estaofsz int db.keys.Walk(func(items []interface{}) { for _, v := range items { estaofsz += v.(*dbItem).estAOFSetSize() } }) db.lastaofsz += estaofsz return nil } // managed calls a block of code that is fully contained in a transaction. // This method is intended to be wrapped by Update and View func (db *DB) managed(writable bool, fn func(tx *Tx) error) (err error) { var tx *Tx tx, err = db.Begin(writable) if err != nil { return } defer func() { if err != nil { // The caller returned an error. We must rollback. _ = tx.Rollback() return } if writable { // Everything went well. Lets Commit() err = tx.Commit() } else { // read-only transaction can only roll back. err = tx.Rollback() } }() tx.funcd = true defer func() { tx.funcd = false }() err = fn(tx) return } // View executes a function within a managed read-only transaction. // When a non-nil error is returned from the function that error will be return // to the caller of View(). // // Executing a manual commit or rollback from inside the function will result // in a panic. func (db *DB) View(fn func(tx *Tx) error) error { return db.managed(false, fn) } // Update executes a function within a managed read/write transaction. // The transaction has been committed when no error is returned. // In the event that an error is returned, the transaction will be rolled back. // When a non-nil error is returned from the function, the transaction will be // rolled back and the that error will be return to the caller of Update(). // // Executing a manual commit or rollback from inside the function will result // in a panic. func (db *DB) Update(fn func(tx *Tx) error) error { return db.managed(true, fn) } // get return an item or nil if not found. func (db *DB) get(key string) *dbItem { item := db.keys.Get(&dbItem{key: key}) if item != nil { return item.(*dbItem) } return nil } // Tx represents a transaction on the database. This transaction can either be // read-only or read/write. Read-only transactions can be used for retrieving // values for keys and iterating through keys and values. Read/write // transactions can set and delete keys. // // All transactions must be committed or rolled-back when done. type Tx struct { db *DB // the underlying database. writable bool // when false mutable operations fail. funcd bool // when true Commit and Rollback panic. wc *txWriteContext // context for writable transactions. } type txWriteContext struct { // rollback when deleteAll is called rbkeys *btree.BTree // a tree of all item ordered by key rbexps *btree.BTree // a tree of items ordered by expiration rbidxs map[string]*index // the index trees. rollbackItems map[string]*dbItem // details for rolling back tx. commitItems map[string]*dbItem // details for committing tx. itercount int // stack of iterators rollbackIndexes map[string]*index // details for dropped indexes. } // DeleteAll deletes all items from the database. func (tx *Tx) DeleteAll() error { if tx.db == nil { return ErrTxClosed } else if !tx.writable { return ErrTxNotWritable } else if tx.wc.itercount > 0 { return ErrTxIterating } // check to see if we've already deleted everything if tx.wc.rbkeys == nil { // we need to backup the live data in case of a rollback. tx.wc.rbkeys = tx.db.keys tx.wc.rbexps = tx.db.exps tx.wc.rbidxs = tx.db.idxs } // now reset the live database trees tx.db.keys = btreeNew(lessCtx(nil)) tx.db.exps = btreeNew(lessCtx(&exctx{tx.db})) tx.db.idxs = make(map[string]*index) // finally re-create the indexes for name, idx := range tx.wc.rbidxs { tx.db.idxs[name] = idx.clearCopy() } // always clear out the commits tx.wc.commitItems = make(map[string]*dbItem) return nil } // Begin opens a new transaction. // Multiple read-only transactions can be opened at the same time but there can // only be one read/write transaction at a time. Attempting to open a read/write // transactions while another one is in progress will result in blocking until // the current read/write transaction is completed. // // All transactions must be closed by calling Commit() or Rollback() when done. func (db *DB) Begin(writable bool) (*Tx, error) { tx := &Tx{ db: db, writable: writable, } tx.lock() if db.closed { tx.unlock() return nil, ErrDatabaseClosed } if writable { // writable transactions have a writeContext object that // contains information about changes to the database. tx.wc = &txWriteContext{} tx.wc.rollbackItems = make(map[string]*dbItem) tx.wc.rollbackIndexes = make(map[string]*index) if db.persist { tx.wc.commitItems = make(map[string]*dbItem) } } return tx, nil } // lock locks the database based on the transaction type. func (tx *Tx) lock() { if tx.writable { tx.db.mu.Lock() } else { tx.db.mu.RLock() } } // unlock unlocks the database based on the transaction type. func (tx *Tx) unlock() { if tx.writable { tx.db.mu.Unlock() } else { tx.db.mu.RUnlock() } } // rollbackInner handles the underlying rollback logic. // Intended to be called from Commit() and Rollback(). func (tx *Tx) rollbackInner() { // rollback the deleteAll if needed if tx.wc.rbkeys != nil { tx.db.keys = tx.wc.rbkeys tx.db.idxs = tx.wc.rbidxs tx.db.exps = tx.wc.rbexps } for key, item := range tx.wc.rollbackItems { tx.db.deleteFromDatabase(&dbItem{key: key}) if item != nil { // When an item is not nil, we will need to reinsert that item // into the database overwriting the current one. tx.db.insertIntoDatabase(item) } } for name, idx := range tx.wc.rollbackIndexes { delete(tx.db.idxs, name) if idx != nil { // When an index is not nil, we will need to rebuilt that index // this could be an expensive process if the database has many // items or the index is complex. tx.db.idxs[name] = idx idx.rebuild() } } } // Commit writes all changes to disk. // An error is returned when a write error occurs, or when a Commit() is called // from a read-only transaction. func (tx *Tx) Commit() error { if tx.funcd { panic("managed tx commit not allowed") } if tx.db == nil { return ErrTxClosed } else if !tx.writable { return ErrTxNotWritable } var err error if tx.db.persist && (len(tx.wc.commitItems) > 0 || tx.wc.rbkeys != nil) { tx.db.buf = tx.db.buf[:0] // write a flushdb if a deleteAll was called. if tx.wc.rbkeys != nil { tx.db.buf = append(tx.db.buf, "*1\r\n$7\r\nflushdb\r\n"...) } now := time.Now() // Each committed record is written to disk for key, item := range tx.wc.commitItems { if item == nil { tx.db.buf = (&dbItem{key: key}).writeDeleteTo(tx.db.buf) } else { tx.db.buf = item.writeSetTo(tx.db.buf, now) } } // Flushing the buffer only once per transaction. // If this operation fails then the write did failed and we must // rollback. var n int n, err = tx.db.file.Write(tx.db.buf) if err != nil { if n > 0 { // There was a partial write to disk. // We are possibly out of disk space. // Delete the partially written bytes from the data file by // seeking to the previously known position and performing // a truncate operation. // At this point a syscall failure is fatal and the process // should be killed to avoid corrupting the file. pos, err := tx.db.file.Seek(-int64(n), 1) if err != nil { panicErr(err) } if err := tx.db.file.Truncate(pos); err != nil { panicErr(err) } } tx.rollbackInner() } if tx.db.config.SyncPolicy == Always { _ = tx.db.file.Sync() } // Increment the number of flushes. The background syncing uses this. tx.db.flushes++ } // Unlock the database and allow for another writable transaction. tx.unlock() // Clear the db field to disable this transaction from future use. tx.db = nil return err } // Rollback closes the transaction and reverts all mutable operations that // were performed on the transaction such as Set() and Delete(). // // Read-only transactions can only be rolled back, not committed. func (tx *Tx) Rollback() error { if tx.funcd { panic("managed tx rollback not allowed") } if tx.db == nil { return ErrTxClosed } // The rollback func does the heavy lifting. if tx.writable { tx.rollbackInner() } // unlock the database for more transactions. tx.unlock() // Clear the db field to disable this transaction from future use. tx.db = nil return nil } // dbItemOpts holds various meta information about an item. type dbItemOpts struct { ex bool // does this item expire? exat time.Time // when does this item expire? } type dbItem struct { key, val string // the binary key and value opts *dbItemOpts // optional meta information keyless bool // keyless item for scanning } // estIntSize returns the string representions size. // Has the same result as len(strconv.Itoa(x)). func estIntSize(x int) int { n := 1 if x < 0 { n++ x *= -1 } for x >= 10 { n++ x /= 10 } return n } func estArraySize(count int) int { return 1 + estIntSize(count) + 2 } func estBulkStringSize(s string) int { return 1 + estIntSize(len(s)) + 2 + len(s) + 2 } // estAOFSetSize returns an estimated number of bytes that this item will use // when stored in the aof file. func (dbi *dbItem) estAOFSetSize() int { var n int if dbi.opts != nil && dbi.opts.ex { n += estArraySize(5) n += estBulkStringSize("set") n += estBulkStringSize(dbi.key) n += estBulkStringSize(dbi.val) n += estBulkStringSize("ex") n += estBulkStringSize("99") // estimate two byte bulk string } else { n += estArraySize(3) n += estBulkStringSize("set") n += estBulkStringSize(dbi.key) n += estBulkStringSize(dbi.val) } return n } func appendArray(buf []byte, count int) []byte { buf = append(buf, '*') buf = strconv.AppendInt(buf, int64(count), 10) buf = append(buf, '\r', '\n') return buf } func appendBulkString(buf []byte, s string) []byte { buf = append(buf, '$') buf = strconv.AppendInt(buf, int64(len(s)), 10) buf = append(buf, '\r', '\n') buf = append(buf, s...) buf = append(buf, '\r', '\n') return buf } // writeSetTo writes an item as a single SET record to the a bufio Writer. func (dbi *dbItem) writeSetTo(buf []byte, now time.Time) []byte { if dbi.opts != nil && dbi.opts.ex { buf = appendArray(buf, 5) buf = appendBulkString(buf, "set") buf = appendBulkString(buf, dbi.key) buf = appendBulkString(buf, dbi.val) if useAbsEx { ex := dbi.opts.exat.Unix() buf = appendBulkString(buf, "ae") buf = appendBulkString(buf, strconv.FormatUint(uint64(ex), 10)) } else { ex := dbi.opts.exat.Sub(now) / time.Second buf = appendBulkString(buf, "ex") buf = appendBulkString(buf, strconv.FormatUint(uint64(ex), 10)) } } else { buf = appendArray(buf, 3) buf = appendBulkString(buf, "set") buf = appendBulkString(buf, dbi.key) buf = appendBulkString(buf, dbi.val) } return buf } // writeSetTo writes an item as a single DEL record to the a bufio Writer. func (dbi *dbItem) writeDeleteTo(buf []byte) []byte { buf = appendArray(buf, 2) buf = appendBulkString(buf, "del") buf = appendBulkString(buf, dbi.key) return buf } // expired evaluates id the item has expired. This will always return false when // the item does not have `opts.ex` set to true. func (dbi *dbItem) expired() bool { return dbi.opts != nil && dbi.opts.ex && time.Now().After(dbi.opts.exat) } // MaxTime from http://stackoverflow.com/questions/25065055#32620397 // This is a long time in the future. It's an imaginary number that is // used for b-tree ordering. var maxTime = time.Unix(1<<63-62135596801, 999999999) // expiresAt will return the time when the item will expire. When an item does // not expire `maxTime` is used. func (dbi *dbItem) expiresAt() time.Time { if dbi.opts == nil || !dbi.opts.ex { return maxTime } return dbi.opts.exat } // Less determines if a b-tree item is less than another. This is required // for ordering, inserting, and deleting items from a b-tree. It's important // to note that the ctx parameter is used to help with determine which // formula to use on an item. Each b-tree should use a different ctx when // sharing the same item. func (dbi *dbItem) Less(dbi2 *dbItem, ctx interface{}) bool { switch ctx := ctx.(type) { case *exctx: // The expires b-tree formula if dbi2.expiresAt().After(dbi.expiresAt()) { return true } if dbi.expiresAt().After(dbi2.expiresAt()) { return false } case *index: if ctx.less != nil { // Using an index if ctx.less(dbi.val, dbi2.val) { return true } if ctx.less(dbi2.val, dbi.val) { return false } } } // Always fall back to the key comparison. This creates absolute uniqueness. if dbi.keyless { return false } else if dbi2.keyless { return true } return dbi.key < dbi2.key } func lessCtx(ctx interface{}) func(a, b interface{}) bool { return func(a, b interface{}) bool { return a.(*dbItem).Less(b.(*dbItem), ctx) } } // Rect converts a string to a rectangle. // An invalid rectangle will cause a panic. func (dbi *dbItem) Rect(ctx interface{}) (min, max []float64) { switch ctx := ctx.(type) { case *index: return ctx.rect(dbi.val) } return nil, nil } // SetOptions represents options that may be included with the Set() command. type SetOptions struct { // Expires indicates that the Set() key-value will expire Expires bool // TTL is how much time the key-value will exist in the database // before being evicted. The Expires field must also be set to true. // TTL stands for Time-To-Live. TTL time.Duration } // GetLess returns the less function for an index. This is handy for // doing ad-hoc compares inside a transaction. // Returns ErrNotFound if the index is not found or there is no less // function bound to the index func (tx *Tx) GetLess(index string) (func(a, b string) bool, error) { if tx.db == nil { return nil, ErrTxClosed } idx, ok := tx.db.idxs[index] if !ok || idx.less == nil { return nil, ErrNotFound } return idx.less, nil } // GetRect returns the rect function for an index. This is handy for // doing ad-hoc searches inside a transaction. // Returns ErrNotFound if the index is not found or there is no rect // function bound to the index func (tx *Tx) GetRect(index string) (func(s string) (min, max []float64), error) { if tx.db == nil { return nil, ErrTxClosed } idx, ok := tx.db.idxs[index] if !ok || idx.rect == nil { return nil, ErrNotFound } return idx.rect, nil } // Set inserts or replaces an item in the database based on the key. // The opt params may be used for additional functionality such as forcing // the item to be evicted at a specified time. When the return value // for err is nil the operation succeeded. When the return value of // replaced is true, then the operaton replaced an existing item whose // value will be returned through the previousValue variable. // The results of this operation will not be available to other // transactions until the current transaction has successfully committed. // // Only a writable transaction can be used with this operation. // This operation is not allowed during iterations such as Ascend* & Descend*. func (tx *Tx) Set(key, value string, opts *SetOptions) (previousValue string, replaced bool, err error) { if tx.db == nil { return "", false, ErrTxClosed } else if !tx.writable { return "", false, ErrTxNotWritable } else if tx.wc.itercount > 0 { return "", false, ErrTxIterating } item := &dbItem{key: key, val: value} if opts != nil { if opts.Expires { // The caller is requesting that this item expires. Convert the // TTL to an absolute time and bind it to the item. item.opts = &dbItemOpts{ex: true, exat: time.Now().Add(opts.TTL)} } } // Insert the item into the keys tree. prev := tx.db.insertIntoDatabase(item) // insert into the rollback map if there has not been a deleteAll. if tx.wc.rbkeys == nil { if prev == nil { // An item with the same key did not previously exist. Let's // create a rollback entry with a nil value. A nil value indicates // that the entry should be deleted on rollback. When the value is // *not* nil, that means the entry should be reverted. if _, ok := tx.wc.rollbackItems[key]; !ok { tx.wc.rollbackItems[key] = nil } } else { // A previous item already exists in the database. Let's create a // rollback entry with the item as the value. We need to check the // map to see if there isn't already an item that matches the // same key. if _, ok := tx.wc.rollbackItems[key]; !ok { tx.wc.rollbackItems[key] = prev } if !prev.expired() { previousValue, replaced = prev.val, true } } } // For commits we simply assign the item to the map. We use this map to // write the entry to disk. if tx.db.persist { tx.wc.commitItems[key] = item } return previousValue, replaced, nil } // Get returns a value for a key. If the item does not exist or if the item // has expired then ErrNotFound is returned. If ignoreExpired is true, then // the found value will be returned even if it is expired. func (tx *Tx) Get(key string, ignoreExpired ...bool) (val string, err error) { if tx.db == nil { return "", ErrTxClosed } var ignore bool if len(ignoreExpired) != 0 { ignore = ignoreExpired[0] } item := tx.db.get(key) if item == nil || (item.expired() && !ignore) { // The item does not exists or has expired. Let's assume that // the caller is only interested in items that have not expired. return "", ErrNotFound } return item.val, nil } // GetWithTTL returns a value for a key together with it's TTL. // If the item does not exist or if the item has expired then ErrNotFound is returned. // A negative duration will be returned for items that do not have an expiration. func (tx *Tx) GetWithTTL(key string) (string, time.Duration, error) { if tx.db == nil { return "", 0, ErrTxClosed } item := tx.db.get(key) if item == nil || item.expired() { return "", 0, ErrNotFound } else if item.opts == nil || !item.opts.ex { return item.val, -1, nil } else if ttl := time.Until(item.opts.exat); ttl < 0 { return "", 0, ErrNotFound } else { return item.val, ttl, nil } } // Delete removes an item from the database based on the item's key. If the item // does not exist or if the item has expired then ErrNotFound is returned. // // Only a writable transaction can be used for this operation. // This operation is not allowed during iterations such as Ascend* & Descend*. func (tx *Tx) Delete(key string) (val string, err error) { if tx.db == nil { return "", ErrTxClosed } else if !tx.writable { return "", ErrTxNotWritable } else if tx.wc.itercount > 0 { return "", ErrTxIterating } item := tx.db.deleteFromDatabase(&dbItem{key: key}) if item == nil { return "", ErrNotFound } // create a rollback entry if there has not been a deleteAll call. if tx.wc.rbkeys == nil { if _, ok := tx.wc.rollbackItems[key]; !ok { tx.wc.rollbackItems[key] = item } } if tx.db.persist { tx.wc.commitItems[key] = nil } // Even though the item has been deleted, we still want to check // if it has expired. An expired item should not be returned. if item.expired() { // The item exists in the tree, but has expired. Let's assume that // the caller is only interested in items that have not expired. return "", ErrNotFound } return item.val, nil } // TTL returns the remaining time-to-live for an item. // A negative duration will be returned for items that do not have an // expiration. func (tx *Tx) TTL(key string) (time.Duration, error) { if tx.db == nil { return 0, ErrTxClosed } item := tx.db.get(key) if item == nil { return 0, ErrNotFound } else if item.opts == nil || !item.opts.ex { return -1, nil } dur := time.Until(item.opts.exat) if dur < 0 { return 0, ErrNotFound } return dur, nil } // scan iterates through a specified index and calls user-defined iterator // function for each item encountered. // The desc param indicates that the iterator should descend. // The gt param indicates that there is a greaterThan limit. // The lt param indicates that there is a lessThan limit. // The index param tells the scanner to use the specified index tree. An // empty string for the index means to scan the keys, not the values. // The start and stop params are the greaterThan, lessThan limits. For // descending order, these will be lessThan, greaterThan. // An error will be returned if the tx is closed or the index is not found. func (tx *Tx) scan(desc, gt, lt bool, index, start, stop string, iterator func(key, value string) bool) error { if tx.db == nil { return ErrTxClosed } // wrap a btree specific iterator around the user-defined iterator. iter := func(item interface{}) bool { dbi := item.(*dbItem) return iterator(dbi.key, dbi.val) } var tr *btree.BTree if index == "" { // empty index means we will use the keys tree. tr = tx.db.keys } else { idx := tx.db.idxs[index] if idx == nil { // index was not found. return error return ErrNotFound } tr = idx.btr if tr == nil { return nil } } // create some limit items var itemA, itemB *dbItem if gt || lt { if index == "" { itemA = &dbItem{key: start} itemB = &dbItem{key: stop} } else { itemA = &dbItem{val: start} itemB = &dbItem{val: stop} if desc { itemA.keyless = true itemB.keyless = true } } } // execute the scan on the underlying tree. if tx.wc != nil { tx.wc.itercount++ defer func() { tx.wc.itercount-- }() } if desc { if gt { if lt { btreeDescendRange(tr, itemA, itemB, iter) } else { btreeDescendGreaterThan(tr, itemA, iter) } } else if lt { btreeDescendLessOrEqual(tr, itemA, iter) } else { btreeDescend(tr, iter) } } else { if gt { if lt { btreeAscendRange(tr, itemA, itemB, iter) } else { btreeAscendGreaterOrEqual(tr, itemA, iter) } } else if lt { btreeAscendLessThan(tr, itemA, iter) } else { btreeAscend(tr, iter) } } return nil } // Match returns true if the specified key matches the pattern. This is a very // simple pattern matcher where '*' matches on any number characters and '?' // matches on any one character. func Match(key, pattern string) bool { return match.Match(key, pattern) } // AscendKeys allows for iterating through keys based on the specified pattern. func (tx *Tx) AscendKeys(pattern string, iterator func(key, value string) bool) error { if pattern == "" { return nil } if pattern[0] == '*' { if pattern == "*" { return tx.Ascend("", iterator) } return tx.Ascend("", func(key, value string) bool { if match.Match(key, pattern) { if !iterator(key, value) { return false } } return true }) } min, max := match.Allowable(pattern) return tx.AscendGreaterOrEqual("", min, func(key, value string) bool { if key > max { return false } if match.Match(key, pattern) { if !iterator(key, value) { return false } } return true }) } // DescendKeys allows for iterating through keys based on the specified pattern. func (tx *Tx) DescendKeys(pattern string, iterator func(key, value string) bool) error { if pattern == "" { return nil } if pattern[0] == '*' { if pattern == "*" { return tx.Descend("", iterator) } return tx.Descend("", func(key, value string) bool { if match.Match(key, pattern) { if !iterator(key, value) { return false } } return true }) } min, max := match.Allowable(pattern) return tx.DescendLessOrEqual("", max, func(key, value string) bool { if key < min { return false } if match.Match(key, pattern) { if !iterator(key, value) { return false } } return true }) } // Ascend calls the iterator for every item in the database within the range // [first, last], until iterator returns false. // When an index is provided, the results will be ordered by the item values // as specified by the less() function of the defined index. // When an index is not provided, the results will be ordered by the item key. // An invalid index will return an error. func (tx *Tx) Ascend(index string, iterator func(key, value string) bool) error { return tx.scan(false, false, false, index, "", "", iterator) } // AscendGreaterOrEqual calls the iterator for every item in the database within // the range [pivot, last], until iterator returns false. // When an index is provided, the results will be ordered by the item values // as specified by the less() function of the defined index. // When an index is not provided, the results will be ordered by the item key. // An invalid index will return an error. func (tx *Tx) AscendGreaterOrEqual(index, pivot string, iterator func(key, value string) bool) error { return tx.scan(false, true, false, index, pivot, "", iterator) } // AscendLessThan calls the iterator for every item in the database within the // range [first, pivot), until iterator returns false. // When an index is provided, the results will be ordered by the item values // as specified by the less() function of the defined index. // When an index is not provided, the results will be ordered by the item key. // An invalid index will return an error. func (tx *Tx) AscendLessThan(index, pivot string, iterator func(key, value string) bool) error { return tx.scan(false, false, true, index, pivot, "", iterator) } // AscendRange calls the iterator for every item in the database within // the range [greaterOrEqual, lessThan), until iterator returns false. // When an index is provided, the results will be ordered by the item values // as specified by the less() function of the defined index. // When an index is not provided, the results will be ordered by the item key. // An invalid index will return an error. func (tx *Tx) AscendRange(index, greaterOrEqual, lessThan string, iterator func(key, value string) bool) error { return tx.scan( false, true, true, index, greaterOrEqual, lessThan, iterator, ) } // Descend calls the iterator for every item in the database within the range // [last, first], until iterator returns false. // When an index is provided, the results will be ordered by the item values // as specified by the less() function of the defined index. // When an index is not provided, the results will be ordered by the item key. // An invalid index will return an error. func (tx *Tx) Descend(index string, iterator func(key, value string) bool) error { return tx.scan(true, false, false, index, "", "", iterator) } // DescendGreaterThan calls the iterator for every item in the database within // the range [last, pivot), until iterator returns false. // When an index is provided, the results will be ordered by the item values // as specified by the less() function of the defined index. // When an index is not provided, the results will be ordered by the item key. // An invalid index will return an error. func (tx *Tx) DescendGreaterThan(index, pivot string, iterator func(key, value string) bool) error { return tx.scan(true, true, false, index, pivot, "", iterator) } // DescendLessOrEqual calls the iterator for every item in the database within // the range [pivot, first], until iterator returns false. // When an index is provided, the results will be ordered by the item values // as specified by the less() function of the defined index. // When an index is not provided, the results will be ordered by the item key. // An invalid index will return an error. func (tx *Tx) DescendLessOrEqual(index, pivot string, iterator func(key, value string) bool) error { return tx.scan(true, false, true, index, pivot, "", iterator) } // DescendRange calls the iterator for every item in the database within // the range [lessOrEqual, greaterThan), until iterator returns false. // When an index is provided, the results will be ordered by the item values // as specified by the less() function of the defined index. // When an index is not provided, the results will be ordered by the item key. // An invalid index will return an error. func (tx *Tx) DescendRange(index, lessOrEqual, greaterThan string, iterator func(key, value string) bool) error { return tx.scan( true, true, true, index, lessOrEqual, greaterThan, iterator, ) } // AscendEqual calls the iterator for every item in the database that equals // pivot, until iterator returns false. // When an index is provided, the results will be ordered by the item values // as specified by the less() function of the defined index. // When an index is not provided, the results will be ordered by the item key. // An invalid index will return an error. func (tx *Tx) AscendEqual(index, pivot string, iterator func(key, value string) bool) error { var err error var less func(a, b string) bool if index != "" { less, err = tx.GetLess(index) if err != nil { return err } } return tx.AscendGreaterOrEqual(index, pivot, func(key, value string) bool { if less == nil { if key != pivot { return false } } else if less(pivot, value) { return false } return iterator(key, value) }) } // DescendEqual calls the iterator for every item in the database that equals // pivot, until iterator returns false. // When an index is provided, the results will be ordered by the item values // as specified by the less() function of the defined index. // When an index is not provided, the results will be ordered by the item key. // An invalid index will return an error. func (tx *Tx) DescendEqual(index, pivot string, iterator func(key, value string) bool) error { var err error var less func(a, b string) bool if index != "" { less, err = tx.GetLess(index) if err != nil { return err } } return tx.DescendLessOrEqual(index, pivot, func(key, value string) bool { if less == nil { if key != pivot { return false } } else if less(value, pivot) { return false } return iterator(key, value) }) } // rect is used by Intersects and Nearby type rect struct { min, max []float64 } func (r *rect) Rect(ctx interface{}) (min, max []float64) { return r.min, r.max } // Nearby searches for rectangle items that are nearby a target rect. // All items belonging to the specified index will be returned in order of // nearest to farthest. // The specified index must have been created by AddIndex() and the target // is represented by the rect string. This string will be processed by the // same bounds function that was passed to the CreateSpatialIndex() function. // An invalid index will return an error. // The dist param is the distance of the bounding boxes. In the case of // simple 2D points, it's the distance of the two 2D points squared. func (tx *Tx) Nearby(index, bounds string, iterator func(key, value string, dist float64) bool) error { if tx.db == nil { return ErrTxClosed } if index == "" { // cannot search on keys tree. just return nil. return nil } // // wrap a rtree specific iterator around the user-defined iterator. iter := func(item rtred.Item, dist float64) bool { dbi := item.(*dbItem) return iterator(dbi.key, dbi.val, dist) } idx := tx.db.idxs[index] if idx == nil { // index was not found. return error return ErrNotFound } if idx.rtr == nil { // not an r-tree index. just return nil return nil } // execute the nearby search var min, max []float64 if idx.rect != nil { min, max = idx.rect(bounds) } // set the center param to false, which uses the box dist calc. idx.rtr.KNN(&rect{min, max}, false, iter) return nil } // Intersects searches for rectangle items that intersect a target rect. // The specified index must have been created by AddIndex() and the target // is represented by the rect string. This string will be processed by the // same bounds function that was passed to the CreateSpatialIndex() function. // An invalid index will return an error. func (tx *Tx) Intersects(index, bounds string, iterator func(key, value string) bool) error { if tx.db == nil { return ErrTxClosed } if index == "" { // cannot search on keys tree. just return nil. return nil } // wrap a rtree specific iterator around the user-defined iterator. iter := func(item rtred.Item) bool { dbi := item.(*dbItem) return iterator(dbi.key, dbi.val) } idx := tx.db.idxs[index] if idx == nil { // index was not found. return error return ErrNotFound } if idx.rtr == nil { // not an r-tree index. just return nil return nil } // execute the search var min, max []float64 if idx.rect != nil { min, max = idx.rect(bounds) } idx.rtr.Search(&rect{min, max}, iter) return nil } // Len returns the number of items in the database func (tx *Tx) Len() (int, error) { if tx.db == nil { return 0, ErrTxClosed } return tx.db.keys.Len(), nil } // IndexOptions provides an index with additional features or // alternate functionality. type IndexOptions struct { // CaseInsensitiveKeyMatching allow for case-insensitive // matching on keys when setting key/values. CaseInsensitiveKeyMatching bool } // CreateIndex builds a new index and populates it with items. // The items are ordered in an b-tree and can be retrieved using the // Ascend* and Descend* methods. // An error will occur if an index with the same name already exists. // // When a pattern is provided, the index will be populated with // keys that match the specified pattern. This is a very simple pattern // match where '*' matches on any number characters and '?' matches on // any one character. // The less function compares if string 'a' is less than string 'b'. // It allows for indexes to create custom ordering. It's possible // that the strings may be textual or binary. It's up to the provided // less function to handle the content format and comparison. // There are some default less function that can be used such as // IndexString, IndexBinary, etc. func (tx *Tx) CreateIndex(name, pattern string, less ...func(a, b string) bool) error { return tx.createIndex(name, pattern, less, nil, nil) } // CreateIndexOptions is the same as CreateIndex except that it allows // for additional options. func (tx *Tx) CreateIndexOptions(name, pattern string, opts *IndexOptions, less ...func(a, b string) bool) error { return tx.createIndex(name, pattern, less, nil, opts) } // CreateSpatialIndex builds a new index and populates it with items. // The items are organized in an r-tree and can be retrieved using the // Intersects method. // An error will occur if an index with the same name already exists. // // The rect function converts a string to a rectangle. The rectangle is // represented by two arrays, min and max. Both arrays may have a length // between 1 and 20, and both arrays must match in length. A length of 1 is a // one dimensional rectangle, and a length of 4 is a four dimension rectangle. // There is support for up to 20 dimensions. // The values of min must be less than the values of max at the same dimension. // Thus min[0] must be less-than-or-equal-to max[0]. // The IndexRect is a default function that can be used for the rect // parameter. func (tx *Tx) CreateSpatialIndex(name, pattern string, rect func(item string) (min, max []float64)) error { return tx.createIndex(name, pattern, nil, rect, nil) } // CreateSpatialIndexOptions is the same as CreateSpatialIndex except that // it allows for additional options. func (tx *Tx) CreateSpatialIndexOptions(name, pattern string, opts *IndexOptions, rect func(item string) (min, max []float64)) error { return tx.createIndex(name, pattern, nil, rect, nil) } // createIndex is called by CreateIndex() and CreateSpatialIndex() func (tx *Tx) createIndex(name string, pattern string, lessers []func(a, b string) bool, rect func(item string) (min, max []float64), opts *IndexOptions, ) error { if tx.db == nil { return ErrTxClosed } else if !tx.writable { return ErrTxNotWritable } else if tx.wc.itercount > 0 { return ErrTxIterating } if name == "" { // cannot create an index without a name. // an empty name index is designated for the main "keys" tree. return ErrIndexExists } // check if an index with that name already exists. if _, ok := tx.db.idxs[name]; ok { // index with name already exists. error. return ErrIndexExists } // genreate a less function var less func(a, b string) bool switch len(lessers) { default: // multiple less functions specified. // create a compound less function. less = func(a, b string) bool { for i := 0; i < len(lessers)-1; i++ { if lessers[i](a, b) { return true } if lessers[i](b, a) { return false } } return lessers[len(lessers)-1](a, b) } case 0: // no less function case 1: less = lessers[0] } var sopts IndexOptions if opts != nil { sopts = *opts } if sopts.CaseInsensitiveKeyMatching { pattern = strings.ToLower(pattern) } // intialize new index idx := &index{ name: name, pattern: pattern, less: less, rect: rect, db: tx.db, opts: sopts, } idx.rebuild() // save the index tx.db.idxs[name] = idx if tx.wc.rbkeys == nil { // store the index in the rollback map. if _, ok := tx.wc.rollbackIndexes[name]; !ok { // we use nil to indicate that the index should be removed upon // rollback. tx.wc.rollbackIndexes[name] = nil } } return nil } // DropIndex removes an index. func (tx *Tx) DropIndex(name string) error { if tx.db == nil { return ErrTxClosed } else if !tx.writable { return ErrTxNotWritable } else if tx.wc.itercount > 0 { return ErrTxIterating } if name == "" { // cannot drop the default "keys" index return ErrInvalidOperation } idx, ok := tx.db.idxs[name] if !ok { return ErrNotFound } // delete from the map. // this is all that is needed to delete an index. delete(tx.db.idxs, name) if tx.wc.rbkeys == nil { // store the index in the rollback map. if _, ok := tx.wc.rollbackIndexes[name]; !ok { // we use a non-nil copy of the index without the data to indicate // that the index should be rebuilt upon rollback. tx.wc.rollbackIndexes[name] = idx.clearCopy() } } return nil } // Indexes returns a list of index names. func (tx *Tx) Indexes() ([]string, error) { if tx.db == nil { return nil, ErrTxClosed } names := make([]string, 0, len(tx.db.idxs)) for name := range tx.db.idxs { names = append(names, name) } sort.Strings(names) return names, nil } // Rect is helper function that returns a string representation // of a rect. IndexRect() is the reverse function and can be used // to generate a rect from a string. func Rect(min, max []float64) string { r := grect.Rect{Min: min, Max: max} return r.String() } // Point is a helper function that converts a series of float64s // to a rectangle for a spatial index. func Point(coords ...float64) string { return Rect(coords, coords) } // IndexRect is a helper function that converts string to a rect. // Rect() is the reverse function and can be used to generate a string // from a rect. func IndexRect(a string) (min, max []float64) { r := grect.Get(a) return r.Min, r.Max } // IndexString is a helper function that return true if 'a' is less than 'b'. // This is a case-insensitive comparison. Use the IndexBinary() for comparing // case-sensitive strings. func IndexString(a, b string) bool { for i := 0; i < len(a) && i < len(b); i++ { if a[i] >= 'A' && a[i] <= 'Z' { if b[i] >= 'A' && b[i] <= 'Z' { // both are uppercase, do nothing if a[i] < b[i] { return true } else if a[i] > b[i] { return false } } else { // a is uppercase, convert a to lowercase if a[i]+32 < b[i] { return true } else if a[i]+32 > b[i] { return false } } } else if b[i] >= 'A' && b[i] <= 'Z' { // b is uppercase, convert b to lowercase if a[i] < b[i]+32 { return true } else if a[i] > b[i]+32 { return false } } else { // neither are uppercase if a[i] < b[i] { return true } else if a[i] > b[i] { return false } } } return len(a) < len(b) } // IndexBinary is a helper function that returns true if 'a' is less than 'b'. // This compares the raw binary of the string. func IndexBinary(a, b string) bool { return a < b } // IndexInt is a helper function that returns true if 'a' is less than 'b'. func IndexInt(a, b string) bool { ia, _ := strconv.ParseInt(a, 10, 64) ib, _ := strconv.ParseInt(b, 10, 64) return ia < ib } // IndexUint is a helper function that returns true if 'a' is less than 'b'. // This compares uint64s that are added to the database using the // Uint() conversion function. func IndexUint(a, b string) bool { ia, _ := strconv.ParseUint(a, 10, 64) ib, _ := strconv.ParseUint(b, 10, 64) return ia < ib } // IndexFloat is a helper function that returns true if 'a' is less than 'b'. // This compares float64s that are added to the database using the // Float() conversion function. func IndexFloat(a, b string) bool { ia, _ := strconv.ParseFloat(a, 64) ib, _ := strconv.ParseFloat(b, 64) return ia < ib } // IndexJSON provides for the ability to create an index on any JSON field. // When the field is a string, the comparison will be case-insensitive. // It returns a helper function used by CreateIndex. func IndexJSON(path string) func(a, b string) bool { return func(a, b string) bool { return gjson.Get(a, path).Less(gjson.Get(b, path), false) } } // IndexJSONCaseSensitive provides for the ability to create an index on // any JSON field. // When the field is a string, the comparison will be case-sensitive. // It returns a helper function used by CreateIndex. func IndexJSONCaseSensitive(path string) func(a, b string) bool { return func(a, b string) bool { return gjson.Get(a, path).Less(gjson.Get(b, path), true) } } // Desc is a helper function that changes the order of an index. func Desc(less func(a, b string) bool) func(a, b string) bool { return func(a, b string) bool { return less(b, a) } } //// Wrappers around btree Ascend/Descend func bLT(tr *btree.BTree, a, b interface{}) bool { return tr.Less(a, b) } func bGT(tr *btree.BTree, a, b interface{}) bool { return tr.Less(b, a) } // func bLTE(tr *btree.BTree, a, b interface{}) bool { return !tr.Less(b, a) } // func bGTE(tr *btree.BTree, a, b interface{}) bool { return !tr.Less(a, b) } // Ascend func btreeAscend(tr *btree.BTree, iter func(item interface{}) bool) { tr.Ascend(nil, iter) } func btreeAscendLessThan(tr *btree.BTree, pivot interface{}, iter func(item interface{}) bool, ) { tr.Ascend(nil, func(item interface{}) bool { return bLT(tr, item, pivot) && iter(item) }) } func btreeAscendGreaterOrEqual(tr *btree.BTree, pivot interface{}, iter func(item interface{}) bool, ) { tr.Ascend(pivot, iter) } func btreeAscendRange(tr *btree.BTree, greaterOrEqual, lessThan interface{}, iter func(item interface{}) bool, ) { tr.Ascend(greaterOrEqual, func(item interface{}) bool { return bLT(tr, item, lessThan) && iter(item) }) } // Descend func btreeDescend(tr *btree.BTree, iter func(item interface{}) bool) { tr.Descend(nil, iter) } func btreeDescendGreaterThan(tr *btree.BTree, pivot interface{}, iter func(item interface{}) bool, ) { tr.Descend(nil, func(item interface{}) bool { return bGT(tr, item, pivot) && iter(item) }) } func btreeDescendRange(tr *btree.BTree, lessOrEqual, greaterThan interface{}, iter func(item interface{}) bool, ) { tr.Descend(lessOrEqual, func(item interface{}) bool { return bGT(tr, item, greaterThan) && iter(item) }) } func btreeDescendLessOrEqual(tr *btree.BTree, pivot interface{}, iter func(item interface{}) bool, ) { tr.Descend(pivot, iter) } func btreeNew(less func(a, b interface{}) bool) *btree.BTree { // Using NewNonConcurrent because we're managing our own locks. return btree.NewNonConcurrent(less) }