encryption

Signed-off-by: Finbarrs Oketunji <f@finbarrs.eu>
This commit is contained in:
Finbarrs Oketunji 2024-04-01 21:18:33 +01:00
parent 4ac2e321b1
commit 83142e3842
No known key found for this signature in database
GPG Key ID: D9127942BDC44778
3 changed files with 130 additions and 48 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
data.db

129
buntdb.go
View File

@ -6,6 +6,9 @@ package buntdb
import ( import (
"bufio" "bufio"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -66,19 +69,20 @@ const useAbsEx = true
// DB represents a collection of key-value pairs that persist on disk. // DB represents a collection of key-value pairs that persist on disk.
// Transactions are used for all forms of data access to the DB. // Transactions are used for all forms of data access to the DB.
type DB struct { type DB struct {
mu sync.RWMutex // the gatekeeper for all fields mu sync.RWMutex // the gatekeeper for all fields
file *os.File // the underlying file file *os.File // the underlying file
buf []byte // a buffer to write to buf []byte // a buffer to write to
keys *btree.BTree // a tree of all item ordered by key keys *btree.BTree // a tree of all item ordered by key
exps *btree.BTree // a tree of items ordered by expiration exps *btree.BTree // a tree of items ordered by expiration
idxs map[string]*index // the index trees. idxs map[string]*index // the index trees.
insIdxs []*index // a reuse buffer for gathering indexes insIdxs []*index // a reuse buffer for gathering indexes
flushes int // a count of the number of disk flushes flushes int // a count of the number of disk flushes
closed bool // set when the database has been closed closed bool // set when the database has been closed
config Config // the database configuration config Config // the database configuration
persist bool // do we write to disk persist bool // do we write to disk
shrinking bool // when an aof shrink is in-process. shrinking bool // when an aof shrink is in-process.
lastaofsz int // the size of the last shrink aof size lastaofsz int // the size of the last shrink aof size
encryptionKey []byte // store the encryption key
} }
// SyncPolicy represents how often data is synced to disk. // SyncPolicy represents how often data is synced to disk.
@ -140,8 +144,10 @@ type exctx struct {
// Open opens a database at the provided path. // Open opens a database at the provided path.
// If the file does not exist then it will be created automatically. // If the file does not exist then it will be created automatically.
func Open(path string) (*DB, error) { func Open(path string, encryptionKey []byte) (*DB, error) {
db := &DB{} db := &DB{
encryptionKey: encryptionKey,
}
// initialize trees and indexes // initialize trees and indexes
db.keys = btreeNew(lessCtx(nil)) db.keys = btreeNew(lessCtx(nil))
db.exps = btreeNew(lessCtx(&exctx{db})) db.exps = btreeNew(lessCtx(&exctx{db}))
@ -208,7 +214,7 @@ func (db *DB) Save(wr io.Writer) error {
// iterated through every item in the database and write to the buffer // iterated through every item in the database and write to the buffer
btreeAscend(db.keys, func(item interface{}) bool { btreeAscend(db.keys, func(item interface{}) bool {
dbi := item.(*dbItem) dbi := item.(*dbItem)
buf = dbi.writeSetTo(buf, now) buf = dbi.writeSetTo(buf, now, db.encryptionKey)
if len(buf) > 1024*1024*4 { if len(buf) > 1024*1024*4 {
// flush when buffer is over 4MB // flush when buffer is over 4MB
_, err = wr.Write(buf) _, err = wr.Write(buf)
@ -246,6 +252,44 @@ func (db *DB) Load(rd io.Reader) error {
return err return err
} }
func encrypt(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
ciphertext := gcm.Seal(nonce, nonce, data, nil)
return ciphertext, nil
}
func decrypt(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, errors.New("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
// index represents a b-tree or r-tree index and also acts as the // index represents a b-tree or r-tree index and also acts as the
// b-tree/r-tree context for itself. // b-tree/r-tree context for itself.
type index struct { type index struct {
@ -698,7 +742,7 @@ func (db *DB) Shrink() error {
done = false done = false
return false return false
} }
buf = dbi.writeSetTo(buf, now) buf = dbi.writeSetTo(buf, now, db.encryptionKey)
n++ n++
return true return true
}, },
@ -914,6 +958,14 @@ func (db *DB) readLoad(rd io.Reader, modTime time.Time) (n int64, err error) {
exat = time.Unix(ex, 0) exat = time.Unix(ex, 0)
} }
if exat.After(now) { if exat.After(now) {
if db.encryptionKey != nil {
// Decrypt the value with the provided encryption key
plaintext, err := decrypt([]byte(parts[2]), db.encryptionKey)
if err != nil {
return totalSize, err
}
parts[2] = string(plaintext)
}
db.insertIntoDatabase(&dbItem{ db.insertIntoDatabase(&dbItem{
key: parts[1], key: parts[1],
val: parts[2], val: parts[2],
@ -924,6 +976,14 @@ func (db *DB) readLoad(rd io.Reader, modTime time.Time) (n int64, err error) {
}) })
} }
} else { } else {
if db.encryptionKey != nil {
// Decrypt the value with the provided encryption key
plaintext, err := decrypt([]byte(parts[2]), db.encryptionKey)
if err != nil {
return totalSize, err
}
parts[2] = string(plaintext)
}
db.insertIntoDatabase(&dbItem{key: parts[1], val: parts[2]}) db.insertIntoDatabase(&dbItem{key: parts[1], val: parts[2]})
} }
} else if (parts[0][0] == 'd' || parts[0][0] == 'D') && } else if (parts[0][0] == 'd' || parts[0][0] == 'D') &&
@ -1204,7 +1264,7 @@ func (tx *Tx) Commit() error {
if item == nil { if item == nil {
tx.db.buf = (&dbItem{key: key}).writeDeleteTo(tx.db.buf) tx.db.buf = (&dbItem{key: key}).writeDeleteTo(tx.db.buf)
} else { } else {
tx.db.buf = item.writeSetTo(tx.db.buf, now) tx.db.buf = item.writeSetTo(tx.db.buf, now, tx.db.encryptionKey)
} }
} }
// Flushing the buffer only once per transaction. // Flushing the buffer only once per transaction.
@ -1336,23 +1396,44 @@ func appendBulkString(buf []byte, s string) []byte {
return buf return buf
} }
// writeSetTo writes an item as a single SET record to the a bufio Writer. // writeSetTo writes an item as a single SET record to the buffer.
func (dbi *dbItem) writeSetTo(buf []byte, now time.Time) []byte { func (dbi *dbItem) writeSetTo(buf []byte, now time.Time, encryptionKey []byte) []byte {
if dbi.opts != nil && dbi.opts.ex { if dbi.opts != nil && dbi.opts.ex {
buf = appendArray(buf, 5) // The item has expiration options, encode them in the value
buf = appendBulkString(buf, "set") if encryptionKey != nil {
buf = appendBulkString(buf, dbi.key) // Encrypt the value with the provided encryption key
buf = appendBulkString(buf, dbi.val) val, err := encrypt([]byte(dbi.val), encryptionKey)
if err != nil {
panic(err)
}
dbi.val = string(val)
}
if useAbsEx { if useAbsEx {
ex := dbi.opts.exat.Unix() ex := dbi.opts.exat.Unix()
buf = appendArray(buf, 5)
buf = appendBulkString(buf, "set")
buf = appendBulkString(buf, dbi.key)
buf = appendBulkString(buf, dbi.val)
buf = appendBulkString(buf, "ae") buf = appendBulkString(buf, "ae")
buf = appendBulkString(buf, strconv.FormatUint(uint64(ex), 10)) buf = appendBulkString(buf, strconv.FormatUint(uint64(ex), 10))
} else { } else {
ex := dbi.opts.exat.Sub(now) / time.Second ex := dbi.opts.exat.Sub(now) / time.Second
buf = appendArray(buf, 5)
buf = appendBulkString(buf, "set")
buf = appendBulkString(buf, dbi.key)
buf = appendBulkString(buf, dbi.val)
buf = appendBulkString(buf, "ex") buf = appendBulkString(buf, "ex")
buf = appendBulkString(buf, strconv.FormatUint(uint64(ex), 10)) buf = appendBulkString(buf, strconv.FormatUint(uint64(ex), 10))
} }
} else { } else {
if encryptionKey != nil {
// Encrypt the value with the provided encryption key
val, err := encrypt([]byte(dbi.val), encryptionKey)
if err != nil {
panic(err)
}
dbi.val = string(val)
}
buf = appendArray(buf, 3) buf = appendArray(buf, 3)
buf = appendBulkString(buf, "set") buf = appendBulkString(buf, "set")
buf = appendBulkString(buf, dbi.key) buf = appendBulkString(buf, dbi.key)

View File

@ -34,7 +34,7 @@ func testReOpenDelay(t testing.TB, db *DB, dur time.Duration) *DB {
} }
} }
time.Sleep(dur) time.Sleep(dur)
db, err := Open("data.db") db, err := Open("data.db", []byte("supersecretkey16"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -93,7 +93,7 @@ func TestBackgroudOperations(t *testing.T) {
} }
} }
func TestSaveLoad(t *testing.T) { func TestSaveLoad(t *testing.T) {
db, _ := Open(":memory:") db, _ := Open(":memory:", []byte("supersecretkey16"))
defer db.Close() defer db.Close()
if err := db.Update(func(tx *Tx) error { if err := db.Update(func(tx *Tx) error {
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
@ -121,7 +121,7 @@ func TestSaveLoad(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
db.Close() db.Close()
db, _ = Open(":memory:") db, _ = Open(":memory:", []byte("supersecretkey16"))
defer db.Close() defer db.Close()
f, err = os.Open("temp.db") f, err = os.Open("temp.db")
if err != nil { if err != nil {
@ -1027,7 +1027,7 @@ func TestVariousTx(t *testing.T) {
func TestNearby(t *testing.T) { func TestNearby(t *testing.T) {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
N := 100000 N := 100000
db, _ := Open(":memory:") db, _ := Open(":memory:", []byte("supersecretkey16"))
db.CreateSpatialIndex("points", "*", IndexRect) db.CreateSpatialIndex("points", "*", IndexRect)
db.Update(func(tx *Tx) error { db.Update(func(tx *Tx) error {
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
@ -1065,7 +1065,7 @@ func TestNearby(t *testing.T) {
} }
func Example_descKeys() { func Example_descKeys() {
db, _ := Open(":memory:") db, _ := Open(":memory:", []byte("supersecretkey16"))
db.CreateIndex("name", "*", IndexString) db.CreateIndex("name", "*", IndexString)
db.Update(func(tx *Tx) error { db.Update(func(tx *Tx) error {
tx.Set("user:100:first", "Tom", nil) tx.Set("user:100:first", "Tom", nil)
@ -1127,7 +1127,7 @@ func Example_descKeys() {
} }
func ExampleDesc() { func ExampleDesc() {
db, _ := Open(":memory:") db, _ := Open(":memory:", []byte("supersecretkey16"))
db.CreateIndex("last_name_age", "*", IndexJSON("name.last"), Desc(IndexJSON("age"))) db.CreateIndex("last_name_age", "*", IndexJSON("name.last"), Desc(IndexJSON("age")))
db.Update(func(tx *Tx) error { db.Update(func(tx *Tx) error {
tx.Set("1", `{"name":{"first":"Tom","last":"Johnson"},"age":38}`, nil) tx.Set("1", `{"name":{"first":"Tom","last":"Johnson"},"age":38}`, nil)
@ -1156,7 +1156,7 @@ func ExampleDesc() {
} }
func ExampleDB_CreateIndex_jSON() { func ExampleDB_CreateIndex_jSON() {
db, _ := Open(":memory:") db, _ := Open(":memory:", []byte("supersecretkey16"))
db.CreateIndex("last_name", "*", IndexJSON("name.last")) db.CreateIndex("last_name", "*", IndexJSON("name.last"))
db.CreateIndex("age", "*", IndexJSON("age")) db.CreateIndex("age", "*", IndexJSON("age"))
db.Update(func(tx *Tx) error { db.Update(func(tx *Tx) error {
@ -1202,7 +1202,7 @@ func ExampleDB_CreateIndex_jSON() {
} }
func ExampleDB_CreateIndex_strings() { func ExampleDB_CreateIndex_strings() {
db, _ := Open(":memory:") db, _ := Open(":memory:", []byte("supersecretkey16"))
db.CreateIndex("name", "*", IndexString) db.CreateIndex("name", "*", IndexString)
db.Update(func(tx *Tx) error { db.Update(func(tx *Tx) error {
tx.Set("1", "Tom", nil) tx.Set("1", "Tom", nil)
@ -1231,7 +1231,7 @@ func ExampleDB_CreateIndex_strings() {
} }
func ExampleDB_CreateIndex_ints() { func ExampleDB_CreateIndex_ints() {
db, _ := Open(":memory:") db, _ := Open(":memory:", []byte("supersecretkey16"))
db.CreateIndex("age", "*", IndexInt) db.CreateIndex("age", "*", IndexInt)
db.Update(func(tx *Tx) error { db.Update(func(tx *Tx) error {
tx.Set("1", "30", nil) tx.Set("1", "30", nil)
@ -1259,7 +1259,7 @@ func ExampleDB_CreateIndex_ints() {
//4: 76 //4: 76
} }
func ExampleDB_CreateIndex_multipleFields() { func ExampleDB_CreateIndex_multipleFields() {
db, _ := Open(":memory:") db, _ := Open(":memory:", []byte("supersecretkey16"))
db.CreateIndex("last_name_age", "*", IndexJSON("name.last"), IndexJSON("age")) db.CreateIndex("last_name_age", "*", IndexJSON("name.last"), IndexJSON("age"))
db.Update(func(tx *Tx) error { db.Update(func(tx *Tx) error {
tx.Set("1", `{"name":{"first":"Tom","last":"Johnson"},"age":38}`, nil) tx.Set("1", `{"name":{"first":"Tom","last":"Johnson"},"age":38}`, nil)
@ -1373,7 +1373,7 @@ func TestDatabaseFormat(t *testing.T) {
} }
defer os.RemoveAll("data.db") defer os.RemoveAll("data.db")
db, err := Open("data.db") db, err := Open("data.db", []byte("supersecretkey16"))
if err == nil { if err == nil {
if do != nil { if do != nil {
if err := do(db); err != nil { if err := do(db); err != nil {
@ -1596,7 +1596,7 @@ func TestOpeningAFolder(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
defer func() { _ = os.RemoveAll("dir.tmp") }() defer func() { _ = os.RemoveAll("dir.tmp") }()
db, err := Open("dir.tmp") db, err := Open("dir.tmp", []byte("supersecretkey16"))
if err == nil { if err == nil {
if err := db.Close(); err != nil { if err := db.Close(); err != nil {
t.Fatal(err) t.Fatal(err)
@ -1614,7 +1614,7 @@ func TestOpeningInvalidDatabaseFile(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
defer func() { _ = os.RemoveAll("data.db") }() defer func() { _ = os.RemoveAll("data.db") }()
db, err := Open("data.db") db, err := Open("data.db", []byte("supersecretkey16"))
if err == nil { if err == nil {
if err := db.Close(); err != nil { if err := db.Close(); err != nil {
t.Fatal(err) t.Fatal(err)
@ -1628,7 +1628,7 @@ func TestOpeningClosedDatabase(t *testing.T) {
if err := os.RemoveAll("data.db"); err != nil { if err := os.RemoveAll("data.db"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
db, err := Open("data.db") db, err := Open("data.db", []byte("supersecretkey16"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1639,7 +1639,7 @@ func TestOpeningClosedDatabase(t *testing.T) {
if err := db.Close(); err != ErrDatabaseClosed { if err := db.Close(); err != ErrDatabaseClosed {
t.Fatal("should not be able to close a closed database") t.Fatal("should not be able to close a closed database")
} }
db, err = Open(":memory:") db, err = Open(":memory:", []byte("supersecretkey16"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1716,7 +1716,7 @@ func TestShrink(t *testing.T) {
t.Fatal("shrink on a closed databse should not be allowed") t.Fatal("shrink on a closed databse should not be allowed")
} }
// Now we will open a db that does not persist // Now we will open a db that does not persist
db, err = Open(":memory:") db, err = Open(":memory:", []byte("supersecretkey16"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -2295,9 +2295,9 @@ func benchOpenFillData(t *testing.B, N int,
if err := os.RemoveAll("data.db"); err != nil { if err := os.RemoveAll("data.db"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
db, err = Open("data.db") db, err = Open("data.db", []byte("supersecretkey16"))
} else { } else {
db, err = Open(":memory:") db, err = Open(":memory:", []byte("supersecretkey16"))
} }
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -2499,10 +2499,10 @@ func Benchmark_Descend_10000(t *testing.B) {
} }
/* /*
func Benchmark_Spatial_2D(t *testing.B) { func Benchmark_Spatial_2D(t *testing.B) {
N := 100000 N := 100000
db, _, _ := benchOpenFillData(t, N, true, true, false, true, 100) db, _, _ := benchOpenFillData(t, N, true, true, false, true, 100)
defer benchClose(t, false, db) defer benchClose(t, false, db)
} }
*/ */
@ -2756,7 +2756,7 @@ func TestTransactionLeak(t *testing.T) {
// This tests an bug identified in Issue #69. When inside a Update // This tests an bug identified in Issue #69. When inside a Update
// transaction, a Set after a Delete for a key that previously exists will // transaction, a Set after a Delete for a key that previously exists will
// remove the key when the transaction was rolledback. // remove the key when the transaction was rolledback.
buntDB, err := Open(":memory:") buntDB, err := Open(":memory:", []byte("supersecretkey16"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -2841,7 +2841,7 @@ func TestReloadNotInvalid(t *testing.T) {
ii := 0 ii := 0
for time.Since(start) < time.Second*5 { for time.Since(start) < time.Second*5 {
func() { func() {
db, err := Open("data.db") db, err := Open("data.db", []byte("supersecretkey16"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }