From b551fee2769101d347ccef6726d2920f0136cb69 Mon Sep 17 00:00:00 2001
From: Jinzhu <wosmvp@gmail.com>
Date: Sun, 27 Oct 2013 14:51:23 +0800
Subject: [PATCH] Add callbacks support

---
 model.go | 11 +++++++++++
 orm.go   |  2 +-
 sql.go   | 17 +++++++++++++++++
 3 files changed, 29 insertions(+), 1 deletion(-)

diff --git a/model.go b/model.go
index bdee653e..d3a1fd14 100644
--- a/model.go
+++ b/model.go
@@ -122,6 +122,17 @@ func (m *Model) TableName() string {
 	return reg.ReplaceAllString(toSnake(t.Name()), "s")
 }
 
+func (model *Model) callMethod(method string) error {
+	fm := reflect.ValueOf(model).MethodByName(method)
+	if fm.IsValid() {
+		v := fm.Call([]reflect.Value{})
+		if verr, ok := v[0].Interface().(error); ok {
+			return verr
+		}
+	}
+	return nil
+}
+
 func (model *Model) MissingColumns() (results []string) {
 	return
 }
diff --git a/orm.go b/orm.go
index fc282268..e98fb01a 100644
--- a/orm.go
+++ b/orm.go
@@ -118,7 +118,7 @@ func (s *Orm) Save(value interface{}) *Orm {
 }
 
 func (s *Orm) Delete(value interface{}) *Orm {
-	s.explain(value, "Delete").Exec()
+	s.explain(value, "Delete").delete(value)
 	return s
 }
 
diff --git a/sql.go b/sql.go
index 4e75a947..c8ce61e1 100644
--- a/sql.go
+++ b/sql.go
@@ -129,6 +129,9 @@ func (s *Orm) createSql(value interface{}) {
 
 func (s *Orm) create(value interface{}) {
 	var id int64
+	s.err(s.model.callMethod("BeforeCreate"))
+	s.err(s.model.callMethod("BeforeSave"))
+
 	if s.driver == "postgres" {
 		s.err(s.db.QueryRow(s.Sql, s.SqlVars...).Scan(&id))
 	} else {
@@ -139,6 +142,9 @@ func (s *Orm) create(value interface{}) {
 		s.err(err)
 	}
 
+	s.err(s.model.callMethod("AfterCreate"))
+	s.err(s.model.callMethod("AfterSave"))
+
 	result := reflect.ValueOf(s.model.Data).Elem()
 	result.FieldByName(s.model.PrimaryKey()).SetInt(id)
 }
@@ -161,7 +167,11 @@ func (s *Orm) updateSql(value interface{}) {
 }
 
 func (s *Orm) update(value interface{}) {
+	s.err(s.model.callMethod("BeforeUpdate"))
+	s.err(s.model.callMethod("BeforeSave"))
 	s.Exec()
+	s.err(s.model.callMethod("AfterUpdate"))
+	s.err(s.model.callMethod("AfterSave"))
 	return
 }
 
@@ -169,6 +179,13 @@ func (s *Orm) deleteSql(value interface{}) {
 	s.Sql = fmt.Sprintf("DELETE FROM %v %v", s.TableName, s.combinedSql())
 	return
 }
+
+func (s *Orm) delete(value interface{}) {
+	s.err(s.model.callMethod("BeforeDelete"))
+	s.Exec()
+	s.err(s.model.callMethod("AfterDelete"))
+}
+
 func (s *Orm) buildWhereCondition(clause map[string]interface{}) string {
 	str := "( " + clause["query"].(string) + " )"