forked from mirror/gorm
Test Transactions
This commit is contained in:
parent
ae9e4f1dd8
commit
5457fe88e6
|
@ -267,6 +267,16 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pluck used to query single column from a model as a map
|
||||||
|
// var ages []int64
|
||||||
|
// db.Find(&users).Pluck("age", &ages)
|
||||||
|
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||||
|
tx = db.getInstance()
|
||||||
|
tx.Statement.Dest = dest
|
||||||
|
tx.callbacks.Query().Execute(tx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||||
tx := db.getInstance()
|
tx := db.getInstance()
|
||||||
tx.Error = tx.Statement.Parse(dest)
|
tx.Error = tx.Statement.Parse(dest)
|
||||||
|
@ -307,7 +317,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
|
||||||
opt = opts[0]
|
opt = opts[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
if tx.Statement.ConnPool, err = beginner.BeginTx(db.Statement.Context, opt); err != nil {
|
if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil {
|
||||||
tx.AddError(err)
|
tx.AddError(err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
11
gorm.go
11
gorm.go
|
@ -167,15 +167,14 @@ func (db *DB) AddError(err error) error {
|
||||||
|
|
||||||
func (db *DB) getInstance() *DB {
|
func (db *DB) getInstance() *DB {
|
||||||
if db.clone {
|
if db.clone {
|
||||||
stmt := &Statement{
|
stmt := &Statement{DB: db, Clauses: map[string]clause.Clause{}}
|
||||||
DB: db,
|
|
||||||
ConnPool: db.ConnPool,
|
|
||||||
Clauses: map[string]clause.Clause{},
|
|
||||||
Context: context.Background(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement != nil {
|
if db.Statement != nil {
|
||||||
stmt.Context = db.Statement.Context
|
stmt.Context = db.Statement.Context
|
||||||
|
stmt.ConnPool = db.Statement.ConnPool
|
||||||
|
} else {
|
||||||
|
stmt.Context = context.Background()
|
||||||
|
stmt.ConnPool = db.ConnPool
|
||||||
}
|
}
|
||||||
|
|
||||||
return &DB{Config: db.Config, Statement: stmt}
|
return &DB{Config: db.Config, Statement: stmt}
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/jinzhu/gorm/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExceptionsWithInvalidSql(t *testing.T) {
|
||||||
|
var columns []string
|
||||||
|
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||||
|
t.Errorf("Should got error with invalid SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||||
|
t.Errorf("Should got error with invalid SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil {
|
||||||
|
t.Errorf("Should got error with invalid SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
var count1, count2 int64
|
||||||
|
DB.Model(&User{}).Count(&count1)
|
||||||
|
if count1 <= 0 {
|
||||||
|
t.Errorf("Should find some users")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil {
|
||||||
|
t.Errorf("Should got error with invalid SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&User{}).Count(&count2)
|
||||||
|
if count1 != count2 {
|
||||||
|
t.Errorf("No user should not be deleted by invalid SQL")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,135 @@
|
||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
. "github.com/jinzhu/gorm/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTransaction(t *testing.T) {
|
||||||
|
tx := DB.Begin()
|
||||||
|
user := *GetUser("transcation", Config{})
|
||||||
|
|
||||||
|
if err := tx.Save(&user).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
|
||||||
|
t.Errorf("Should find saved record, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil {
|
||||||
|
t.Errorf("Should return the underlying sql.Tx")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Rollback()
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||||
|
t.Errorf("Should not find record after rollback, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx2 := DB.Begin()
|
||||||
|
user2 := *GetUser("transcation-2", Config{})
|
||||||
|
if err := tx2.Save(&user2).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||||
|
t.Errorf("Should find saved record, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx2.Commit()
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||||
|
t.Errorf("Should be able to find committed record, but got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionWithBlock(t *testing.T) {
|
||||||
|
assertPanic := func(f func()) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("The code did not panic")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
|
||||||
|
// rollback
|
||||||
|
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
user := *GetUser("transcation-block", Config{})
|
||||||
|
if err := tx.Save(&user).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Errorf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.New("the error message")
|
||||||
|
})
|
||||||
|
|
||||||
|
if err.Error() != "the error message" {
|
||||||
|
t.Errorf("Transaction return error will equal the block returns error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil {
|
||||||
|
t.Errorf("Should not find record after rollback")
|
||||||
|
}
|
||||||
|
|
||||||
|
// commit
|
||||||
|
DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
user := *GetUser("transcation-block-2", Config{})
|
||||||
|
if err := tx.Save(&user).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Errorf("Should find saved record")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil {
|
||||||
|
t.Errorf("Should be able to find committed record")
|
||||||
|
}
|
||||||
|
|
||||||
|
// panic will rollback
|
||||||
|
assertPanic(func() {
|
||||||
|
DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
user := *GetUser("transcation-block-3", Config{})
|
||||||
|
if err := tx.Save(&user).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Errorf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("force panic")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil {
|
||||||
|
t.Errorf("Should not find record after panic rollback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) {
|
||||||
|
tx := DB.Begin()
|
||||||
|
user := User{Name: "transcation"}
|
||||||
|
if err := tx.Save(&user).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
t.Errorf("Commit should not raise error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Rollback().Error; err == nil {
|
||||||
|
t.Errorf("Rollback after commit should raise error")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue