diff --git a/README.md b/README.md index 43ba6e27..c30a6268 100644 --- a/README.md +++ b/README.md @@ -598,16 +598,43 @@ db.Select("name, age").Find(&users) ## Callbacks Callbacks are functions defined to struct's pointer, they would be run when save a struct to database. -If any callback return error, gorm will stop future operations and do rollback +If any callback return error, gorm will stop future operations and rollback all changes -Below callbacks are supported now: +Here is a list with all available callbacks, +listed in the same order in which they will get called during the respective operations. -`BeforeCreate`, `AfterCreate` -`BeforeUpdate`, `AfterUpdate` -`BeforeSave`, `AfterSave` -`BeforeDelete`, `AfterDelete` +### Creating an Object -For example: +```go +BeforeSave +BeforeCreate +// save before associations +// save self +// save after associations +AfterCreate +AfterSave +``` +### Updating an Object + +```go +BeforeSave +BeforeUpdate +// save before associations +// save self +// save after associations +AfterUpdate +AfterSave +``` + +### Destroying an Object + +```go +BeforeDelete +// delete self +AfterDelete +``` + +Here is an example: ```go func (u *User) BeforeUpdate() (err error) { @@ -626,6 +653,17 @@ func (u *User) AfterCreate() (err error) { } ``` +```go +// As you know, the save/delete operations are running in a transaction +// This is means all your changes will be rollbacked if get any errors +// If you want your changes in callbacks be run in the same transaction +// You have to pass the transaction as argument to the function +func (u *User) AfterCreate(tx *gorm.DB) (err error) { + tx.Model(u).Update("role", "admin") + return +} +``` + ## Specify Table Name ```go diff --git a/do.go b/do.go index 83e90293..49b76835 100644 --- a/do.go +++ b/do.go @@ -173,8 +173,8 @@ func (s *Do) saveAfterAssociations() { func (s *Do) create() (i interface{}) { defer s.trace(time.Now()) - s.model.callMethod("BeforeCreate") s.model.callMethod("BeforeSave") + s.model.callMethod("BeforeCreate") s.saveBeforeAssociations() s.prepareCreateSql() @@ -274,8 +274,8 @@ func (s *Do) update() *Do { return s } - s.model.callMethod("BeforeUpdate") s.model.callMethod("BeforeSave") + s.model.callMethod("BeforeUpdate") s.saveBeforeAssociations() s.prepareUpdateSql(true) diff --git a/gorm_test.go b/gorm_test.go index 816b47cb..ee703a5c 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -588,8 +588,8 @@ func (s *Product) BeforeSave() (err error) { return } -func (s *Product) AfterCreate() { - s.AfterCreateCallTimes = s.AfterCreateCallTimes + 1 +func (s *Product) AfterCreate(db *DB) { + db.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) } func (s *Product) AfterUpdate() { @@ -633,23 +633,23 @@ func TestRunCallbacks(t *testing.T) { } db.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0, 0, 0}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0}) { t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) } p.Price = 200 db.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1, 0, 0}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0}) { t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) } db.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0, 0, 0}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0}) { t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) } db.Delete(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0, 1, 1}) { + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1}) { t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) } diff --git a/model.go b/model.go index 2d3cec47..4494306f 100644 --- a/model.go +++ b/model.go @@ -239,8 +239,15 @@ func (m *Model) callMethod(method string) { } if fm := reflect.ValueOf(m.data).MethodByName(method); fm.IsValid() { - if v := fm.Call([]reflect.Value{}); len(v) > 0 { - if verr, ok := v[0].Interface().(error); ok { + numin := fm.Type().NumIn() + var results []reflect.Value + if numin == 0 { + results = fm.Call([]reflect.Value{}) + } else if numin == 1 { + results = fm.Call([]reflect.Value{reflect.ValueOf(m.do.db.new())}) + } + if len(results) > 0 { + if verr, ok := results[0].Interface().(error); ok { m.do.err(verr) } } diff --git a/private.go b/private.go index 9206cfd6..541a0d54 100644 --- a/private.go +++ b/private.go @@ -18,6 +18,12 @@ func (s *DB) clone() *DB { return &db } +func (s *DB) new() *DB { + db := DB{db: s.db, parent: s.parent, logMode: s.logMode, data: s.data, Error: s.Error, search: &search{}} + db.search.db = &db + return &db +} + func (s *DB) do(data interface{}) *Do { s.data = data do := Do{db: s}