Document transaction for callbacks

This commit is contained in:
Jinzhu 2013-11-11 19:06:26 +08:00
parent ce91468922
commit aaa73fe21d
4 changed files with 64 additions and 25 deletions

View File

@ -560,7 +560,7 @@ db.Table("deleted_users").Pluck("name", &names)
## Callbacks ## Callbacks
Callback is a function defined to a struct, the function would be run when reflect a struct to database. Callback is a function defined to a struct, the function would be run when reflect a struct to database.
If the function return an error, will prevent following operations. (for example, stop inserting, updating) If a function return error, gorm will prevent future operations and do rollback
Those callbacks are defined now: Those callbacks are defined now:
@ -570,12 +570,21 @@ Those callbacks are defined now:
`BeforeDelete`, `AfterDelete` `BeforeDelete`, `AfterDelete`
```go ```go
// Won't update readonly user
func (u *User) BeforeUpdate() (err error) { func (u *User) BeforeUpdate() (err error) {
if u.readonly() { if u.readonly() {
err = errors.New("Read Only User") err = errors.New("Read Only User")
} }
return return
} }
// If have more than 1000 users, will rollback the insertion
func (u *User) AfterCreate() (err error) {
if (u.Id > 1000) { // just an example, don't use Id to count users
err = errors.New("Only 1000 users allowed")
}
return
}
``` ```
## Specify Table Name ## Specify Table Name

View File

@ -132,22 +132,16 @@ func (s *Chain) Select(value interface{}) *Chain {
} }
func (s *Chain) Save(value interface{}) *Chain { func (s *Chain) Save(value interface{}) *Chain {
do := s.do(value) do := s.do(value).begin()
tx_started := do.begin()
do.save() do.save()
if tx_started { do.commit_or_rollback()
do.commit()
}
return s return s
} }
func (s *Chain) Delete(value interface{}) *Chain { func (s *Chain) Delete(value interface{}) *Chain {
do := s.do(value) do := s.do(value).begin()
tx_started := do.begin()
do.delete() do.delete()
if tx_started { do.commit_or_rollback()
do.commit()
}
return s return s
} }
@ -156,12 +150,9 @@ func (s *Chain) Update(attrs ...interface{}) *Chain {
} }
func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain { func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain {
do := s.do(s.value) do := s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...)
tx_started := do.begin() do.update()
do.setUpdateAttrs(values, ignore_protected_attrs...).update() do.commit_or_rollback()
if tx_started {
do.commit()
}
return s return s
} }

17
do.go
View File

@ -17,6 +17,7 @@ type Do struct {
db sql_common db sql_common
guessedTableName string guessedTableName string
specifiedTableName string specifiedTableName string
startedTransaction bool
model *Model model *Model
value interface{} value interface{}
@ -756,20 +757,26 @@ func (s *Do) autoMigrate() *Do {
return s return s
} }
func (s *Do) begin() bool { func (s *Do) begin() *Do {
if db, ok := s.db.(sql_db); ok { if db, ok := s.db.(sql_db); ok {
tx, err := db.Begin() tx, err := db.Begin()
if err == nil { if err == nil {
s.db = interface{}(tx).(sql_common) s.db = interface{}(tx).(sql_common)
return true s.startedTransaction = true
} }
} }
return false return s
} }
func (s *Do) commit() { func (s *Do) commit_or_rollback() {
if s.startedTransaction {
if db, ok := s.db.(sql_tx); ok { if db, ok := s.db.(sql_tx); ok {
s.err(db.Commit()) if s.chain.hasError() {
db.Rollback()
} else {
db.Commit()
}
}
} }
} }

View File

@ -610,8 +610,12 @@ func (s *Product) AfterUpdate() {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
} }
func (s *Product) AfterSave() { func (s *Product) AfterSave() (err error) {
if s.Code == "after_save_error" {
err = errors.New("Can't save")
}
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
return
} }
func (s *Product) BeforeDelete() (err error) { func (s *Product) BeforeDelete() (err error) {
@ -622,8 +626,12 @@ func (s *Product) BeforeDelete() (err error) {
return return
} }
func (s *Product) AfterDelete() { func (s *Product) AfterDelete() (err error) {
if s.Code == "after_delete_error" {
err = errors.New("Can't delete")
}
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
return
} }
func (p *Product) GetCallTimes() []int64 { func (p *Product) GetCallTimes() []int64 {
return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes, p.BeforeDeleteCallTimes, p.AfterDeleteCallTimes} return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes, p.BeforeDeleteCallTimes, p.AfterDeleteCallTimes}
@ -703,6 +711,23 @@ func TestRunCallbacksAndGetErrors(t *testing.T) {
if db.Where("Code = ?", "dont_delete").First(&p3).Error != nil { if db.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
t.Errorf("Should not delete record due to errors happened in callback") t.Errorf("Should not delete record due to errors happened in callback")
} }
p4 := Product{Code: "after_save_error", Price: 100}
db.Save(&p4)
if err := db.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
t.Errorf("Record should be reverted if get an error after save", err)
}
p5 := Product{Code: "after_delete_error", Price: 100}
db.Save(&p5)
if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found", err)
}
db.Delete(&p5)
if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found because failed to delete", err)
}
} }
func TestFillSmallerStructCorrectly(t *testing.T) { func TestFillSmallerStructCorrectly(t *testing.T) {
@ -1365,6 +1390,13 @@ func TestTransaction(t *testing.T) {
} }
} }
func (s *CreditCard) BeforeSave() (err error) {
if s.Number == "0000" {
err = errors.New("invalid credit card")
}
return
}
func BenchmarkGorm(b *testing.B) { func BenchmarkGorm(b *testing.B) {
for x := 0; x < b.N; x++ { for x := 0; x < b.N; x++ {
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()} email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}