Merge branch 'dc0d-onexpiredsync-01'

This commit is contained in:
Josh Baker 2018-03-15 17:57:49 -07:00
commit cd4f5a9b05
2 changed files with 129 additions and 9 deletions

View File

@ -121,6 +121,12 @@ type Config struct {
// OnExpired is used to custom handle the deletion option when a key
// has been expired.
OnExpired func(keys []string)
// OnExpiredSync will be called inside the same transaction that is performing
// the deletion of expired items. If OnExpired is present then this callback
// will not be called. If this callback is present, then the deletion of the
// timeed-out item is the explicit responsibility of this callback.
OnExpiredSync func(key, value string, tx *Tx) error
}
// exctx is a simple b-tree context for ordering by expiration.
@ -544,9 +550,13 @@ func (db *DB) backgroundManager() {
// Open a standard view. This will take a full lock of the
// database thus allowing for access to anything we need.
var onExpired func([]string)
var expired []string
var expired []*dbItem
var onExpiredSync func(key, value string, tx *Tx) error
err := db.Update(func(tx *Tx) error {
onExpired = db.config.OnExpired
if onExpired == nil {
onExpiredSync = db.config.OnExpiredSync
}
if db.persist && !db.config.AutoShrinkDisabled {
pos, err := db.file.Seek(0, 1)
if err != nil {
@ -562,12 +572,12 @@ func (db *DB) backgroundManager() {
db.exps.AscendLessThan(&dbItem{
opts: &dbItemOpts{ex: true, exat: time.Now()},
}, func(item btree.Item) bool {
expired = append(expired, item.(*dbItem).key)
expired = append(expired, item.(*dbItem))
return true
})
if onExpired == nil {
for _, key := range expired {
if _, err := tx.Delete(key); err != nil {
if onExpired == nil && onExpiredSync == nil {
for _, itm := range expired {
if _, err := tx.Delete(itm.key); err != nil {
// it's ok to get a "not found" because the
// 'Delete' method reports "not found" for
// expired items.
@ -576,6 +586,12 @@ func (db *DB) backgroundManager() {
}
}
}
} else if onExpiredSync != nil {
for _, itm := range expired {
if err := onExpiredSync(itm.key, itm.val, tx); err != nil {
return err
}
}
}
return nil
})
@ -585,7 +601,11 @@ func (db *DB) backgroundManager() {
// send expired event, if needed
if onExpired != nil && len(expired) > 0 {
onExpired(expired)
keys := make([]string, 0, 32)
for _, itm := range expired {
keys = append(keys, itm.key)
}
onExpired(keys)
}
// execute a disk sync, if needed
@ -1399,13 +1419,18 @@ func (tx *Tx) Set(key, value string, opts *SetOptions) (previousValue string,
}
// Get returns a value for a key. If the item does not exist or if the item
// has expired then ErrNotFound is returned.
func (tx *Tx) Get(key string) (val string, err error) {
// has expired then ErrNotFound is returned. If ignoreExpired is true, then
// the found value will be returned even if it is expired.
func (tx *Tx) Get(key string, ignoreExpired ...bool) (val string, err error) {
if tx.db == nil {
return "", ErrTxClosed
}
var ignore bool
if len(ignoreExpired) != 0 {
ignore = ignoreExpired[0]
}
item := tx.db.get(key)
if item == nil || item.expired() {
if item == nil || (item.expired() && !ignore) {
// The item does not exists or has expired. Let's assume that
// the caller is only interested in items that have not expired.
return "", ErrNotFound

View File

@ -2551,3 +2551,98 @@ func TestJSONIndex(t *testing.T) {
t.Fatalf("expected %v, got %v", expect, strings.Join(keys, ","))
}
}
func TestOnExpiredSync(t *testing.T) {
db := testOpen(t)
defer testClose(db)
var config Config
if err := db.ReadConfig(&config); err != nil {
t.Fatal(err)
}
hits := make(chan int, 3)
config.OnExpiredSync = func(key, value string, tx *Tx) error {
n, err := strconv.Atoi(value)
if err != nil {
return err
}
defer func() { hits <- n }()
if n >= 2 {
_, err = tx.Delete(key)
if err != ErrNotFound {
return err
}
return nil
}
n++
_, _, err = tx.Set(key, strconv.Itoa(n), &SetOptions{Expires: true, TTL: time.Millisecond * 100})
return err
}
if err := db.SetConfig(config); err != nil {
t.Fatal(err)
}
err := db.Update(func(tx *Tx) error {
_, _, err := tx.Set("K", "0", &SetOptions{Expires: true, TTL: time.Millisecond * 100})
return err
})
if err != nil {
t.Fail()
}
done := make(chan struct{})
go func() {
ticks := time.NewTicker(time.Millisecond * 50)
defer ticks.Stop()
for {
select {
case <-done:
return
case <-ticks.C:
err := db.View(func(tx *Tx) error {
v, err := tx.Get("K", true)
if err != nil {
return err
}
n, err := strconv.Atoi(v)
if err != nil {
return err
}
if n < 0 || n > 2 {
t.Fail()
}
return nil
})
if err != nil {
t.Fail()
}
}
}
}()
OUTER1:
for {
select {
case <-time.After(time.Second * 2):
t.Fail()
case v := <-hits:
if v >= 2 {
break OUTER1
}
}
}
err = db.View(func(tx *Tx) error {
defer close(done)
v, err := tx.Get("K")
if err != nil {
t.Fail()
return err
}
if v != "2" {
t.Fail()
}
return nil
})
if err != nil {
t.Fail()
}
}