From b551fee2769101d347ccef6726d2920f0136cb69 Mon Sep 17 00:00:00 2001 From: Jinzhu <wosmvp@gmail.com> Date: Sun, 27 Oct 2013 14:51:23 +0800 Subject: [PATCH] Add callbacks support --- model.go | 11 +++++++++++ orm.go | 2 +- sql.go | 17 +++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/model.go b/model.go index bdee653e..d3a1fd14 100644 --- a/model.go +++ b/model.go @@ -122,6 +122,17 @@ func (m *Model) TableName() string { return reg.ReplaceAllString(toSnake(t.Name()), "s") } +func (model *Model) callMethod(method string) error { + fm := reflect.ValueOf(model).MethodByName(method) + if fm.IsValid() { + v := fm.Call([]reflect.Value{}) + if verr, ok := v[0].Interface().(error); ok { + return verr + } + } + return nil +} + func (model *Model) MissingColumns() (results []string) { return } diff --git a/orm.go b/orm.go index fc282268..e98fb01a 100644 --- a/orm.go +++ b/orm.go @@ -118,7 +118,7 @@ func (s *Orm) Save(value interface{}) *Orm { } func (s *Orm) Delete(value interface{}) *Orm { - s.explain(value, "Delete").Exec() + s.explain(value, "Delete").delete(value) return s } diff --git a/sql.go b/sql.go index 4e75a947..c8ce61e1 100644 --- a/sql.go +++ b/sql.go @@ -129,6 +129,9 @@ func (s *Orm) createSql(value interface{}) { func (s *Orm) create(value interface{}) { var id int64 + s.err(s.model.callMethod("BeforeCreate")) + s.err(s.model.callMethod("BeforeSave")) + if s.driver == "postgres" { s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id)) } else { @@ -139,6 +142,9 @@ func (s *Orm) create(value interface{}) { s.err(err) } + s.err(s.model.callMethod("AfterCreate")) + s.err(s.model.callMethod("AfterSave")) + result := reflect.ValueOf(s.model.Data).Elem() result.FieldByName(s.model.PrimaryKey()).SetInt(id) } @@ -161,7 +167,11 @@ func (s *Orm) updateSql(value interface{}) { } func (s *Orm) update(value interface{}) { + s.err(s.model.callMethod("BeforeUpdate")) + s.err(s.model.callMethod("BeforeSave")) s.Exec() + s.err(s.model.callMethod("AfterUpdate")) + s.err(s.model.callMethod("AfterSave")) return } @@ -169,6 +179,13 @@ func (s *Orm) deleteSql(value interface{}) { s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql()) return } + +func (s *Orm) delete(value interface{}) { + s.err(s.model.callMethod("BeforeDelete")) + s.Exec() + s.err(s.model.callMethod("AfterDelete")) +} + func (s *Orm) buildWhereCondition(clause map[string]interface{}) string { str := "( " + clause["query"].(string) + " )"