diff --git a/callback.go b/callback.go index 55cd807e..6d0ff0f3 100644 --- a/callback.go +++ b/callback.go @@ -23,7 +23,7 @@ type Callback struct { processors []*CallbackProcessor } -// callbackProcessor contains all informations for a callback +// CallbackProcessor contains all informations for a callback type CallbackProcessor struct { name string // current callback's name before string // register current callback before a callback @@ -79,7 +79,7 @@ func (c *Callback) Query() *CallbackProcessor { return c.addProcessor("query") } -// Query could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage +// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage func (c *Callback) RowQuery() *CallbackProcessor { return c.addProcessor("row_query") } @@ -125,6 +125,17 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S cp.parent.reorder() } +// Get registered callback +// db.Callback().Create().Get("gorm:create") +func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { + for _, processor := range cp.parent.processors { + if processor.name == callbackName && processor.kind == cp.kind && !cp.remove { + return *cp.processor + } + } + return nil +} + // getRIndex get right index from string slice func getRIndex(strs []string, str string) int { for i := len(strs) - 1; i >= 0; i-- { diff --git a/callback_create.go b/callback_create.go index c52b9c85..3003b3a5 100644 --- a/callback_create.go +++ b/callback_create.go @@ -5,12 +5,12 @@ import ( "strings" ) -func BeforeCreate(scope *Scope) { +func beforeCreateCallback(scope *Scope) { scope.CallMethodWithErrorCheck("BeforeSave") scope.CallMethodWithErrorCheck("BeforeCreate") } -func UpdateTimeStampWhenCreate(scope *Scope) { +func updateTimeStampForCreateCallback(scope *Scope) { if !scope.HasError() { now := NowFunc() scope.SetColumn("CreatedAt", now) @@ -18,7 +18,7 @@ func UpdateTimeStampWhenCreate(scope *Scope) { } } -func Create(scope *Scope) { +func createCallback(scope *Scope) { defer scope.trace(NowFunc()) if !scope.HasError() { @@ -102,25 +102,25 @@ func Create(scope *Scope) { } } -func ForceReloadAfterCreate(scope *Scope) { +func forceReloadAfterCreateCallback(scope *Scope) { if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok { scope.DB().New().Select(columns.([]string)).First(scope.Value) } } -func AfterCreate(scope *Scope) { +func afterCreateCallback(scope *Scope) { scope.CallMethodWithErrorCheck("AfterCreate") scope.CallMethodWithErrorCheck("AfterSave") } func init() { - defaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction) - defaultCallback.Create().Register("gorm:before_create", BeforeCreate) - defaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) - defaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate) - defaultCallback.Create().Register("gorm:create", Create) - defaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate) - defaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) - defaultCallback.Create().Register("gorm:after_create", AfterCreate) - defaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + defaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) + defaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) + defaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) + defaultCallback.Create().Register("gorm:update_time_stamp_when_create", updateTimeStampForCreateCallback) + defaultCallback.Create().Register("gorm:create", createCallback) + defaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) + defaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) + defaultCallback.Create().Register("gorm:after_create", afterCreateCallback) + defaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) } diff --git a/callback_delete.go b/callback_delete.go index dca6ee21..7616cb87 100644 --- a/callback_delete.go +++ b/callback_delete.go @@ -2,11 +2,11 @@ package gorm import "fmt" -func BeforeDelete(scope *Scope) { +func beforeDeleteCallback(scope *Scope) { scope.CallMethodWithErrorCheck("BeforeDelete") } -func Delete(scope *Scope) { +func deleteCallback(scope *Scope) { if !scope.HasError() { if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { scope.Raw( @@ -23,14 +23,14 @@ func Delete(scope *Scope) { } } -func AfterDelete(scope *Scope) { +func afterDeleteCallback(scope *Scope) { scope.CallMethodWithErrorCheck("AfterDelete") } func init() { - defaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction) - defaultCallback.Delete().Register("gorm:before_delete", BeforeDelete) - defaultCallback.Delete().Register("gorm:delete", Delete) - defaultCallback.Delete().Register("gorm:after_delete", AfterDelete) - defaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + defaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) + defaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) + defaultCallback.Delete().Register("gorm:delete", deleteCallback) + defaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) + defaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) } diff --git a/callback_query.go b/callback_query.go index 05ac8880..707d83f5 100644 --- a/callback_query.go +++ b/callback_query.go @@ -6,7 +6,7 @@ import ( "reflect" ) -func Query(scope *Scope) { +func queryCallback(scope *Scope) { defer scope.trace(NowFunc()) var ( @@ -78,12 +78,12 @@ func Query(scope *Scope) { } } -func AfterQuery(scope *Scope) { +func afterQueryCallback(scope *Scope) { scope.CallMethodWithErrorCheck("AfterFind") } func init() { - defaultCallback.Query().Register("gorm:query", Query) - defaultCallback.Query().Register("gorm:after_query", AfterQuery) - defaultCallback.Query().Register("gorm:preload", Preload) + defaultCallback.Query().Register("gorm:query", queryCallback) + defaultCallback.Query().Register("gorm:after_query", afterQueryCallback) + defaultCallback.Query().Register("gorm:preload", preloadCallback) } diff --git a/preload.go b/callback_query_preload.go similarity index 99% rename from preload.go rename to callback_query_preload.go index 692280ef..5dc91de9 100644 --- a/preload.go +++ b/callback_query_preload.go @@ -7,7 +7,7 @@ import ( "strings" ) -func Preload(scope *Scope) { +func preloadCallback(scope *Scope) { if scope.Search.preload == nil || scope.HasError() { return } diff --git a/callback_shared.go b/callback_shared.go index 547059e3..a525b709 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -2,15 +2,15 @@ package gorm import "reflect" -func BeginTransaction(scope *Scope) { +func beginTransactionCallback(scope *Scope) { scope.Begin() } -func CommitOrRollbackTransaction(scope *Scope) { +func commitOrRollbackTransactionCallback(scope *Scope) { scope.CommitOrRollback() } -func SaveBeforeAssociations(scope *Scope) { +func saveBeforeAssociationsCallback(scope *Scope) { if !scope.shouldSaveAssociations() { return } @@ -32,7 +32,7 @@ func SaveBeforeAssociations(scope *Scope) { } } -func SaveAfterAssociations(scope *Scope) { +func saveAfterAssociationsCallback(scope *Scope) { if !scope.shouldSaveAssociations() { return } diff --git a/callback_update.go b/callback_update.go index e7884450..f8ded58b 100644 --- a/callback_update.go +++ b/callback_update.go @@ -5,7 +5,7 @@ import ( "strings" ) -func AssignUpdateAttributes(scope *Scope) { +func assignUpdateAttributesCallback(scope *Scope) { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if maps := convertInterfaceToMap(attrs); len(maps) > 0 { protected, ok := scope.Get("gorm:ignore_protected_attrs") @@ -24,20 +24,20 @@ func AssignUpdateAttributes(scope *Scope) { } } -func BeforeUpdate(scope *Scope) { +func beforeUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { scope.CallMethodWithErrorCheck("BeforeSave") scope.CallMethodWithErrorCheck("BeforeUpdate") } } -func UpdateTimeStampWhenUpdate(scope *Scope) { +func updateTimeStampForUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { scope.SetColumn("UpdatedAt", NowFunc()) } } -func Update(scope *Scope) { +func updateCallback(scope *Scope) { if !scope.HasError() { var sqls []string @@ -75,7 +75,7 @@ func Update(scope *Scope) { } } -func AfterUpdate(scope *Scope) { +func afterUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { scope.CallMethodWithErrorCheck("AfterUpdate") scope.CallMethodWithErrorCheck("AfterSave") @@ -83,13 +83,13 @@ func AfterUpdate(scope *Scope) { } func init() { - defaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes) - defaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction) - defaultCallback.Update().Register("gorm:before_update", BeforeUpdate) - defaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations) - defaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate) - defaultCallback.Update().Register("gorm:update", Update) - defaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations) - defaultCallback.Update().Register("gorm:after_update", AfterUpdate) - defaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + defaultCallback.Update().Register("gorm:assign_update_attributes", assignUpdateAttributesCallback) + defaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) + defaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) + defaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) + defaultCallback.Update().Register("gorm:update_time_stamp_when_update", updateTimeStampForUpdateCallback) + defaultCallback.Update().Register("gorm:update", updateCallback) + defaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) + defaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) + defaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) }