From 973acd6339ae58587c61db82e31d1d48181a95bb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 26 Jan 2014 12:41:37 +0800 Subject: [PATCH] Add callback create, delete --- callback.go | 24 ++++++++++-------- callback_create.go | 41 ++++++++++++++++++++++++++++++ callback_delete.go | 33 ++++++++++++++++++++++++ callback_test.go | 14 +++++------ callbacks/create.go | 36 ++------------------------ main.go | 4 +-- private.go | 14 ----------- scope.go | 61 +++++++++++++++++++++++++++++++++++++++++++++ utils.go | 14 +++++++++++ 9 files changed, 173 insertions(+), 68 deletions(-) create mode 100644 callback_create.go create mode 100644 callback_delete.go create mode 100644 scope.go diff --git a/callback.go b/callback.go index 1cdeeffb..11eb0006 100644 --- a/callback.go +++ b/callback.go @@ -1,12 +1,14 @@ package gorm -import "fmt" +import ( + "fmt" +) type callback struct { - creates []*func() - updates []*func() - deletes []*func() - queries []*func() + creates []*func(scope *Scope) + updates []*func(scope *Scope) + deletes []*func(scope *Scope) + queries []*func(scope *Scope) processors []*callback_processor } @@ -17,7 +19,7 @@ type callback_processor struct { replace bool remove bool typ string - processor *func() + processor *func(scope *Scope) callback *callback } @@ -53,7 +55,7 @@ func (cp *callback_processor) After(name string) *callback_processor { return cp } -func (cp *callback_processor) Register(name string, fc func()) { +func (cp *callback_processor) Register(name string, fc func(scope *Scope)) { cp.name = name cp.processor = &fc cp.callback.sort() @@ -65,7 +67,7 @@ func (cp *callback_processor) Remove(name string) { cp.callback.sort() } -func (cp *callback_processor) Replace(name string, fc func()) { +func (cp *callback_processor) Replace(name string, fc func(scope *Scope)) { cp.name = name cp.processor = &fc cp.replace = true @@ -81,7 +83,7 @@ func getRIndex(strs []string, str string) int { return -1 } -func sortProcessors(cps []*callback_processor) []*func() { +func sortProcessors(cps []*callback_processor) []*func(scope *Scope) { var sortCallbackProcessor func(c *callback_processor, force bool) var names, sortedNames = []string{}, []string{} @@ -137,8 +139,8 @@ func sortProcessors(cps []*callback_processor) []*func() { sortCallbackProcessor(cp, false) } - var funcs = []*func(){} - var sortedFuncs = []*func(){} + var funcs = []*func(scope *Scope){} + var sortedFuncs = []*func(scope *Scope){} for _, name := range sortedNames { index := getRIndex(names, name) if !cps[index].remove { diff --git a/callback_create.go b/callback_create.go new file mode 100644 index 00000000..10f7537e --- /dev/null +++ b/callback_create.go @@ -0,0 +1,41 @@ +package gorm + +func BeforeCreate(scope *Scope) { + scope.CallMethod("BeforeSave") + scope.CallMethod("BeforeCreate") +} + +func SaveBeforeAssociations(scope *Scope) { +} + +func Create(scope *Scope) { + if !scope.HasError() { + var id interface{} + if scope.Dialect().SupportLastInsertId() { + if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { + id, err = sql_result.LastInsertId() + scope.Err(err) + } + } else { + scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id)) + } + + scope.SetColumn(scope.PrimaryKey(), id) + } +} + +func AfterCreate(scope *Scope) { + scope.CallMethod("AfterCreate") + scope.CallMethod("AfterSave") +} + +func SaveAfterAssociations(scope *Scope) { +} + +func init() { + DefaultCallback.Create().Register("before_create", BeforeCreate) + DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations) + DefaultCallback.Create().Register("create", Create) + DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations) + DefaultCallback.Create().Register("after_create", AfterCreate) +} diff --git a/callback_delete.go b/callback_delete.go new file mode 100644 index 00000000..aa13af4a --- /dev/null +++ b/callback_delete.go @@ -0,0 +1,33 @@ +package gorm + +import ( + "fmt" + "time" +) + +func BeforeDelete(scope *Scope) { + scope.CallMethod("BeforeDelete") +} + +func Delete(scope *Scope) { + if scope.HasError() { + return + } + + if !scope.Search.unscope && scope.HasColumn("DeletedAt") { + scope.Raw(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.TableName(), scope.AddToVars(time.Now()), scope.CombinedConditionSql())) + } else { + scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.TableName(), scope.CombinedConditionSql())) + } + scope.Exec() +} + +func AfterDelete(scope *Scope) { + scope.CallMethod("AfterDelete") +} + +func init() { + DefaultCallback.Delete().Register("before_delete", BeforeDelete) + DefaultCallback.Delete().Register("delete", Delete) + DefaultCallback.Delete().Register("after_delete", AfterDelete) +} diff --git a/callback_test.go b/callback_test.go index 2f4156d7..6fa05f49 100644 --- a/callback_test.go +++ b/callback_test.go @@ -7,7 +7,7 @@ import ( "testing" ) -func equalFuncs(funcs []*func(), fnames []string) bool { +func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { var names []string for _, f := range funcs { fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") @@ -16,11 +16,11 @@ func equalFuncs(funcs []*func(), fnames []string) bool { return reflect.DeepEqual(names, fnames) } -func create() {} -func before_create1() {} -func before_create2() {} -func after_create1() {} -func after_create2() {} +func create(s *Scope) {} +func before_create1(s *Scope) {} +func before_create2(s *Scope) {} +func after_create1(s *Scope) {} +func after_create2(s *Scope) {} func TestRegisterCallback(t *testing.T) { var callback = &callback{processors: []*callback_processor{}} @@ -76,7 +76,7 @@ func TestRegisterCallbackWithComplexOrder2(t *testing.T) { } } -func replace_create() {} +func replace_create(s *Scope) {} func TestReplaceCallback(t *testing.T) { var callback = &callback{processors: []*callback_processor{}} diff --git a/callbacks/create.go b/callbacks/create.go index 8b2f4f1d..390a1ad0 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -1,12 +1,6 @@ -package callback +package callbacks -import ( - "fmt" - - "github.com/jinzhu/gorm" - - "time" -) +import "github.com/jinzhu/gorm" func Create(scope *gorm.Scope) { } @@ -15,32 +9,6 @@ func init() { gorm.DefaultCallback.Create().Before().Register(Create) } -func query(db *DB) { -} - -func save(db *DB) { -} - -func create(db *DB) { -} - -func update(db *DB) { -} - -func Delete(scope *Scope) { - scope.CallMethod("BeforeDelete") - - if !scope.HasError() { - if !scope.Search.unscope && scope.HasColumn("DeletedAt") { - scope.Raw(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.Table(), scope.AddToVars(time.Now()), scope.CombinedSql())) - } else { - scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.Table(), scope.CombinedSql())) - } - scope.Exec() - scope.CallMethod("AfterDelete") - } -} - func init() { DefaultCallback.Create().Before("Delete").After("Lalala").Register("delete", Delete) DefaultCallback.Update().Before("Delete").After("Lalala").Remove("replace", Delete) diff --git a/main.go b/main.go index 214a9c31..87a0d7e5 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,7 @@ import ( type DB struct { Value interface{} - Callbacks *callback + callback *callback Error error db sqlCommon parent *DB @@ -22,7 +22,7 @@ type DB struct { func Open(driver, source string) (DB, error) { var err error - db := DB{dialect: dialect.New(driver), tagIdentifier: "sql", logger: defaultLogger} + db := DB{dialect: dialect.New(driver), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback} db.db, err = sql.Open(driver, source) db.parent = &db return db, err diff --git a/private.go b/private.go index 7b72b5fd..1a7bcb2c 100644 --- a/private.go +++ b/private.go @@ -1,11 +1,7 @@ package gorm import ( - "fmt" - "os" "regexp" - "runtime" - "strings" "time" ) @@ -55,16 +51,6 @@ func (s *DB) hasError() bool { return s.Error != nil } -func fileWithLineNum() string { - for i := 1; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { - return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line) - } - } - return "" -} - func (s *DB) print(v ...interface{}) { s.parent.logger.(logger).Print(v...) } diff --git a/scope.go b/scope.go new file mode 100644 index 00000000..f094d770 --- /dev/null +++ b/scope.go @@ -0,0 +1,61 @@ +package gorm + +import "github.com/jinzhu/gorm/dialect" + +type Scope struct { + Search *search + Sql string + SqlVars []interface{} + db *DB +} + +func (scope *Scope) DB() sqlCommon { + return scope.db.db +} + +func (scope *Scope) Dialect() dialect.Dialect { + return scope.db.parent.dialect +} + +func (scope *Scope) Err(err error) error { + if err != nil { + scope.db.err(err) + } + return err +} + +func (scope *Scope) HasError() bool { + return true +} + +func (scope *Scope) PrimaryKey() string { + return "" +} + +func (scope *Scope) HasColumn(name string) bool { + return false +} + +func (scope *Scope) SetColumn(column string, value interface{}) { +} + +func (scope *Scope) CallMethod(name string) { +} + +func (scope *Scope) CombinedConditionSql() string { + return "" +} + +func (scope *Scope) AddToVars(value interface{}) string { + return "" +} + +func (scope *Scope) TableName() string { + return "" +} + +func (scope *Scope) Raw(sql string, values ...interface{}) { +} + +func (scope *Scope) Exec() { +} diff --git a/utils.go b/utils.go index 947cccad..6df5a86a 100644 --- a/utils.go +++ b/utils.go @@ -3,7 +3,11 @@ package gorm import ( "bytes" "database/sql" + "fmt" + "os" "reflect" + "regexp" + "runtime" "strconv" "strings" "sync" @@ -86,6 +90,16 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) { return } +func fileWithLineNum() string { + for i := 1; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { + return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line) + } + } + return "" +} + func setFieldValue(field reflect.Value, value interface{}) bool { if field.IsValid() && field.CanAddr() { switch field.Kind() {