forked from mirror/gorm
Document transaction for callbacks
This commit is contained in:
parent
ce91468922
commit
aaa73fe21d
11
README.md
11
README.md
|
@ -560,7 +560,7 @@ db.Table("deleted_users").Pluck("name", &names)
|
|||
## Callbacks
|
||||
|
||||
Callback is a function defined to a struct, the function would be run when reflect a struct to database.
|
||||
If the function return an error, will prevent following operations. (for example, stop inserting, updating)
|
||||
If a function return error, gorm will prevent future operations and do rollback
|
||||
|
||||
Those callbacks are defined now:
|
||||
|
||||
|
@ -570,12 +570,21 @@ Those callbacks are defined now:
|
|||
`BeforeDelete`, `AfterDelete`
|
||||
|
||||
```go
|
||||
// Won't update readonly user
|
||||
func (u *User) BeforeUpdate() (err error) {
|
||||
if u.readonly() {
|
||||
err = errors.New("Read Only User")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If have more than 1000 users, will rollback the insertion
|
||||
func (u *User) AfterCreate() (err error) {
|
||||
if (u.Id > 1000) { // just an example, don't use Id to count users
|
||||
err = errors.New("Only 1000 users allowed")
|
||||
}
|
||||
return
|
||||
}
|
||||
```
|
||||
|
||||
## Specify Table Name
|
||||
|
|
23
chain.go
23
chain.go
|
@ -132,22 +132,16 @@ func (s *Chain) Select(value interface{}) *Chain {
|
|||
}
|
||||
|
||||
func (s *Chain) Save(value interface{}) *Chain {
|
||||
do := s.do(value)
|
||||
tx_started := do.begin()
|
||||
do := s.do(value).begin()
|
||||
do.save()
|
||||
if tx_started {
|
||||
do.commit()
|
||||
}
|
||||
do.commit_or_rollback()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Chain) Delete(value interface{}) *Chain {
|
||||
do := s.do(value)
|
||||
tx_started := do.begin()
|
||||
do := s.do(value).begin()
|
||||
do.delete()
|
||||
if tx_started {
|
||||
do.commit()
|
||||
}
|
||||
do.commit_or_rollback()
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -156,12 +150,9 @@ func (s *Chain) Update(attrs ...interface{}) *Chain {
|
|||
}
|
||||
|
||||
func (s *Chain) Updates(values interface{}, ignore_protected_attrs ...bool) *Chain {
|
||||
do := s.do(s.value)
|
||||
tx_started := do.begin()
|
||||
do.setUpdateAttrs(values, ignore_protected_attrs...).update()
|
||||
if tx_started {
|
||||
do.commit()
|
||||
}
|
||||
do := s.do(s.value).begin().setUpdateAttrs(values, ignore_protected_attrs...)
|
||||
do.update()
|
||||
do.commit_or_rollback()
|
||||
return s
|
||||
}
|
||||
|
||||
|
|
19
do.go
19
do.go
|
@ -17,6 +17,7 @@ type Do struct {
|
|||
db sql_common
|
||||
guessedTableName string
|
||||
specifiedTableName string
|
||||
startedTransaction bool
|
||||
|
||||
model *Model
|
||||
value interface{}
|
||||
|
@ -756,20 +757,26 @@ func (s *Do) autoMigrate() *Do {
|
|||
return s
|
||||
}
|
||||
|
||||
func (s *Do) begin() bool {
|
||||
func (s *Do) begin() *Do {
|
||||
if db, ok := s.db.(sql_db); ok {
|
||||
tx, err := db.Begin()
|
||||
if err == nil {
|
||||
s.db = interface{}(tx).(sql_common)
|
||||
return true
|
||||
s.startedTransaction = true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Do) commit() {
|
||||
if db, ok := s.db.(sql_tx); ok {
|
||||
s.err(db.Commit())
|
||||
func (s *Do) commit_or_rollback() {
|
||||
if s.startedTransaction {
|
||||
if db, ok := s.db.(sql_tx); ok {
|
||||
if s.chain.hasError() {
|
||||
db.Rollback()
|
||||
} else {
|
||||
db.Commit()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
36
gorm_test.go
36
gorm_test.go
|
@ -610,8 +610,12 @@ func (s *Product) AfterUpdate() {
|
|||
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
|
||||
}
|
||||
|
||||
func (s *Product) AfterSave() {
|
||||
func (s *Product) AfterSave() (err error) {
|
||||
if s.Code == "after_save_error" {
|
||||
err = errors.New("Can't save")
|
||||
}
|
||||
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) BeforeDelete() (err error) {
|
||||
|
@ -622,8 +626,12 @@ func (s *Product) BeforeDelete() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Product) AfterDelete() {
|
||||
func (s *Product) AfterDelete() (err error) {
|
||||
if s.Code == "after_delete_error" {
|
||||
err = errors.New("Can't delete")
|
||||
}
|
||||
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
|
||||
return
|
||||
}
|
||||
func (p *Product) GetCallTimes() []int64 {
|
||||
return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes, p.BeforeDeleteCallTimes, p.AfterDeleteCallTimes}
|
||||
|
@ -703,6 +711,23 @@ func TestRunCallbacksAndGetErrors(t *testing.T) {
|
|||
if db.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
|
||||
t.Errorf("Should not delete record due to errors happened in callback")
|
||||
}
|
||||
|
||||
p4 := Product{Code: "after_save_error", Price: 100}
|
||||
db.Save(&p4)
|
||||
if err := db.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
|
||||
t.Errorf("Record should be reverted if get an error after save", err)
|
||||
}
|
||||
|
||||
p5 := Product{Code: "after_delete_error", Price: 100}
|
||||
db.Save(&p5)
|
||||
if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
||||
t.Errorf("Record should be found", err)
|
||||
}
|
||||
|
||||
db.Delete(&p5)
|
||||
if err := db.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
||||
t.Errorf("Record should be found because failed to delete", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillSmallerStructCorrectly(t *testing.T) {
|
||||
|
@ -1365,6 +1390,13 @@ func TestTransaction(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *CreditCard) BeforeSave() (err error) {
|
||||
if s.Number == "0000" {
|
||||
err = errors.New("invalid credit card")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func BenchmarkGorm(b *testing.B) {
|
||||
for x := 0; x < b.N; x++ {
|
||||
email := BigEmail{Email: "benchmark@example.org", UserAgent: "pc", RegisteredAt: time.Now()}
|
||||
|
|
Loading…
Reference in New Issue