ledisdb/ledis/tx_test.go

220 lines
3.5 KiB
Go

package ledis
import (
"github.com/siddontang/ledisdb/config"
"os"
"testing"
)
func testTxRollback(t *testing.T, db *DB) {
var err error
key1 := []byte("tx_key1")
key2 := []byte("tx_key2")
field2 := []byte("tx_field2")
err = db.Set(key1, []byte("value"))
if err != nil {
t.Fatal(err)
}
_, err = db.HSet(key2, field2, []byte("value"))
if err != nil {
t.Fatal(err)
}
var tx *Tx
tx, err = db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
err = tx.Set(key1, []byte("1"))
if err != nil {
t.Fatal(err)
}
_, err = tx.HSet(key2, field2, []byte("2"))
if err != nil {
t.Fatal(err)
}
_, err = tx.HSet([]byte("no_key"), field2, []byte("2"))
if err != nil {
t.Fatal(err)
}
if v, err := tx.Get(key1); err != nil {
t.Fatal(err)
} else if string(v) != "1" {
t.Fatal(string(v))
}
if v, err := tx.HGet(key2, field2); err != nil {
t.Fatal(err)
} else if string(v) != "2" {
t.Fatal(string(v))
}
err = tx.Rollback()
if err != nil {
t.Fatal(err)
}
if v, err := db.Get(key1); err != nil {
t.Fatal(err)
} else if string(v) != "value" {
t.Fatal(string(v))
}
if v, err := db.HGet(key2, field2); err != nil {
t.Fatal(err)
} else if string(v) != "value" {
t.Fatal(string(v))
}
}
func testTxCommit(t *testing.T, db *DB) {
var err error
key1 := []byte("tx_key1")
key2 := []byte("tx_key2")
field2 := []byte("tx_field2")
err = db.Set(key1, []byte("value"))
if err != nil {
t.Fatal(err)
}
_, err = db.HSet(key2, field2, []byte("value"))
if err != nil {
t.Fatal(err)
}
var tx *Tx
tx, err = db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
err = tx.Set(key1, []byte("1"))
if err != nil {
t.Fatal(err)
}
_, err = tx.HSet(key2, field2, []byte("2"))
if err != nil {
t.Fatal(err)
}
if v, err := tx.Get(key1); err != nil {
t.Fatal(err)
} else if string(v) != "1" {
t.Fatal(string(v))
}
if v, err := tx.HGet(key2, field2); err != nil {
t.Fatal(err)
} else if string(v) != "2" {
t.Fatal(string(v))
}
err = tx.Commit()
if err != nil {
t.Fatal(err)
}
if v, err := db.Get(key1); err != nil {
t.Fatal(err)
} else if string(v) != "1" {
t.Fatal(string(v))
}
if v, err := db.HGet(key2, field2); err != nil {
t.Fatal(err)
} else if string(v) != "2" {
t.Fatal(string(v))
}
}
func testTxSelect(t *testing.T, db *DB) {
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
tx.Set([]byte("tx_select_1"), []byte("a"))
tx.Select(1)
tx.Set([]byte("tx_select_2"), []byte("b"))
if err = tx.Commit(); err != nil {
t.Fatal(err)
}
if v, err := db.Get([]byte("tx_select_1")); err != nil {
t.Fatal(err)
} else if string(v) != "a" {
t.Fatal(string(v))
}
if v, err := db.Get([]byte("tx_select_2")); err != nil {
t.Fatal(err)
} else if v != nil {
t.Fatal("must nil")
}
db, _ = db.l.Select(1)
if v, err := db.Get([]byte("tx_select_2")); err != nil {
t.Fatal(err)
} else if string(v) != "b" {
t.Fatal(string(v))
}
if v, err := db.Get([]byte("tx_select_1")); err != nil {
t.Fatal(err)
} else if v != nil {
t.Fatal("must nil")
}
}
func testTx(t *testing.T, name string) {
cfg := new(config.Config)
cfg.DataDir = "/tmp/ledis_test_tx"
cfg.DBName = name
cfg.LMDB.MapSize = 10 * 1024 * 1024
os.RemoveAll(cfg.DataDir)
l, err := Open(cfg)
if err != nil {
t.Fatal(err)
}
defer l.Close()
db, _ := l.Select(0)
testTxRollback(t, db)
testTxCommit(t, db)
testTxSelect(t, db)
}
//only lmdb, boltdb support Transaction
func TestTx(t *testing.T) {
testTx(t, "lmdb")
testTx(t, "boltdb")
}