From da25e2b37f39a65dcf3da04cfac2d05f07776242 Mon Sep 17 00:00:00 2001 From: siddontang Date: Wed, 7 May 2014 13:40:48 +0800 Subject: [PATCH] iterator use range to support close/open interval --- leveldb/db.go | 22 ++-- leveldb/iterator.go | 85 +++++++++--- leveldb/leveldb_test.go | 286 ++++++++++++---------------------------- leveldb/snapshot.go | 8 +- 4 files changed, 164 insertions(+), 237 deletions(-) diff --git a/leveldb/db.go b/leveldb/db.go index 0fa0462..a7a3221 100644 --- a/leveldb/db.go +++ b/leveldb/db.go @@ -115,6 +115,13 @@ func (db *DB) Destroy() { levigo.DestroyDatabase(db.cfg.Path, opts) } +func (db *DB) Clear() { + it := db.Iterator(NewRange(nil, nil), 0) + for ; it.Valid(); it.Next() { + db.Delete(it.Key()) + } +} + func (db *DB) Put(key, value []byte) error { return db.db.Put(db.writeOpts, key, value) } @@ -134,21 +141,14 @@ func (db *DB) NewWriteBatch() *WriteBatch { return wb } -//[begin, end] close(inclusive) interval -//if begin is nil, we will seek to first -//if end is nil, we will seek to last //limit <= 0, no limit -func (db *DB) Iterator(begin []byte, end []byte, limit int) *Iterator { - return newIterator(db, db.iteratorOpts, begin, end, limit, forward) +func (db *DB) Iterator(r *Range, limit int) *Iterator { + return newIterator(db, db.iteratorOpts, r, limit, forward) } -//[rbegin, rend] close(inclusive) interval -//rbegin should bigger than rend -//if rbegin is nil, we will seek to last -//if end is nil, we will seek to first //limit <= 0, no limit -func (db *DB) ReverseIterator(rbegin []byte, rend []byte, limit int) *Iterator { - return newIterator(db, db.iteratorOpts, rbegin, rend, limit, backward) +func (db *DB) ReverseIterator(r *Range, limit int) *Iterator { + return newIterator(db, db.iteratorOpts, r, limit, backward) } func (db *DB) NewSnapshot() *Snapshot { diff --git a/leveldb/iterator.go b/leveldb/iterator.go index 7cbc213..588f09b 100644 --- a/leveldb/iterator.go +++ b/leveldb/iterator.go @@ -8,11 +8,39 @@ import ( const forward uint8 = 0 const backward uint8 = 1 +//min must less or equal than max +//MinEx if true, range is left open interval (min, ... +//MaxEx if true, range is right open interval ..., max) +//Default range is close interval +type Range struct { + Min []byte + Max []byte + + MinEx bool + MaxEx bool +} + +func NewRange(min []byte, max []byte) *Range { + return &Range{min, max, false, false} +} + +func NewOpenRange(min []byte, max []byte) *Range { + return &Range{min, max, true, true} +} + +func NewLOpenRange(min []byte, max []byte) *Range { + return &Range{min, max, true, false} +} + +func NewROpenRange(min []byte, max []byte) *Range { + return &Range{min, max, true, true} +} + type Iterator struct { it *levigo.Iterator - start []byte - stop []byte + r *Range + limit int step int @@ -21,31 +49,42 @@ type Iterator struct { direction uint8 } -func newIterator(db *DB, opts *levigo.ReadOptions, start []byte, stop []byte, limit int, direction uint8) *Iterator { +func newIterator(db *DB, opts *levigo.ReadOptions, r *Range, limit int, direction uint8) *Iterator { it := new(Iterator) it.it = db.db.NewIterator(opts) - it.start = start - it.stop = stop + it.r = r it.limit = limit it.direction = direction it.step = 0 - if start == nil { - if direction == forward { + if direction == forward { + if r.Min == nil { it.it.SeekToFirst() } else { - it.it.SeekToLast() + it.it.Seek(r.Min) + + if r.MinEx { + if it.Valid() && bytes.Equal(it.Key(), r.Min) { + it.it.Next() + } + } } } else { - it.it.Seek(start) - - if it.Valid() && !bytes.Equal(it.Key(), start) { - //for forward, key is the next bigger than start - //for backward, key is the next bigger than start, so must go prev - if direction == backward { + if r.Max == nil { + it.it.SeekToLast() + } else { + it.it.Seek(r.Max) + if it.Valid() && !bytes.Equal(it.Key(), r.Max) { + //key must bigger than max, so we must go prev it.it.Prev() } + + if r.MaxEx { + if it.Valid() && bytes.Equal(it.Key(), r.Max) { + it.it.Prev() + } + } } } @@ -62,12 +101,22 @@ func (it *Iterator) Valid() bool { } if it.direction == forward { - if it.stop != nil && bytes.Compare(it.Key(), it.stop) > 0 { - return false + if it.r.Max != nil { + r := bytes.Compare(it.Key(), it.r.Max) + if !it.r.MaxEx { + return !(r > 0) + } else { + return !(r >= 0) + } } } else { - if it.stop != nil && bytes.Compare(it.Key(), it.stop) < 0 { - return false + if it.r.Min != nil { + r := bytes.Compare(it.Key(), it.r.Min) + if !it.r.MinEx { + return !(r < 0) + } else { + return !(r <= 0) + } } } diff --git a/leveldb/leveldb_test.go b/leveldb/leveldb_test.go index 6bee93c..62f53b9 100644 --- a/leveldb/leveldb_test.go +++ b/leveldb/leveldb_test.go @@ -107,210 +107,118 @@ func TestBatch(t *testing.T) { db.Delete(key1) } -func TestIterator(t *testing.T) { - db := getTestDB() - for it := db.Iterator(nil, nil, 0); it.Valid(); it.Next() { - db.Delete(it.Key()) - } - - for i := 0; i < 10; i++ { - key := []byte(fmt.Sprintf("key_%d", i)) - value := []byte(fmt.Sprintf("value_%d", i)) - db.Put(key, value) - } - - step := 0 - var it *Iterator - for it = db.Iterator(nil, nil, 0); it.Valid(); it.Next() { - key := it.Key() - value := it.Value() - - if string(key) != fmt.Sprintf("key_%d", step) { - t.Fatal(string(key), step) - } - - if string(value) != fmt.Sprintf("value_%d", step) { - t.Fatal(string(value), step) - } - - step++ +func checkIterator(it *Iterator, cv ...int) error { + v := make([]string, 0, len(cv)) + for ; it.Valid(); it.Next() { + k := it.Key() + v = append(v, string(k)) } it.Close() - step = 2 - for it = db.Iterator([]byte("key_2"), nil, 3); it.Valid(); it.Next() { - key := it.Key() - value := it.Value() + if len(v) != len(cv) { + return fmt.Errorf("len error %d != %d", len(v), len(cv)) + } - if string(key) != fmt.Sprintf("key_%d", step) { - t.Fatal(string(key), step) + for k, i := range cv { + if fmt.Sprintf("key_%d", i) != v[k] { + return fmt.Errorf("%s, %d", v[k], i) } - - if string(value) != fmt.Sprintf("value_%d", step) { - t.Fatal(string(value), step) - } - - step++ - } - it.Close() - - if step != 5 { - t.Fatal("invalid step", step) } - step = 2 - for it = db.Iterator([]byte("key_2"), []byte("key_5"), 0); it.Valid(); it.Next() { - key := it.Key() - value := it.Value() + return nil +} - if string(key) != fmt.Sprintf("key_%d", step) { - t.Fatal(string(key), step) - } - - if string(value) != fmt.Sprintf("value_%d", step) { - t.Fatal(string(value), step) - } - - step++ - } - it.Close() - - if step != 6 { - t.Fatal("invalid step", step) - } - - step = 2 - for it = db.Iterator([]byte("key_5"), []byte("key_2"), 0); it.Valid(); it.Next() { - step++ - } - it.Close() - - if step != 2 { - t.Fatal("must 0") - } - - step = 9 - for it = db.ReverseIterator(nil, nil, 0); it.Valid(); it.Next() { - key := it.Key() - value := it.Value() - - if string(key) != fmt.Sprintf("key_%d", step) { - t.Fatal(string(key), step) - } - - if string(value) != fmt.Sprintf("value_%d", step) { - t.Fatal(string(value), step) - } - - step-- - } - it.Close() - - step = 5 - for it = db.ReverseIterator([]byte("key_5"), nil, 3); it.Valid(); it.Next() { - key := it.Key() - value := it.Value() - - if string(key) != fmt.Sprintf("key_%d", step) { - t.Fatal(string(key), step) - } - - if string(value) != fmt.Sprintf("value_%d", step) { - t.Fatal(string(value), step) - } - - step-- - } - it.Close() - - if step != 2 { - t.Fatal("invalid step", step) - } - - step = 5 - for it = db.ReverseIterator([]byte("key_5"), []byte("key_2"), 0); it.Valid(); it.Next() { - key := it.Key() - value := it.Value() - - if string(key) != fmt.Sprintf("key_%d", step) { - t.Fatal(string(key), step) - } - - if string(value) != fmt.Sprintf("value_%d", step) { - t.Fatal(string(value), step) - } - - step-- - } - it.Close() - - if step != 1 { - t.Fatal("invalid step", step) - } - - step = 5 - for it = db.ReverseIterator([]byte("key_2"), []byte("key_5"), 0); it.Valid(); it.Next() { - step-- - } - it.Close() - - if step != 5 { - t.Fatal("must 5") +func testKeyRange(min int, max int) *Range { + return &Range{ + []byte(fmt.Sprintf("key_%d", min)), + []byte(fmt.Sprintf("key_%d", max)), + false, + false, } } -func TestIterator_2(t *testing.T) { +func testLKeyRange(min int, max int) *Range { + r := testKeyRange(min, max) + r.MinEx = true + return r +} + +func testRKeyRange(min int, max int) *Range { + r := testKeyRange(min, max) + r.MaxEx = true + return r +} + +func testOpenKeyRange(min int, max int) *Range { + r := testKeyRange(min, max) + r.MinEx = true + r.MaxEx = true + return r +} + +func TestIterator(t *testing.T) { db := getTestDB() - for it := db.Iterator(nil, nil, 0); it.Valid(); it.Next() { - db.Delete(it.Key()) + + db.Clear() + + for i := 0; i < 10; i++ { + key := []byte(fmt.Sprintf("key_%d", i)) + value := []byte("") + db.Put(key, value) } - db.Put([]byte("key_1"), []byte("value_1")) - db.Put([]byte("key_7"), []byte("value_9")) - db.Put([]byte("key_9"), []byte("value_9")) + var it *Iterator - it := db.Iterator([]byte("key_0"), []byte("key_8"), 0) - if !it.Valid() { - t.Fatal("must valid") + it = db.Iterator(testKeyRange(1, 5), 0) + if err := checkIterator(it, 1, 2, 3, 4, 5); err != nil { + t.Fatal(err) } - if string(it.Key()) != "key_1" { - t.Fatal(string(it.Key())) + it = db.Iterator(testKeyRange(1, 9), 5) + if err := checkIterator(it, 1, 2, 3, 4, 5); err != nil { + t.Fatal(err) } - it.Close() - - it = db.ReverseIterator([]byte("key_8"), []byte("key_0"), 0) - if !it.Valid() { - t.Fatal("must valid") + it = db.Iterator(testLKeyRange(1, 5), 0) + if err := checkIterator(it, 2, 3, 4, 5); err != nil { + t.Fatal(err) } - if string(it.Key()) != "key_7" { - t.Fatal(string(it.Key())) + it = db.Iterator(testRKeyRange(1, 5), 0) + if err := checkIterator(it, 1, 2, 3, 4); err != nil { + t.Fatal(err) } - it.Close() - - for it = db.Iterator(nil, nil, 0); it.Valid(); it.Next() { - db.Delete(it.Key()) + it = db.Iterator(testOpenKeyRange(1, 5), 0) + if err := checkIterator(it, 2, 3, 4); err != nil { + t.Fatal(err) } - it.Close() - - it = db.Iterator([]byte("key_0"), []byte("key_8"), 0) - if it.Valid() { - t.Fatal("must not valid") + it = db.ReverseIterator(testKeyRange(1, 5), 0) + if err := checkIterator(it, 5, 4, 3, 2, 1); err != nil { + t.Fatal(err) } - it.Close() - - it = db.ReverseIterator([]byte("key_8"), []byte("key_0"), 0) - if it.Valid() { - t.Fatal("must not valid") + it = db.ReverseIterator(testKeyRange(1, 9), 5) + if err := checkIterator(it, 9, 8, 7, 6, 5); err != nil { + t.Fatal(err) } - it.Close() + it = db.ReverseIterator(testLKeyRange(1, 5), 0) + if err := checkIterator(it, 5, 4, 3, 2); err != nil { + t.Fatal(err) + } + + it = db.ReverseIterator(testRKeyRange(1, 5), 0) + if err := checkIterator(it, 4, 3, 2, 1); err != nil { + t.Fatal(err) + } + + it = db.ReverseIterator(testOpenKeyRange(1, 5), 0) + if err := checkIterator(it, 4, 3, 2); err != nil { + t.Fatal(err) + } } func TestSnapshot(t *testing.T) { @@ -331,36 +239,6 @@ func TestSnapshot(t *testing.T) { } else if string(v) != string(value) { t.Fatal(string(v)) } - - found := false - var it *Iterator - for it = s.Iterator(nil, nil, 0); it.Valid(); it.Next() { - if string(it.Key()) == string(key) { - found = true - break - } - } - - it.Close() - - if !found { - t.Fatal("must found") - } - - found = false - for it = s.ReverseIterator(nil, nil, 0); it.Valid(); it.Next() { - if string(it.Key()) == string(key) { - found = true - break - } - } - - it.Close() - - if !found { - t.Fatal("must found") - } - } func TestDestroy(t *testing.T) { diff --git a/leveldb/snapshot.go b/leveldb/snapshot.go index ab7b37a..72eb40f 100644 --- a/leveldb/snapshot.go +++ b/leveldb/snapshot.go @@ -38,12 +38,12 @@ func (s *Snapshot) Get(key []byte) ([]byte, error) { } //same as db iterator and reverse iterator -func (s *Snapshot) Iterator(begin []byte, end []byte, limit int) *Iterator { - return newIterator(s.db, s.iteratorOpts, begin, end, limit, forward) +func (s *Snapshot) Iterator(r *Range, limit int) *Iterator { + return newIterator(s.db, s.iteratorOpts, r, limit, forward) } -func (s *Snapshot) ReverseIterator(rbegin []byte, rend []byte, limit int) *Iterator { - return newIterator(s.db, s.iteratorOpts, rbegin, rend, limit, backward) +func (s *Snapshot) ReverseIterator(r *Range, limit int) *Iterator { + return newIterator(s.db, s.iteratorOpts, r, limit, backward) } func (s *Snapshot) GetInt(key []byte) (int64, error) {