diff --git a/callback.go b/callback.go index f45bb64c..55cd807e 100644 --- a/callback.go +++ b/callback.go @@ -4,17 +4,17 @@ import ( "fmt" ) -// defaultCallbacks hold default callbacks defined by gorm -var defaultCallbacks = &Callbacks{} +// defaultCallback hold default callbacks defined by gorm +var defaultCallback = &Callback{} -// Callbacks contains callbacks that used when CURD objects +// Callback contains callbacks that used when CURD objects // Field `creates` hold callbacks will be call when creating object // Field `updates` hold callbacks will be call when updating object // Field `deletes` hold callbacks will be call when deleting object // Field `queries` hold callbacks will be call when querying object with query methods like Find, First, Related, Association... // Field `rowQueries` hold callbacks will be call when querying object with Row, Rows... // Field `processors` hold all callback processors, will be used to generate above callbacks in order -type Callbacks struct { +type Callback struct { creates []*func(scope *Scope) updates []*func(scope *Scope) deletes []*func(scope *Scope) @@ -32,17 +32,17 @@ type CallbackProcessor struct { remove bool // delete callbacks with same name kind string // callback type: create, update, delete, query, row_query processor *func(scope *Scope) // callback handler - parent *Callbacks + parent *Callback } -func (c *Callbacks) addProcessor(kind string) *CallbackProcessor { +func (c *Callback) addProcessor(kind string) *CallbackProcessor { cp := &CallbackProcessor{kind: kind, parent: c} c.processors = append(c.processors, cp) return cp } -func (c *Callbacks) clone() *Callbacks { - return &Callbacks{ +func (c *Callback) clone() *Callback { + return &Callback{ creates: c.creates, updates: c.updates, deletes: c.deletes, @@ -59,28 +59,28 @@ func (c *Callbacks) clone() *Callbacks { // // set error if some thing wrong happened, will rollback the creating // scope.Err(errors.New("error")) // }) -func (c *Callbacks) Create() *CallbackProcessor { +func (c *Callback) Create() *CallbackProcessor { return c.addProcessor("create") } // Update could be used to register callbacks for updating object, refer `Create` for usage -func (c *Callbacks) Update() *CallbackProcessor { +func (c *Callback) Update() *CallbackProcessor { return c.addProcessor("update") } // Delete could be used to register callbacks for deleting object, refer `Create` for usage -func (c *Callbacks) Delete() *CallbackProcessor { +func (c *Callback) Delete() *CallbackProcessor { return c.addProcessor("delete") } // Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... // refer `Create` for usage -func (c *Callbacks) Query() *CallbackProcessor { +func (c *Callback) Query() *CallbackProcessor { return c.addProcessor("query") } // Query could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage -func (c *Callbacks) RowQuery() *CallbackProcessor { +func (c *Callback) RowQuery() *CallbackProcessor { return c.addProcessor("row_query") } @@ -209,7 +209,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { } // reorder all registered processors, and reset CURD callbacks -func (c *Callbacks) reorder() { +func (c *Callback) reorder() { var creates, updates, deletes, queries, rowQueries []*CallbackProcessor for _, processor := range c.processors { diff --git a/callback_create.go b/callback_create.go index 6f99c56b..c52b9c85 100644 --- a/callback_create.go +++ b/callback_create.go @@ -114,13 +114,13 @@ func AfterCreate(scope *Scope) { } func init() { - defaultCallbacks.Create().Register("gorm:begin_transaction", BeginTransaction) - defaultCallbacks.Create().Register("gorm:before_create", BeforeCreate) - defaultCallbacks.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) - defaultCallbacks.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate) - defaultCallbacks.Create().Register("gorm:create", Create) - defaultCallbacks.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate) - defaultCallbacks.Create().Register("gorm:save_after_associations", SaveAfterAssociations) - defaultCallbacks.Create().Register("gorm:after_create", AfterCreate) - defaultCallbacks.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + defaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction) + defaultCallback.Create().Register("gorm:before_create", BeforeCreate) + defaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) + defaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate) + defaultCallback.Create().Register("gorm:create", Create) + defaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate) + defaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) + defaultCallback.Create().Register("gorm:after_create", AfterCreate) + defaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) } diff --git a/callback_delete.go b/callback_delete.go index 7ea001cc..dca6ee21 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -28,9 +28,9 @@ func AfterDelete(scope *Scope) { } func init() { - defaultCallbacks.Delete().Register("gorm:begin_transaction", BeginTransaction) - defaultCallbacks.Delete().Register("gorm:before_delete", BeforeDelete) - defaultCallbacks.Delete().Register("gorm:delete", Delete) - defaultCallbacks.Delete().Register("gorm:after_delete", AfterDelete) - defaultCallbacks.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + defaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction) + defaultCallback.Delete().Register("gorm:before_delete", BeforeDelete) + defaultCallback.Delete().Register("gorm:delete", Delete) + defaultCallback.Delete().Register("gorm:after_delete", AfterDelete) + defaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) } diff --git a/callback_query.go b/callback_query.go index 2c9ba0d1..05ac8880 100644 --- a/callback_query.go +++ b/callback_query.go @@ -83,7 +83,7 @@ func AfterQuery(scope *Scope) { } func init() { - defaultCallbacks.Query().Register("gorm:query", Query) - defaultCallbacks.Query().Register("gorm:after_query", AfterQuery) - defaultCallbacks.Query().Register("gorm:preload", Preload) + defaultCallback.Query().Register("gorm:query", Query) + defaultCallback.Query().Register("gorm:after_query", AfterQuery) + defaultCallback.Query().Register("gorm:preload", Preload) } diff --git a/callback_test.go b/callback_test.go index bb189543..13ca3f42 100644 --- a/callback_test.go +++ b/callback_test.go @@ -23,62 +23,62 @@ func afterCreate1(s *Scope) {} func afterCreate2(s *Scope) {} func TestRegisterCallback(t *testing.T) { - var callbacks = &Callbacks{} + var callback = &Callback{} - callbacks.Create().Register("before_create1", beforeCreate1) - callbacks.Create().Register("before_create2", beforeCreate2) - callbacks.Create().Register("create", create) - callbacks.Create().Register("after_create1", afterCreate1) - callbacks.Create().Register("after_create2", afterCreate2) + callback.Create().Register("before_create1", beforeCreate1) + callback.Create().Register("before_create2", beforeCreate2) + callback.Create().Register("create", create) + callback.Create().Register("after_create1", afterCreate1) + callback.Create().Register("after_create2", afterCreate2) - if !equalFuncs(callbacks.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { + if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { t.Errorf("register callback") } } func TestRegisterCallbackWithOrder(t *testing.T) { - var callbacks1 = &Callbacks{} - callbacks1.Create().Register("before_create1", beforeCreate1) - callbacks1.Create().Register("create", create) - callbacks1.Create().Register("after_create1", afterCreate1) - callbacks1.Create().Before("after_create1").Register("after_create2", afterCreate2) - if !equalFuncs(callbacks1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { + var callback1 = &Callback{} + callback1.Create().Register("before_create1", beforeCreate1) + callback1.Create().Register("create", create) + callback1.Create().Register("after_create1", afterCreate1) + callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) + if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { t.Errorf("register callback with order") } - var callbacks2 = &Callbacks{} + var callback2 = &Callback{} - callbacks2.Update().Register("create", create) - callbacks2.Update().Before("create").Register("before_create1", beforeCreate1) - callbacks2.Update().After("after_create2").Register("after_create1", afterCreate1) - callbacks2.Update().Before("before_create1").Register("before_create2", beforeCreate2) - callbacks2.Update().Register("after_create2", afterCreate2) + callback2.Update().Register("create", create) + callback2.Update().Before("create").Register("before_create1", beforeCreate1) + callback2.Update().After("after_create2").Register("after_create1", afterCreate1) + callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) + callback2.Update().Register("after_create2", afterCreate2) - if !equalFuncs(callbacks2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { + if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { t.Errorf("register callback with order") } } func TestRegisterCallbackWithComplexOrder(t *testing.T) { - var callbacks1 = &Callbacks{} + var callback1 = &Callback{} - callbacks1.Query().Before("after_create1").After("before_create1").Register("create", create) - callbacks1.Query().Register("before_create1", beforeCreate1) - callbacks1.Query().Register("after_create1", afterCreate1) + callback1.Query().Before("after_create1").After("before_create1").Register("create", create) + callback1.Query().Register("before_create1", beforeCreate1) + callback1.Query().Register("after_create1", afterCreate1) - if !equalFuncs(callbacks1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { + if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { t.Errorf("register callback with order") } - var callbacks2 = &Callbacks{} + var callback2 = &Callback{} - callbacks2.Delete().Before("after_create1").After("before_create1").Register("create", create) - callbacks2.Delete().Before("create").Register("before_create1", beforeCreate1) - callbacks2.Delete().After("before_create1").Register("before_create2", beforeCreate2) - callbacks2.Delete().Register("after_create1", afterCreate1) - callbacks2.Delete().After("after_create1").Register("after_create2", afterCreate2) + callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) + callback2.Delete().Before("create").Register("before_create1", beforeCreate1) + callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) + callback2.Delete().Register("after_create1", afterCreate1) + callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) - if !equalFuncs(callbacks2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { + if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { t.Errorf("register callback with order") } } @@ -86,27 +86,27 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) { func replaceCreate(s *Scope) {} func TestReplaceCallback(t *testing.T) { - var callbacks = &Callbacks{} + var callback = &Callback{} - callbacks.Create().Before("after_create1").After("before_create1").Register("create", create) - callbacks.Create().Register("before_create1", beforeCreate1) - callbacks.Create().Register("after_create1", afterCreate1) - callbacks.Create().Replace("create", replaceCreate) + callback.Create().Before("after_create1").After("before_create1").Register("create", create) + callback.Create().Register("before_create1", beforeCreate1) + callback.Create().Register("after_create1", afterCreate1) + callback.Create().Replace("create", replaceCreate) - if !equalFuncs(callbacks.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { + if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { t.Errorf("replace callback") } } func TestRemoveCallback(t *testing.T) { - var callbacks = &Callbacks{} + var callback = &Callback{} - callbacks.Create().Before("after_create1").After("before_create1").Register("create", create) - callbacks.Create().Register("before_create1", beforeCreate1) - callbacks.Create().Register("after_create1", afterCreate1) - callbacks.Create().Remove("create") + callback.Create().Before("after_create1").After("before_create1").Register("create", create) + callback.Create().Register("before_create1", beforeCreate1) + callback.Create().Register("after_create1", afterCreate1) + callback.Create().Remove("create") - if !equalFuncs(callbacks.creates, []string{"beforeCreate1", "afterCreate1"}) { + if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { t.Errorf("remove callback") } } diff --git a/callback_update.go b/callback_update.go index a2b6d48e..e7884450 100644 --- a/callback_update.go +++ b/callback_update.go @@ -83,13 +83,13 @@ func AfterUpdate(scope *Scope) { } func init() { - defaultCallbacks.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes) - defaultCallbacks.Update().Register("gorm:begin_transaction", BeginTransaction) - defaultCallbacks.Update().Register("gorm:before_update", BeforeUpdate) - defaultCallbacks.Update().Register("gorm:save_before_associations", SaveBeforeAssociations) - defaultCallbacks.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate) - defaultCallbacks.Update().Register("gorm:update", Update) - defaultCallbacks.Update().Register("gorm:save_after_associations", SaveAfterAssociations) - defaultCallbacks.Update().Register("gorm:after_update", AfterUpdate) - defaultCallbacks.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + defaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes) + defaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction) + defaultCallback.Update().Register("gorm:before_update", BeforeUpdate) + defaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations) + defaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate) + defaultCallback.Update().Register("gorm:update", Update) + defaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations) + defaultCallback.Update().Register("gorm:after_update", AfterUpdate) + defaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) } diff --git a/main.go b/main.go index bc5f4735..d580c5a5 100644 --- a/main.go +++ b/main.go @@ -23,7 +23,7 @@ type DB struct { Value interface{} Error error RowsAffected int64 - callbacks *Callbacks + callbacks *Callback db sqlCommon parent *DB search *search @@ -67,7 +67,7 @@ func Open(dialect string, args ...interface{}) (*DB, error) { db = DB{ dialect: NewDialect(dialect), logger: defaultLogger, - callbacks: defaultCallbacks, + callbacks: defaultCallback, source: source, values: map[string]interface{}{}, db: dbSql, @@ -111,7 +111,7 @@ func (s *DB) CommonDB() sqlCommon { return s.db } -func (s *DB) Callback() *Callbacks { +func (s *DB) Callback() *Callback { s.parent.callbacks = s.parent.callbacks.clone() return s.parent.callbacks }