From 5457fe88e6f8df372aecef18570fa1b62c318ad3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 31 May 2020 18:51:43 +0800 Subject: [PATCH] Test Transactions --- finisher_api.go | 12 +++- gorm.go | 11 ++-- tests/main_test.go | 37 +++++++++++ tests/transaction_test.go | 135 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 188 insertions(+), 7 deletions(-) create mode 100644 tests/main_test.go create mode 100644 tests/transaction_test.go diff --git a/finisher_api.go b/finisher_api.go index f14bcfbe..cfbb98c1 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -267,6 +267,16 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { return } +// Pluck used to query single column from a model as a map +// var ages []int64 +// db.Find(&users).Pluck("age", &ages) +func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx := db.getInstance() tx.Error = tx.Statement.Parse(dest) @@ -307,7 +317,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { opt = opts[0] } - if tx.Statement.ConnPool, err = beginner.BeginTx(db.Statement.Context, opt); err != nil { + if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil { tx.AddError(err) } } else { diff --git a/gorm.go b/gorm.go index 70751cb3..ac4bff5e 100644 --- a/gorm.go +++ b/gorm.go @@ -167,15 +167,14 @@ func (db *DB) AddError(err error) error { func (db *DB) getInstance() *DB { if db.clone { - stmt := &Statement{ - DB: db, - ConnPool: db.ConnPool, - Clauses: map[string]clause.Clause{}, - Context: context.Background(), - } + stmt := &Statement{DB: db, Clauses: map[string]clause.Clause{}} if db.Statement != nil { stmt.Context = db.Statement.Context + stmt.ConnPool = db.Statement.ConnPool + } else { + stmt.Context = context.Background() + stmt.ConnPool = db.ConnPool } return &DB{Config: db.Config, Statement: stmt} diff --git a/tests/main_test.go b/tests/main_test.go new file mode 100644 index 00000000..da2003d6 --- /dev/null +++ b/tests/main_test.go @@ -0,0 +1,37 @@ +package tests_test + +import ( + "testing" + + . "github.com/jinzhu/gorm/tests" +) + +func TestExceptionsWithInvalidSql(t *testing.T) { + var columns []string + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + var count1, count2 int64 + DB.Model(&User{}).Count(&count1) + if count1 <= 0 { + t.Errorf("Should find some users") + } + + if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + DB.Model(&User{}).Count(&count2) + if count1 != count2 { + t.Errorf("No user should not be deleted by invalid SQL") + } +} diff --git a/tests/transaction_test.go b/tests/transaction_test.go new file mode 100644 index 00000000..9405fd76 --- /dev/null +++ b/tests/transaction_test.go @@ -0,0 +1,135 @@ +package tests_test + +import ( + "database/sql" + "errors" + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +func TestTransaction(t *testing.T) { + tx := DB.Begin() + user := *GetUser("transcation", Config{}) + + if err := tx.Save(&user).Error; err != nil { + t.Errorf("No error should raise, but got %v", err) + } + + if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record, but got %v", err) + } + + if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil { + t.Errorf("Should return the underlying sql.Tx") + } + + tx.Rollback() + + if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback, but got %v", err) + } + + tx2 := DB.Begin() + user2 := *GetUser("transcation-2", Config{}) + if err := tx2.Save(&user2).Error; err != nil { + t.Errorf("No error should raise, but got %v", err) + } + + if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should find saved record, but got %v", err) + } + + tx2.Commit() + + if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should be able to find committed record, but got %v", err) + } +} + +func TestTransactionWithBlock(t *testing.T) { + assertPanic := func(f func()) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + f() + } + + // rollback + err := DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transcation-block", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should find saved record") + } + + return errors.New("the error message") + }) + + if err.Error() != "the error message" { + t.Errorf("Transaction return error will equal the block returns error") + } + + if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil { + t.Errorf("Should not find record after rollback") + } + + // commit + DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transcation-block-2", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should find saved record") + } + return nil + }) + + if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil { + t.Errorf("Should be able to find committed record") + } + + // panic will rollback + assertPanic(func() { + DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transcation-block-3", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should find saved record") + } + + panic("force panic") + }) + }) + + if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil { + t.Errorf("Should not find record after panic rollback") + } +} + +func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { + tx := DB.Begin() + user := User{Name: "transcation"} + if err := tx.Save(&user).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Commit should not raise error") + } + + if err := tx.Rollback().Error; err == nil { + t.Errorf("Rollback after commit should raise error") + } +}