mirror of https://github.com/go-gorm/gorm.git
Test Transactions
This commit is contained in:
parent
ae9e4f1dd8
commit
5457fe88e6
|
@ -267,6 +267,16 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
|||
return
|
||||
}
|
||||
|
||||
// Pluck used to query single column from a model as a map
|
||||
// var ages []int64
|
||||
// db.Find(&users).Pluck("age", &ages)
|
||||
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||
tx := db.getInstance()
|
||||
tx.Error = tx.Statement.Parse(dest)
|
||||
|
@ -307,7 +317,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
|
|||
opt = opts[0]
|
||||
}
|
||||
|
||||
if tx.Statement.ConnPool, err = beginner.BeginTx(db.Statement.Context, opt); err != nil {
|
||||
if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil {
|
||||
tx.AddError(err)
|
||||
}
|
||||
} else {
|
||||
|
|
11
gorm.go
11
gorm.go
|
@ -167,15 +167,14 @@ func (db *DB) AddError(err error) error {
|
|||
|
||||
func (db *DB) getInstance() *DB {
|
||||
if db.clone {
|
||||
stmt := &Statement{
|
||||
DB: db,
|
||||
ConnPool: db.ConnPool,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Context: context.Background(),
|
||||
}
|
||||
stmt := &Statement{DB: db, Clauses: map[string]clause.Clause{}}
|
||||
|
||||
if db.Statement != nil {
|
||||
stmt.Context = db.Statement.Context
|
||||
stmt.ConnPool = db.Statement.ConnPool
|
||||
} else {
|
||||
stmt.Context = context.Background()
|
||||
stmt.ConnPool = db.ConnPool
|
||||
}
|
||||
|
||||
return &DB{Config: db.Config, Statement: stmt}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
func TestExceptionsWithInvalidSql(t *testing.T) {
|
||||
var columns []string
|
||||
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||
t.Errorf("Should got error with invalid SQL")
|
||||
}
|
||||
|
||||
if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||
t.Errorf("Should got error with invalid SQL")
|
||||
}
|
||||
|
||||
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil {
|
||||
t.Errorf("Should got error with invalid SQL")
|
||||
}
|
||||
|
||||
var count1, count2 int64
|
||||
DB.Model(&User{}).Count(&count1)
|
||||
if count1 <= 0 {
|
||||
t.Errorf("Should find some users")
|
||||
}
|
||||
|
||||
if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil {
|
||||
t.Errorf("Should got error with invalid SQL")
|
||||
}
|
||||
|
||||
DB.Model(&User{}).Count(&count2)
|
||||
if count1 != count2 {
|
||||
t.Errorf("No user should not be deleted by invalid SQL")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,135 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
. "github.com/jinzhu/gorm/tests"
|
||||
)
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
tx := DB.Begin()
|
||||
user := *GetUser("transcation", Config{})
|
||||
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
t.Errorf("No error should raise, but got %v", err)
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
|
||||
t.Errorf("Should find saved record, but got %v", err)
|
||||
}
|
||||
|
||||
if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil {
|
||||
t.Errorf("Should return the underlying sql.Tx")
|
||||
}
|
||||
|
||||
tx.Rollback()
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||
t.Errorf("Should not find record after rollback, but got %v", err)
|
||||
}
|
||||
|
||||
tx2 := DB.Begin()
|
||||
user2 := *GetUser("transcation-2", Config{})
|
||||
if err := tx2.Save(&user2).Error; err != nil {
|
||||
t.Errorf("No error should raise, but got %v", err)
|
||||
}
|
||||
|
||||
if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||
t.Errorf("Should find saved record, but got %v", err)
|
||||
}
|
||||
|
||||
tx2.Commit()
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||
t.Errorf("Should be able to find committed record, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionWithBlock(t *testing.T) {
|
||||
assertPanic := func(f func()) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("The code did not panic")
|
||||
}
|
||||
}()
|
||||
f()
|
||||
}
|
||||
|
||||
// rollback
|
||||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
user := *GetUser("transcation-block", Config{})
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
t.Errorf("No error should raise")
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||
t.Errorf("Should find saved record")
|
||||
}
|
||||
|
||||
return errors.New("the error message")
|
||||
})
|
||||
|
||||
if err.Error() != "the error message" {
|
||||
t.Errorf("Transaction return error will equal the block returns error")
|
||||
}
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil {
|
||||
t.Errorf("Should not find record after rollback")
|
||||
}
|
||||
|
||||
// commit
|
||||
DB.Transaction(func(tx *gorm.DB) error {
|
||||
user := *GetUser("transcation-block-2", Config{})
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
t.Errorf("No error should raise")
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||
t.Errorf("Should find saved record")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil {
|
||||
t.Errorf("Should be able to find committed record")
|
||||
}
|
||||
|
||||
// panic will rollback
|
||||
assertPanic(func() {
|
||||
DB.Transaction(func(tx *gorm.DB) error {
|
||||
user := *GetUser("transcation-block-3", Config{})
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
t.Errorf("No error should raise")
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||
t.Errorf("Should find saved record")
|
||||
}
|
||||
|
||||
panic("force panic")
|
||||
})
|
||||
})
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil {
|
||||
t.Errorf("Should not find record after panic rollback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) {
|
||||
tx := DB.Begin()
|
||||
user := User{Name: "transcation"}
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
t.Errorf("No error should raise")
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
t.Errorf("Commit should not raise error")
|
||||
}
|
||||
|
||||
if err := tx.Rollback().Error; err == nil {
|
||||
t.Errorf("Rollback after commit should raise error")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue