Transaction in callbacks

This commit is contained in:
Jinzhu 2013-11-24 08:29:56 +08:00
parent 96ade8c619
commit 31c64a9c95
5 changed files with 68 additions and 17 deletions

View File

@ -598,16 +598,43 @@ db.Select("name, age").Find(&users)
## Callbacks
Callbacks are functions defined to struct's pointer, they would be run when save a struct to database.
If any callback return error, gorm will stop future operations and do rollback
If any callback return error, gorm will stop future operations and rollback all changes
Below callbacks are supported now:
Here is a list with all available callbacks,
listed in the same order in which they will get called during the respective operations.
`BeforeCreate`, `AfterCreate`
`BeforeUpdate`, `AfterUpdate`
`BeforeSave`, `AfterSave`
`BeforeDelete`, `AfterDelete`
### Creating an Object
For example:
```go
BeforeSave
BeforeCreate
// save before associations
// save self
// save after associations
AfterCreate
AfterSave
```
### Updating an Object
```go
BeforeSave
BeforeUpdate
// save before associations
// save self
// save after associations
AfterUpdate
AfterSave
```
### Destroying an Object
```go
BeforeDelete
// delete self
AfterDelete
```
Here is an example:
```go
func (u *User) BeforeUpdate() (err error) {
@ -626,6 +653,17 @@ func (u *User) AfterCreate() (err error) {
}
```
```go
// As you know, the save/delete operations are running in a transaction
// This is means all your changes will be rollbacked if get any errors
// If you want your changes in callbacks be run in the same transaction
// You have to pass the transaction as argument to the function
func (u *User) AfterCreate(tx *gorm.DB) (err error) {
tx.Model(u).Update("role", "admin")
return
}
```
## Specify Table Name
```go

4
do.go
View File

@ -173,8 +173,8 @@ func (s *Do) saveAfterAssociations() {
func (s *Do) create() (i interface{}) {
defer s.trace(time.Now())
s.model.callMethod("BeforeCreate")
s.model.callMethod("BeforeSave")
s.model.callMethod("BeforeCreate")
s.saveBeforeAssociations()
s.prepareCreateSql()
@ -274,8 +274,8 @@ func (s *Do) update() *Do {
return s
}
s.model.callMethod("BeforeUpdate")
s.model.callMethod("BeforeSave")
s.model.callMethod("BeforeUpdate")
s.saveBeforeAssociations()
s.prepareUpdateSql(true)

View File

@ -588,8 +588,8 @@ func (s *Product) BeforeSave() (err error) {
return
}
func (s *Product) AfterCreate() {
s.AfterCreateCallTimes = s.AfterCreateCallTimes + 1
func (s *Product) AfterCreate(db *DB) {
db.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
}
func (s *Product) AfterUpdate() {
@ -633,23 +633,23 @@ func TestRunCallbacks(t *testing.T) {
}
db.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0, 0, 0}) {
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0}) {
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
}
p.Price = 200
db.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1, 0, 0}) {
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0}) {
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
}
db.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0, 0, 0}) {
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0}) {
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
}
db.Delete(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0, 1, 1}) {
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1}) {
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
}

View File

@ -239,8 +239,15 @@ func (m *Model) callMethod(method string) {
}
if fm := reflect.ValueOf(m.data).MethodByName(method); fm.IsValid() {
if v := fm.Call([]reflect.Value{}); len(v) > 0 {
if verr, ok := v[0].Interface().(error); ok {
numin := fm.Type().NumIn()
var results []reflect.Value
if numin == 0 {
results = fm.Call([]reflect.Value{})
} else if numin == 1 {
results = fm.Call([]reflect.Value{reflect.ValueOf(m.do.db.new())})
}
if len(results) > 0 {
if verr, ok := results[0].Interface().(error); ok {
m.do.err(verr)
}
}

View File

@ -18,6 +18,12 @@ func (s *DB) clone() *DB {
return &db
}
func (s *DB) new() *DB {
db := DB{db: s.db, parent: s.parent, logMode: s.logMode, data: s.data, Error: s.Error, search: &search{}}
db.search.db = &db
return &db
}
func (s *DB) do(data interface{}) *Do {
s.data = data
do := Do{db: s}