diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e604003 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +data.db \ No newline at end of file diff --git a/buntdb.go b/buntdb.go index 244748e..dcbd3da 100644 --- a/buntdb.go +++ b/buntdb.go @@ -6,6 +6,9 @@ package buntdb import ( "bufio" + "crypto/aes" + "crypto/cipher" + "crypto/rand" "errors" "fmt" "io" @@ -66,19 +69,20 @@ const useAbsEx = true // DB represents a collection of key-value pairs that persist on disk. // Transactions are used for all forms of data access to the DB. type DB struct { - mu sync.RWMutex // the gatekeeper for all fields - file *os.File // the underlying file - buf []byte // a buffer to write to - keys *btree.BTree // a tree of all item ordered by key - exps *btree.BTree // a tree of items ordered by expiration - idxs map[string]*index // the index trees. - insIdxs []*index // a reuse buffer for gathering indexes - flushes int // a count of the number of disk flushes - closed bool // set when the database has been closed - config Config // the database configuration - persist bool // do we write to disk - shrinking bool // when an aof shrink is in-process. - lastaofsz int // the size of the last shrink aof size + mu sync.RWMutex // the gatekeeper for all fields + file *os.File // the underlying file + buf []byte // a buffer to write to + keys *btree.BTree // a tree of all item ordered by key + exps *btree.BTree // a tree of items ordered by expiration + idxs map[string]*index // the index trees. + insIdxs []*index // a reuse buffer for gathering indexes + flushes int // a count of the number of disk flushes + closed bool // set when the database has been closed + config Config // the database configuration + persist bool // do we write to disk + shrinking bool // when an aof shrink is in-process. + lastaofsz int // the size of the last shrink aof size + encryptionKey []byte // store the encryption key } // SyncPolicy represents how often data is synced to disk. @@ -140,8 +144,10 @@ type exctx struct { // Open opens a database at the provided path. // If the file does not exist then it will be created automatically. -func Open(path string) (*DB, error) { - db := &DB{} +func Open(path string, encryptionKey []byte) (*DB, error) { + db := &DB{ + encryptionKey: encryptionKey, + } // initialize trees and indexes db.keys = btreeNew(lessCtx(nil)) db.exps = btreeNew(lessCtx(&exctx{db})) @@ -208,7 +214,7 @@ func (db *DB) Save(wr io.Writer) error { // iterated through every item in the database and write to the buffer btreeAscend(db.keys, func(item interface{}) bool { dbi := item.(*dbItem) - buf = dbi.writeSetTo(buf, now) + buf = dbi.writeSetTo(buf, now, db.encryptionKey) if len(buf) > 1024*1024*4 { // flush when buffer is over 4MB _, err = wr.Write(buf) @@ -246,6 +252,44 @@ func (db *DB) Load(rd io.Reader) error { return err } +func encrypt(data []byte, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + ciphertext := gcm.Seal(nonce, nonce, data, nil) + return ciphertext, nil +} + +func decrypt(data []byte, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return nil, errors.New("ciphertext too short") + } + nonce, ciphertext := data[:nonceSize], data[nonceSize:] + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, err + } + return plaintext, nil +} + // index represents a b-tree or r-tree index and also acts as the // b-tree/r-tree context for itself. type index struct { @@ -698,7 +742,7 @@ func (db *DB) Shrink() error { done = false return false } - buf = dbi.writeSetTo(buf, now) + buf = dbi.writeSetTo(buf, now, db.encryptionKey) n++ return true }, @@ -914,6 +958,14 @@ func (db *DB) readLoad(rd io.Reader, modTime time.Time) (n int64, err error) { exat = time.Unix(ex, 0) } if exat.After(now) { + if db.encryptionKey != nil { + // Decrypt the value with the provided encryption key + plaintext, err := decrypt([]byte(parts[2]), db.encryptionKey) + if err != nil { + return totalSize, err + } + parts[2] = string(plaintext) + } db.insertIntoDatabase(&dbItem{ key: parts[1], val: parts[2], @@ -924,6 +976,14 @@ func (db *DB) readLoad(rd io.Reader, modTime time.Time) (n int64, err error) { }) } } else { + if db.encryptionKey != nil { + // Decrypt the value with the provided encryption key + plaintext, err := decrypt([]byte(parts[2]), db.encryptionKey) + if err != nil { + return totalSize, err + } + parts[2] = string(plaintext) + } db.insertIntoDatabase(&dbItem{key: parts[1], val: parts[2]}) } } else if (parts[0][0] == 'd' || parts[0][0] == 'D') && @@ -1204,7 +1264,7 @@ func (tx *Tx) Commit() error { if item == nil { tx.db.buf = (&dbItem{key: key}).writeDeleteTo(tx.db.buf) } else { - tx.db.buf = item.writeSetTo(tx.db.buf, now) + tx.db.buf = item.writeSetTo(tx.db.buf, now, tx.db.encryptionKey) } } // Flushing the buffer only once per transaction. @@ -1336,23 +1396,44 @@ func appendBulkString(buf []byte, s string) []byte { return buf } -// writeSetTo writes an item as a single SET record to the a bufio Writer. -func (dbi *dbItem) writeSetTo(buf []byte, now time.Time) []byte { +// writeSetTo writes an item as a single SET record to the buffer. +func (dbi *dbItem) writeSetTo(buf []byte, now time.Time, encryptionKey []byte) []byte { if dbi.opts != nil && dbi.opts.ex { - buf = appendArray(buf, 5) - buf = appendBulkString(buf, "set") - buf = appendBulkString(buf, dbi.key) - buf = appendBulkString(buf, dbi.val) + // The item has expiration options, encode them in the value + if encryptionKey != nil { + // Encrypt the value with the provided encryption key + val, err := encrypt([]byte(dbi.val), encryptionKey) + if err != nil { + panic(err) + } + dbi.val = string(val) + } if useAbsEx { ex := dbi.opts.exat.Unix() + buf = appendArray(buf, 5) + buf = appendBulkString(buf, "set") + buf = appendBulkString(buf, dbi.key) + buf = appendBulkString(buf, dbi.val) buf = appendBulkString(buf, "ae") buf = appendBulkString(buf, strconv.FormatUint(uint64(ex), 10)) } else { ex := dbi.opts.exat.Sub(now) / time.Second + buf = appendArray(buf, 5) + buf = appendBulkString(buf, "set") + buf = appendBulkString(buf, dbi.key) + buf = appendBulkString(buf, dbi.val) buf = appendBulkString(buf, "ex") buf = appendBulkString(buf, strconv.FormatUint(uint64(ex), 10)) } } else { + if encryptionKey != nil { + // Encrypt the value with the provided encryption key + val, err := encrypt([]byte(dbi.val), encryptionKey) + if err != nil { + panic(err) + } + dbi.val = string(val) + } buf = appendArray(buf, 3) buf = appendBulkString(buf, "set") buf = appendBulkString(buf, dbi.key) diff --git a/buntdb_test.go b/buntdb_test.go index 4b37db2..7318810 100644 --- a/buntdb_test.go +++ b/buntdb_test.go @@ -34,7 +34,7 @@ func testReOpenDelay(t testing.TB, db *DB, dur time.Duration) *DB { } } time.Sleep(dur) - db, err := Open("data.db") + db, err := Open("data.db", []byte("supersecretkey16")) if err != nil { t.Fatal(err) } @@ -93,7 +93,7 @@ func TestBackgroudOperations(t *testing.T) { } } func TestSaveLoad(t *testing.T) { - db, _ := Open(":memory:") + db, _ := Open(":memory:", []byte("supersecretkey16")) defer db.Close() if err := db.Update(func(tx *Tx) error { for i := 0; i < 20; i++ { @@ -121,7 +121,7 @@ func TestSaveLoad(t *testing.T) { t.Fatal(err) } db.Close() - db, _ = Open(":memory:") + db, _ = Open(":memory:", []byte("supersecretkey16")) defer db.Close() f, err = os.Open("temp.db") if err != nil { @@ -1027,7 +1027,7 @@ func TestVariousTx(t *testing.T) { func TestNearby(t *testing.T) { rand.Seed(time.Now().UnixNano()) N := 100000 - db, _ := Open(":memory:") + db, _ := Open(":memory:", []byte("supersecretkey16")) db.CreateSpatialIndex("points", "*", IndexRect) db.Update(func(tx *Tx) error { for i := 0; i < N; i++ { @@ -1065,7 +1065,7 @@ func TestNearby(t *testing.T) { } func Example_descKeys() { - db, _ := Open(":memory:") + db, _ := Open(":memory:", []byte("supersecretkey16")) db.CreateIndex("name", "*", IndexString) db.Update(func(tx *Tx) error { tx.Set("user:100:first", "Tom", nil) @@ -1127,7 +1127,7 @@ func Example_descKeys() { } func ExampleDesc() { - db, _ := Open(":memory:") + db, _ := Open(":memory:", []byte("supersecretkey16")) db.CreateIndex("last_name_age", "*", IndexJSON("name.last"), Desc(IndexJSON("age"))) db.Update(func(tx *Tx) error { tx.Set("1", `{"name":{"first":"Tom","last":"Johnson"},"age":38}`, nil) @@ -1156,7 +1156,7 @@ func ExampleDesc() { } func ExampleDB_CreateIndex_jSON() { - db, _ := Open(":memory:") + db, _ := Open(":memory:", []byte("supersecretkey16")) db.CreateIndex("last_name", "*", IndexJSON("name.last")) db.CreateIndex("age", "*", IndexJSON("age")) db.Update(func(tx *Tx) error { @@ -1202,7 +1202,7 @@ func ExampleDB_CreateIndex_jSON() { } func ExampleDB_CreateIndex_strings() { - db, _ := Open(":memory:") + db, _ := Open(":memory:", []byte("supersecretkey16")) db.CreateIndex("name", "*", IndexString) db.Update(func(tx *Tx) error { tx.Set("1", "Tom", nil) @@ -1231,7 +1231,7 @@ func ExampleDB_CreateIndex_strings() { } func ExampleDB_CreateIndex_ints() { - db, _ := Open(":memory:") + db, _ := Open(":memory:", []byte("supersecretkey16")) db.CreateIndex("age", "*", IndexInt) db.Update(func(tx *Tx) error { tx.Set("1", "30", nil) @@ -1259,7 +1259,7 @@ func ExampleDB_CreateIndex_ints() { //4: 76 } func ExampleDB_CreateIndex_multipleFields() { - db, _ := Open(":memory:") + db, _ := Open(":memory:", []byte("supersecretkey16")) db.CreateIndex("last_name_age", "*", IndexJSON("name.last"), IndexJSON("age")) db.Update(func(tx *Tx) error { tx.Set("1", `{"name":{"first":"Tom","last":"Johnson"},"age":38}`, nil) @@ -1373,7 +1373,7 @@ func TestDatabaseFormat(t *testing.T) { } defer os.RemoveAll("data.db") - db, err := Open("data.db") + db, err := Open("data.db", []byte("supersecretkey16")) if err == nil { if do != nil { if err := do(db); err != nil { @@ -1596,7 +1596,7 @@ func TestOpeningAFolder(t *testing.T) { t.Fatal(err) } defer func() { _ = os.RemoveAll("dir.tmp") }() - db, err := Open("dir.tmp") + db, err := Open("dir.tmp", []byte("supersecretkey16")) if err == nil { if err := db.Close(); err != nil { t.Fatal(err) @@ -1614,7 +1614,7 @@ func TestOpeningInvalidDatabaseFile(t *testing.T) { t.Fatal(err) } defer func() { _ = os.RemoveAll("data.db") }() - db, err := Open("data.db") + db, err := Open("data.db", []byte("supersecretkey16")) if err == nil { if err := db.Close(); err != nil { t.Fatal(err) @@ -1628,7 +1628,7 @@ func TestOpeningClosedDatabase(t *testing.T) { if err := os.RemoveAll("data.db"); err != nil { t.Fatal(err) } - db, err := Open("data.db") + db, err := Open("data.db", []byte("supersecretkey16")) if err != nil { t.Fatal(err) } @@ -1639,7 +1639,7 @@ func TestOpeningClosedDatabase(t *testing.T) { if err := db.Close(); err != ErrDatabaseClosed { t.Fatal("should not be able to close a closed database") } - db, err = Open(":memory:") + db, err = Open(":memory:", []byte("supersecretkey16")) if err != nil { t.Fatal(err) } @@ -1716,7 +1716,7 @@ func TestShrink(t *testing.T) { t.Fatal("shrink on a closed databse should not be allowed") } // Now we will open a db that does not persist - db, err = Open(":memory:") + db, err = Open(":memory:", []byte("supersecretkey16")) if err != nil { t.Fatal(err) } @@ -2295,9 +2295,9 @@ func benchOpenFillData(t *testing.B, N int, if err := os.RemoveAll("data.db"); err != nil { t.Fatal(err) } - db, err = Open("data.db") + db, err = Open("data.db", []byte("supersecretkey16")) } else { - db, err = Open(":memory:") + db, err = Open(":memory:", []byte("supersecretkey16")) } if err != nil { t.Fatal(err) @@ -2499,10 +2499,10 @@ func Benchmark_Descend_10000(t *testing.B) { } /* -func Benchmark_Spatial_2D(t *testing.B) { - N := 100000 - db, _, _ := benchOpenFillData(t, N, true, true, false, true, 100) - defer benchClose(t, false, db) + func Benchmark_Spatial_2D(t *testing.B) { + N := 100000 + db, _, _ := benchOpenFillData(t, N, true, true, false, true, 100) + defer benchClose(t, false, db) } */ @@ -2756,7 +2756,7 @@ func TestTransactionLeak(t *testing.T) { // This tests an bug identified in Issue #69. When inside a Update // transaction, a Set after a Delete for a key that previously exists will // remove the key when the transaction was rolledback. - buntDB, err := Open(":memory:") + buntDB, err := Open(":memory:", []byte("supersecretkey16")) if err != nil { t.Fatal(err) } @@ -2841,7 +2841,7 @@ func TestReloadNotInvalid(t *testing.T) { ii := 0 for time.Since(start) < time.Second*5 { func() { - db, err := Open("data.db") + db, err := Open("data.db", []byte("supersecretkey16")) if err != nil { t.Fatal(err) }