diff --git a/store/driver/batch.go b/store/driver/batch.go index 33c07ed..5a47c1b 100644 --- a/store/driver/batch.go +++ b/store/driver/batch.go @@ -22,6 +22,9 @@ type WriteBatch struct { } func (wb *WriteBatch) Put(key, value []byte) { + if value == nil { + value = []byte{} + } wb.wb = append(wb.wb, Write{key, value}) } @@ -46,9 +49,9 @@ func (wb *WriteBatch) Data() []byte { wb.d.Reset() for _, w := range wb.wb { if w.Value == nil { - wb.Delete(w.Key) + wb.d.Delete(w.Key) } else { - wb.Put(w.Key, w.Value) + wb.d.Put(w.Key, w.Value) } } return wb.d.Dump() diff --git a/store/store_test.go b/store/store_test.go index 7626d5b..edb20bb 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -6,6 +6,7 @@ import ( "github.com/siddontang/ledisdb/config" "github.com/siddontang/ledisdb/store/driver" "os" + "reflect" "testing" ) @@ -38,6 +39,7 @@ func testStore(db *DB, t *testing.T) { testBatch(db, t) testIterator(db, t) testSnapshot(db, t) + testBatchData(db, t) } func testClear(db *DB, t *testing.T) { @@ -342,3 +344,52 @@ func testSnapshot(db *DB, t *testing.T) { } } + +func testBatchData(db *DB, t *testing.T) { + w := db.NewWriteBatch() + + w.Put([]byte("a"), []byte("1")) + w.Put([]byte("b"), nil) + w.Delete([]byte("c")) + + d, err := w.Data() + if err != nil { + t.Fatal(err) + } + + if kvs, err := d.Items(); err != nil { + t.Fatal(err) + } else if len(kvs) != 3 { + t.Fatal(len(kvs)) + } else if !reflect.DeepEqual(kvs[0], BatchItem{[]byte("a"), []byte("1")}) { + t.Fatal("must equal") + } else if !reflect.DeepEqual(kvs[1], BatchItem{[]byte("b"), []byte{}}) { + t.Fatal("must equal") + } else if !reflect.DeepEqual(kvs[2], BatchItem{[]byte("c"), nil}) { + t.Fatal("must equal") + } + + if err := d.Append(d); err != nil { + t.Fatal(err) + } else if d.Len() != 6 { + t.Fatal(d.Len()) + } + + if kvs, err := d.Items(); err != nil { + t.Fatal(err) + } else if len(kvs) != 6 { + t.Fatal(len(kvs)) + } else if !reflect.DeepEqual(kvs[0], BatchItem{[]byte("a"), []byte("1")}) { + t.Fatal("must equal") + } else if !reflect.DeepEqual(kvs[1], BatchItem{[]byte("b"), []byte{}}) { + t.Fatal("must equal") + } else if !reflect.DeepEqual(kvs[2], BatchItem{[]byte("c"), nil}) { + t.Fatal("must equal") + } else if !reflect.DeepEqual(kvs[3], BatchItem{[]byte("a"), []byte("1")}) { + t.Fatal("must equal") + } else if !reflect.DeepEqual(kvs[4], BatchItem{[]byte("b"), []byte{}}) { + t.Fatal("must equal") + } else if !reflect.DeepEqual(kvs[5], BatchItem{[]byte("c"), nil}) { + t.Fatal("must equal") + } +} diff --git a/store/writebatch.go b/store/writebatch.go index a7fd84a..134ad42 100644 --- a/store/writebatch.go +++ b/store/writebatch.go @@ -1,6 +1,8 @@ package store import ( + "encoding/binary" + "github.com/siddontang/goleveldb/leveldb" "github.com/siddontang/ledisdb/store/driver" "time" ) @@ -50,3 +52,78 @@ func (wb *WriteBatch) Rollback() error { return wb.wb.Rollback() } + +func (wb *WriteBatch) Data() (*BatchData, error) { + data := wb.wb.Data() + return NewBatchData(data) +} + +const BatchDataHeadLen = 12 + +/* + see leveldb batch data format for more information +*/ + +type BatchData struct { + leveldb.Batch +} + +func NewBatchData(data []byte) (*BatchData, error) { + b := new(BatchData) + + if err := b.Load(data); err != nil { + return nil, err + } + + return b, nil +} + +func (d *BatchData) Append(do *BatchData) error { + d1 := d.Dump() + d2 := do.Dump() + + n := d.Len() + do.Len() + + binary.LittleEndian.PutUint32(d1[8:], uint32(n)) + d1 = append(d1, d2[BatchDataHeadLen:]...) + + return d.Load(d1) +} + +func (d *BatchData) Data() []byte { + return d.Dump() +} + +type BatchDataReplay interface { + Put(key, value []byte) + Delete(key []byte) +} + +type BatchItem struct { + Key []byte + Value []byte +} + +type batchItems []BatchItem + +func (bs *batchItems) Put(key, value []byte) { + *bs = append(*bs, BatchItem{key, value}) +} + +func (bs *batchItems) Delete(key []byte) { + *bs = append(*bs, BatchItem{key, nil}) +} + +func (d *BatchData) Replay(r BatchDataReplay) error { + return d.Batch.Replay(r) +} + +func (d *BatchData) Items() ([]BatchItem, error) { + is := make(batchItems, 0, d.Len()) + + if err := d.Replay(&is); err != nil { + return nil, err + } + + return []BatchItem(is), nil +}