Document transaction for callbacks

This commit is contained in:
Jinzhu 2013-11-11 19:06:26 +08:00
parent ce91468922
commit aaa73fe21d
4 changed files with 64 additions and 25 deletions

View File

@ -560,7 +560,7 @@ db.Table("deleted_users").Pluck("name", &names)
## Callbacks
Callback is a function defined to a struct, the function would be run when reflect a struct to database.
If the function return an error, will prevent following operations. (for example, stop inserting, updating)
If a function return error, gorm will prevent future operations and do rollback
Those callbacks are defined now:
@ -570,12 +570,21 @@ Those callbacks are defined now:
`BeforeDelete`, `AfterDelete`
```go
// Won't update readonly user
func (u *User) BeforeUpdate() (err error) {
if u.readonly() {
err = errors.New("Read Only User")
}
return
}
// If have more than 1000 users, will rollback the insertion
func (u *User) AfterCreate() (err error) {
if (u.Id > 1000) { // just an example, don't use Id to count users
err = errors.New("Only 1000 users allowed")
}
return
}
```
## Specify Table Name

View File

@ -132,22 +132,16 @@ func (s *Chain) Select(value interface{}) *Chain {
}
func (s *Chain) Save(value interface{}) *Chain {
do := s.do(value)
tx_started := do.begin()
do := s.do(value).begin()
do.save()
if tx_started {
do.commit()
}
do.commit_or_rollback()
return s
}
func (s *Chain) Delete(value interface{}) *Chain {
do := s.do(value)
tx_started := do.begin()
do := s.do(value).begin()
do.delete()
if tx_started {
do.commit()
}
do.commit_or_rollback()
return s
}
@ -156,12 +150,9 @@ func (s *Chain) Update(attrs ...interface{}) *Chain {
}
func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain {
do := s.do(s.value)
tx_started := do.begin()
do.setUpdateAttrs(values, ignore_protected_attrs...).update()
if tx_started {
do.commit()
}
do := s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...)
do.update()
do.commit_or_rollback()
return s
}

19
do.go
View File

@ -17,6 +17,7 @@ type Do struct {
db sql_common
guessedTableName string
specifiedTableName string
startedTransaction bool
model *Model
value interface{}
@ -756,20 +757,26 @@ func (s *Do) autoMigrate() *Do {
return s
}
func (s *Do) begin() bool {
func (s *Do) begin() *Do {
if db, ok := s.db.(sql_db); ok {
tx, err := db.Begin()
if err == nil {
s.db = interface{}(tx).(sql_common)
return true
s.startedTransaction = true
}
}
return false
return s
}
func (s *Do) commit() {
if db, ok := s.db.(sql_tx); ok {
s.err(db.Commit())
func (s *Do) commit_or_rollback() {
if s.startedTransaction {
if db, ok := s.db.(sql_tx); ok {
if s.chain.hasError() {
db.Rollback()
} else {
db.Commit()
}
}
}
}

View File

@ -610,8 +610,12 @@ func (s *Product) AfterUpdate() {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
}
func (s *Product) AfterSave() {
func (s *Product) AfterSave() (err error) {
if s.Code == "after_save_error" {
err = errors.New("Can't save")
}
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
return
}
func (s *Product) BeforeDelete() (err error) {
@ -622,8 +626,12 @@ func (s *Product) BeforeDelete() (err error) {
return
}
func (s *Product) AfterDelete() {
func (s *Product) AfterDelete() (err error) {
if s.Code == "after_delete_error" {
err = errors.New("Can't delete")
}
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
return
}
func (p *Product) GetCallTimes() []int64 {
return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes, p.BeforeDeleteCallTimes, p.AfterDeleteCallTimes}
@ -703,6 +711,23 @@ func TestRunCallbacksAndGetErrors(t *testing.T) {
if db.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
t.Errorf("Should not delete record due to errors happened in callback")
}
p4 := Product{Code: "after_save_error", Price: 100}
db.Save(&p4)
if err := db.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
t.Errorf("Record should be reverted if get an error after save", err)
}
p5 := Product{Code: "after_delete_error", Price: 100}
db.Save(&p5)
if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found", err)
}
db.Delete(&p5)
if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found because failed to delete", err)
}
}
func TestFillSmallerStructCorrectly(t *testing.T) {
@ -1365,6 +1390,13 @@ func TestTransaction(t *testing.T) {
}
}
func (s *CreditCard) BeforeSave() (err error) {
if s.Number == "0000" {
err = errors.New("invalid credit card")
}
return
}
func BenchmarkGorm(b *testing.B) {
for x := 0; x < b.N; x++ {
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}