Return error when mutating during iteration. Fixes #12

This commit is contained in:
Josh Baker 2016-09-02 10:05:14 -07:00
parent 2d071d5cbe
commit d570a6fba9
2 changed files with 56 additions and 1 deletions

View File

@ -55,6 +55,9 @@ var (
// ErrPersistenceActive is returned when post-loading data from an database
// not opened with Open(":memory:").
ErrPersistenceActive = errors.New("persistence active")
// ErrTxIterating is returned when Set or Delete are called while iterating.
ErrTxIterating = errors.New("tx is iterating")
)
// DB represents a collection of key-value pairs that persist on disk.
@ -928,6 +931,7 @@ type Tx struct {
funcd bool // when true Commit and Rollback panic.
rollbacks map[string]*dbItem // cotnains details for rolling back tx.
commits map[string]*dbItem // contains details for committing tx.
itercount int // stack of iterators
}
// begin opens a new transaction.
@ -1183,12 +1187,17 @@ type SetOptions struct {
// value will be returned through the previousValue variable.
// The results of this operation will not be available to other
// transactions until the current transaction has successfully committed.
//
// Only a writable transaction can be used with this operation.
// This operation is not allowed during iterations such as Ascend* & Descend*.
func (tx *Tx) Set(key, value string, opts *SetOptions) (previousValue string,
replaced bool, err error) {
if tx.db == nil {
return "", false, ErrTxClosed
} else if !tx.writable {
return "", false, ErrTxNotWritable
} else if tx.itercount > 0 {
return "", false, ErrTxIterating
}
item := &dbItem{key: key, val: value}
if opts != nil {
@ -1243,12 +1252,15 @@ func (tx *Tx) Get(key string) (val string, err error) {
// Delete removes an item from the database based on the item's key. If the item
// does not exist or if the item has expired then ErrNotFound is returned.
//
// Only writable transaction can be used for Delete() calls.
// Only a writable transaction can be used for this operation.
// This operation is not allowed during iterations such as Ascend* & Descend*.
func (tx *Tx) Delete(key string) (val string, err error) {
if tx.db == nil {
return "", ErrTxClosed
} else if !tx.writable {
return "", ErrTxNotWritable
} else if tx.itercount > 0 {
return "", ErrTxIterating
}
item := tx.db.deleteFromDatabase(&dbItem{key: key})
if item == nil {
@ -1337,6 +1349,10 @@ func (tx *Tx) scan(desc, gt, lt bool, index, start, stop string,
}
}
// execute the scan on the underlying tree.
tx.itercount++
defer func() {
tx.itercount--
}()
if desc {
if gt {
if lt {

View File

@ -144,7 +144,46 @@ func TestSaveLoad(t *testing.T) {
t.Fatal(err)
}
}
func TestMutatingIterator(t *testing.T) {
db := testOpen(t)
defer testClose(db)
count := 1000
if err := db.CreateIndex("ages", "user:*:age", IndexInt); err != nil {
t.Fatal(err)
}
for i := 0; i < 10; i++ {
if err := db.Update(func(tx *Tx) error {
for j := 0; j < count; j++ {
key := fmt.Sprintf("user:%d:age", j)
val := fmt.Sprintf("%d", rand.Intn(100))
if _, _, err := tx.Set(key, val, nil); err != nil {
return err
}
}
return nil
}); err != nil {
t.Fatal(err)
}
if err := db.Update(func(tx *Tx) error {
return tx.Ascend("ages", func(key, val string) bool {
_, err := tx.Delete(key)
if err != ErrTxIterating {
t.Fatal("should not be able to call Delete while iterating.")
}
_, _, err = tx.Set(key, "", nil)
if err != ErrTxIterating {
t.Fatal("should not be able to call Set while iterating.")
}
return true
})
}); err != nil {
t.Fatal(err)
}
}
}
func TestVariousTx(t *testing.T) {
db := testOpen(t)
defer testClose(db)