diff --git a/buntdb.go b/buntdb.go index e99ba26..b2ec7e3 100644 --- a/buntdb.go +++ b/buntdb.go @@ -121,6 +121,12 @@ type Config struct { // OnExpired is used to custom handle the deletion option when a key // has been expired. OnExpired func(keys []string) + + // OnExpiredSync will be called inside the same transaction that is performing + // the deletion of expired items. If OnExpired is present then this callback + // will not be called. If this callback is present, then the deletion of the + // timeed-out item is the explicit responsibility of this callback. + OnExpiredSync func(key, value string, tx *Tx) error } // exctx is a simple b-tree context for ordering by expiration. @@ -544,9 +550,13 @@ func (db *DB) backgroundManager() { // Open a standard view. This will take a full lock of the // database thus allowing for access to anything we need. var onExpired func([]string) - var expired []string + var expired []*dbItem + var onExpiredSync func(key, value string, tx *Tx) error err := db.Update(func(tx *Tx) error { onExpired = db.config.OnExpired + if onExpired == nil { + onExpiredSync = db.config.OnExpiredSync + } if db.persist && !db.config.AutoShrinkDisabled { pos, err := db.file.Seek(0, 1) if err != nil { @@ -562,12 +572,12 @@ func (db *DB) backgroundManager() { db.exps.AscendLessThan(&dbItem{ opts: &dbItemOpts{ex: true, exat: time.Now()}, }, func(item btree.Item) bool { - expired = append(expired, item.(*dbItem).key) + expired = append(expired, item.(*dbItem)) return true }) - if onExpired == nil { - for _, key := range expired { - if _, err := tx.Delete(key); err != nil { + if onExpired == nil && onExpiredSync == nil { + for _, itm := range expired { + if _, err := tx.Delete(itm.key); err != nil { // it's ok to get a "not found" because the // 'Delete' method reports "not found" for // expired items. @@ -576,6 +586,12 @@ func (db *DB) backgroundManager() { } } } + } else if onExpiredSync != nil { + for _, itm := range expired { + if err := onExpiredSync(itm.key, itm.val, tx); err != nil { + return err + } + } } return nil }) @@ -585,7 +601,11 @@ func (db *DB) backgroundManager() { // send expired event, if needed if onExpired != nil && len(expired) > 0 { - onExpired(expired) + keys := make([]string, 0, 32) + for _, itm := range expired { + keys = append(keys, itm.key) + } + onExpired(keys) } // execute a disk sync, if needed @@ -1399,13 +1419,18 @@ func (tx *Tx) Set(key, value string, opts *SetOptions) (previousValue string, } // Get returns a value for a key. If the item does not exist or if the item -// has expired then ErrNotFound is returned. -func (tx *Tx) Get(key string) (val string, err error) { +// has expired then ErrNotFound is returned. If ignoreExpired is true, then +// the found value will be returned even if it is expired. +func (tx *Tx) Get(key string, ignoreExpired ...bool) (val string, err error) { if tx.db == nil { return "", ErrTxClosed } + var ignore bool + if len(ignoreExpired) != 0 { + ignore = ignoreExpired[0] + } item := tx.db.get(key) - if item == nil || item.expired() { + if item == nil || (item.expired() && !ignore) { // The item does not exists or has expired. Let's assume that // the caller is only interested in items that have not expired. return "", ErrNotFound diff --git a/buntdb_test.go b/buntdb_test.go index cef8ccb..01dd712 100644 --- a/buntdb_test.go +++ b/buntdb_test.go @@ -2551,3 +2551,98 @@ func TestJSONIndex(t *testing.T) { t.Fatalf("expected %v, got %v", expect, strings.Join(keys, ",")) } } + +func TestOnExpiredSync(t *testing.T) { + db := testOpen(t) + defer testClose(db) + + var config Config + if err := db.ReadConfig(&config); err != nil { + t.Fatal(err) + } + hits := make(chan int, 3) + config.OnExpiredSync = func(key, value string, tx *Tx) error { + n, err := strconv.Atoi(value) + if err != nil { + return err + } + defer func() { hits <- n }() + if n >= 2 { + _, err = tx.Delete(key) + if err != ErrNotFound { + return err + } + return nil + } + n++ + _, _, err = tx.Set(key, strconv.Itoa(n), &SetOptions{Expires: true, TTL: time.Millisecond * 100}) + return err + } + if err := db.SetConfig(config); err != nil { + t.Fatal(err) + } + err := db.Update(func(tx *Tx) error { + _, _, err := tx.Set("K", "0", &SetOptions{Expires: true, TTL: time.Millisecond * 100}) + return err + }) + if err != nil { + t.Fail() + } + + done := make(chan struct{}) + go func() { + ticks := time.NewTicker(time.Millisecond * 50) + defer ticks.Stop() + for { + select { + case <-done: + return + case <-ticks.C: + err := db.View(func(tx *Tx) error { + v, err := tx.Get("K", true) + if err != nil { + return err + } + n, err := strconv.Atoi(v) + if err != nil { + return err + } + if n < 0 || n > 2 { + t.Fail() + } + return nil + }) + if err != nil { + t.Fail() + } + } + } + }() + +OUTER1: + for { + select { + case <-time.After(time.Second * 2): + t.Fail() + case v := <-hits: + if v >= 2 { + break OUTER1 + } + } + } + err = db.View(func(tx *Tx) error { + defer close(done) + v, err := tx.Get("K") + if err != nil { + t.Fail() + return err + } + if v != "2" { + t.Fail() + } + return nil + }) + if err != nil { + t.Fail() + } +}