diff --git a/store/mdb/batch.go b/store/mdb/batch.go index d8debf6..d68ee43 100644 --- a/store/mdb/batch.go +++ b/store/mdb/batch.go @@ -15,6 +15,9 @@ func (w *WriteBatch) Close() error { } func (w *WriteBatch) Put(key, value []byte) { + if value == nil { + value = []byte{} + } w.wb = append(w.wb, Write{key, value}) } diff --git a/store/store_test.go b/store/store_test.go index f37c496..c79bddd 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -37,6 +37,16 @@ func testSimple(db *DB, t *testing.T) { } else if v != nil { t.Fatal("must nil") } + + if err := db.Put(key, nil); err != nil { + t.Fatal(err) + } + + if v, err := db.Get(key); err != nil { + t.Fatal(err) + } else if !bytes.Equal(v, []byte{}) { + t.Fatal("must empty") + } } func testBatch(db *DB, t *testing.T) { @@ -80,7 +90,27 @@ func testBatch(db *DB, t *testing.T) { t.Fatal(string(v)) } + wb.Put(key1, nil) + wb.Put(key2, []byte{}) + + if err := wb.Commit(); err != nil { + t.Fatal(err) + } + + if v, err := db.Get(key1); err != nil { + t.Fatal(err) + } else if !bytes.Equal(v, []byte{}) { + t.Fatal("must empty") + } + + if v, err := db.Get(key2); err != nil { + t.Fatal(err) + } else if !bytes.Equal(v, []byte{}) { + t.Fatal("must empty") + } + db.Delete(key1) + db.Delete(key2) } func checkIterator(it *RangeLimitIterator, cv ...int) error {