diff --git a/model.go b/model.go index d3a1fd14..c033637c 100644 --- a/model.go +++ b/model.go @@ -122,12 +122,14 @@ func (m *Model) TableName() string { return reg.ReplaceAllString(toSnake(t.Name()), "s") } -func (model *Model) callMethod(method string) error { - fm := reflect.ValueOf(model).MethodByName(method) +func (m *Model) callMethod(method string) error { + fm := reflect.ValueOf(m.Data).MethodByName(method) if fm.IsValid() { v := fm.Call([]reflect.Value{}) - if verr, ok := v[0].Interface().(error); ok { - return verr + if len(v) > 0 { + if verr, ok := v[0].Interface().(error); ok { + return verr + } } } return nil diff --git a/orm.go b/orm.go index e98fb01a..df6ef77f 100644 --- a/orm.go +++ b/orm.go @@ -110,9 +110,9 @@ func (s *Orm) Select(value interface{}) *Orm { func (s *Orm) Save(value interface{}) *Orm { s.Model(value) if s.model.PrimaryKeyIsEmpty() { - s.explain(value, "Create").create(value) + s.create(value) } else { - s.explain(value, "Update").update(value) + s.update(value) } return s } diff --git a/orm_test.go b/orm_test.go index 1a127f68..984592a6 100644 --- a/orm_test.go +++ b/orm_test.go @@ -1,6 +1,7 @@ package gorm import ( + "reflect" "testing" "time" ) @@ -14,6 +15,20 @@ type User struct { UpdatedAt time.Time } +type Product struct { + Id int64 + Code string + Price int64 + CreatedAt time.Time + UpdatedAt time.Time + BeforeCreateCallTimes int64 + AfterCreateCallTimes int64 + BeforeUpdateCallTimes int64 + AfterUpdateCallTimes int64 + BeforeSaveCallTimes int64 + AfterSaveCallTimes int64 +} + var ( db DB t1, t2, t3, t4, t5 time.Time @@ -22,11 +37,13 @@ var ( func init() { db, _ = Open("postgres", "user=gorm dbname=gorm sslmode=disable") db.Exec("drop table users;") + db.Exec("drop table products;") orm := db.CreateTable(&User{}) if orm.Error != nil { panic("No error should raise when create table") } + db.CreateTable(&Product{}) var shortForm = "2006-01-02 15:04:05" t1, _ = time.Parse(shortForm, "2000-10-27 12:02:40") @@ -311,3 +328,55 @@ func TestCreatedAtAndUpdatedAt(t *testing.T) { t.Errorf("Updated At should be changed after update") } } + +func (s *Product) BeforeCreate() { + s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 +} + +func (s *Product) BeforeUpdate() { + s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 +} + +func (s *Product) BeforeSave() { + s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 +} + +func (s *Product) AfterCreate() { + s.AfterCreateCallTimes = s.AfterCreateCallTimes + 1 +} + +func (s *Product) AfterUpdate() { + s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 +} + +func (s *Product) AfterSave() { + s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 +} + +func (p *Product) GetCallTimes() []int64 { + return []int64{p.BeforeCreateCallTimes, p.BeforeSaveCallTimes, p.BeforeUpdateCallTimes, p.AfterCreateCallTimes, p.AfterSaveCallTimes, p.AfterUpdateCallTimes} +} + +func TestRunCallbacks(t *testing.T) { + p := Product{Code: "unique_code", Price: 100} + db.Save(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0}) { + t.Errorf("Some errors happened when run create callbacks, %v", p.GetCallTimes()) + } + + db.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 0, 0, 0}) { + t.Errorf("Should be able to query about saved values in before filters, %v", p.GetCallTimes()) + } + + p.Price = 200 + db.Save(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 1, 1}) { + t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes()) + } + + db.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 0, 0, 0}) { + t.Errorf("Some errors happened when run update callbacks, %v", p.GetCallTimes()) + } +} diff --git a/sql.go b/sql.go index c8ce61e1..eeab9365 100644 --- a/sql.go +++ b/sql.go @@ -131,7 +131,7 @@ func (s *Orm) create(value interface{}) { var id int64 s.err(s.model.callMethod("BeforeCreate")) s.err(s.model.callMethod("BeforeSave")) - + s.explain(value, "Create") if s.driver == "postgres" { s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) } else { @@ -141,12 +141,11 @@ func (s *Orm) create(value interface{}) { id, err = s.SqlResult.LastInsertId() s.err(err) } + result := reflect.ValueOf(s.model.Data).Elem() + result.FieldByName(s.model.PrimaryKey()).SetInt(id) s.err(s.model.callMethod("AfterCreate")) s.err(s.model.callMethod("AfterSave")) - - result := reflect.ValueOf(s.model.Data).Elem() - result.FieldByName(s.model.PrimaryKey()).SetInt(id) } func (s *Orm) updateSql(value interface{}) { @@ -169,7 +168,7 @@ func (s *Orm) updateSql(value interface{}) { func (s *Orm) update(value interface{}) { s.err(s.model.callMethod("BeforeUpdate")) s.err(s.model.callMethod("BeforeSave")) - s.Exec() + s.explain(value, "Update").Exec() s.err(s.model.callMethod("AfterUpdate")) s.err(s.model.callMethod("AfterSave")) return