Test Transactions

This commit is contained in:
Jinzhu 2020-05-31 18:51:43 +08:00
parent ae9e4f1dd8
commit 5457fe88e6
4 changed files with 188 additions and 7 deletions

View File

@ -267,6 +267,16 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
return return
} }
// Pluck used to query single column from a model as a map
// var ages []int64
// db.Find(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
}
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx := db.getInstance() tx := db.getInstance()
tx.Error = tx.Statement.Parse(dest) tx.Error = tx.Statement.Parse(dest)
@ -307,7 +317,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
opt = opts[0] opt = opts[0]
} }
if tx.Statement.ConnPool, err = beginner.BeginTx(db.Statement.Context, opt); err != nil { if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil {
tx.AddError(err) tx.AddError(err)
} }
} else { } else {

11
gorm.go
View File

@ -167,15 +167,14 @@ func (db *DB) AddError(err error) error {
func (db *DB) getInstance() *DB { func (db *DB) getInstance() *DB {
if db.clone { if db.clone {
stmt := &Statement{ stmt := &Statement{DB: db, Clauses: map[string]clause.Clause{}}
DB: db,
ConnPool: db.ConnPool,
Clauses: map[string]clause.Clause{},
Context: context.Background(),
}
if db.Statement != nil { if db.Statement != nil {
stmt.Context = db.Statement.Context stmt.Context = db.Statement.Context
stmt.ConnPool = db.Statement.ConnPool
} else {
stmt.Context = context.Background()
stmt.ConnPool = db.ConnPool
} }
return &DB{Config: db.Config, Statement: stmt} return &DB{Config: db.Config, Statement: stmt}

37
tests/main_test.go Normal file
View File

@ -0,0 +1,37 @@
package tests_test
import (
"testing"
. "github.com/jinzhu/gorm/tests"
)
func TestExceptionsWithInvalidSql(t *testing.T) {
var columns []string
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
t.Errorf("Should got error with invalid SQL")
}
if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
t.Errorf("Should got error with invalid SQL")
}
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil {
t.Errorf("Should got error with invalid SQL")
}
var count1, count2 int64
DB.Model(&User{}).Count(&count1)
if count1 <= 0 {
t.Errorf("Should find some users")
}
if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil {
t.Errorf("Should got error with invalid SQL")
}
DB.Model(&User{}).Count(&count2)
if count1 != count2 {
t.Errorf("No user should not be deleted by invalid SQL")
}
}

135
tests/transaction_test.go Normal file
View File

@ -0,0 +1,135 @@
package tests_test
import (
"database/sql"
"errors"
"testing"
"github.com/jinzhu/gorm"
. "github.com/jinzhu/gorm/tests"
)
func TestTransaction(t *testing.T) {
tx := DB.Begin()
user := *GetUser("transcation", Config{})
if err := tx.Save(&user).Error; err != nil {
t.Errorf("No error should raise, but got %v", err)
}
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
t.Errorf("Should find saved record, but got %v", err)
}
if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil {
t.Errorf("Should return the underlying sql.Tx")
}
tx.Rollback()
if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
t.Errorf("Should not find record after rollback, but got %v", err)
}
tx2 := DB.Begin()
user2 := *GetUser("transcation-2", Config{})
if err := tx2.Save(&user2).Error; err != nil {
t.Errorf("No error should raise, but got %v", err)
}
if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should find saved record, but got %v", err)
}
tx2.Commit()
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should be able to find committed record, but got %v", err)
}
}
func TestTransactionWithBlock(t *testing.T) {
assertPanic := func(f func()) {
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic")
}
}()
f()
}
// rollback
err := DB.Transaction(func(tx *gorm.DB) error {
user := *GetUser("transcation-block", Config{})
if err := tx.Save(&user).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Errorf("Should find saved record")
}
return errors.New("the error message")
})
if err.Error() != "the error message" {
t.Errorf("Transaction return error will equal the block returns error")
}
if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil {
t.Errorf("Should not find record after rollback")
}
// commit
DB.Transaction(func(tx *gorm.DB) error {
user := *GetUser("transcation-block-2", Config{})
if err := tx.Save(&user).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Errorf("Should find saved record")
}
return nil
})
if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil {
t.Errorf("Should be able to find committed record")
}
// panic will rollback
assertPanic(func() {
DB.Transaction(func(tx *gorm.DB) error {
user := *GetUser("transcation-block-3", Config{})
if err := tx.Save(&user).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Errorf("Should find saved record")
}
panic("force panic")
})
})
if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil {
t.Errorf("Should not find record after panic rollback")
}
}
func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) {
tx := DB.Begin()
user := User{Name: "transcation"}
if err := tx.Save(&user).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.Commit().Error; err != nil {
t.Errorf("Commit should not raise error")
}
if err := tx.Rollback().Error; err == nil {
t.Errorf("Rollback after commit should raise error")
}
}