Refactor callbacks

This commit is contained in:
Jinzhu 2016-01-17 15:30:42 +08:00
parent 09f46f01b9
commit de73d30503
7 changed files with 59 additions and 48 deletions

View File

@ -23,7 +23,7 @@ type Callback struct {
processors []*CallbackProcessor
}
// callbackProcessor contains all informations for a callback
// CallbackProcessor contains all informations for a callback
type CallbackProcessor struct {
name string // current callback's name
before string // register current callback before a callback
@ -79,7 +79,7 @@ func (c *Callback) Query() *CallbackProcessor {
return c.addProcessor("query")
}
// Query could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
func (c *Callback) RowQuery() *CallbackProcessor {
return c.addProcessor("row_query")
}
@ -125,6 +125,17 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
cp.parent.reorder()
}
// Get registered callback
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
for _, processor := range cp.parent.processors {
if processor.name == callbackName && processor.kind == cp.kind && !cp.remove {
return *cp.processor
}
}
return nil
}
// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {

View File

@ -5,12 +5,12 @@ import (
"strings"
)
func BeforeCreate(scope *Scope) {
func beforeCreateCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("BeforeSave")
scope.CallMethodWithErrorCheck("BeforeCreate")
}
func UpdateTimeStampWhenCreate(scope *Scope) {
func updateTimeStampForCreateCallback(scope *Scope) {
if !scope.HasError() {
now := NowFunc()
scope.SetColumn("CreatedAt", now)
@ -18,7 +18,7 @@ func UpdateTimeStampWhenCreate(scope *Scope) {
}
}
func Create(scope *Scope) {
func createCallback(scope *Scope) {
defer scope.trace(NowFunc())
if !scope.HasError() {
@ -102,25 +102,25 @@ func Create(scope *Scope) {
}
}
func ForceReloadAfterCreate(scope *Scope) {
func forceReloadAfterCreateCallback(scope *Scope) {
if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
scope.DB().New().Select(columns.([]string)).First(scope.Value)
}
}
func AfterCreate(scope *Scope) {
func afterCreateCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterCreate")
scope.CallMethodWithErrorCheck("AfterSave")
}
func init() {
defaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction)
defaultCallback.Create().Register("gorm:before_create", BeforeCreate)
defaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
defaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
defaultCallback.Create().Register("gorm:create", Create)
defaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
defaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
defaultCallback.Create().Register("gorm:after_create", AfterCreate)
defaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
defaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
defaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
defaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
defaultCallback.Create().Register("gorm:update_time_stamp_when_create", updateTimeStampForCreateCallback)
defaultCallback.Create().Register("gorm:create", createCallback)
defaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
defaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
defaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
defaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}

View File

@ -2,11 +2,11 @@ package gorm
import "fmt"
func BeforeDelete(scope *Scope) {
func beforeDeleteCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("BeforeDelete")
}
func Delete(scope *Scope) {
func deleteCallback(scope *Scope) {
if !scope.HasError() {
if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
scope.Raw(
@ -23,14 +23,14 @@ func Delete(scope *Scope) {
}
}
func AfterDelete(scope *Scope) {
func afterDeleteCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterDelete")
}
func init() {
defaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
defaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
defaultCallback.Delete().Register("gorm:delete", Delete)
defaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
defaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
defaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
defaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
defaultCallback.Delete().Register("gorm:delete", deleteCallback)
defaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
defaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}

View File

@ -6,7 +6,7 @@ import (
"reflect"
)
func Query(scope *Scope) {
func queryCallback(scope *Scope) {
defer scope.trace(NowFunc())
var (
@ -78,12 +78,12 @@ func Query(scope *Scope) {
}
}
func AfterQuery(scope *Scope) {
func afterQueryCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterFind")
}
func init() {
defaultCallback.Query().Register("gorm:query", Query)
defaultCallback.Query().Register("gorm:after_query", AfterQuery)
defaultCallback.Query().Register("gorm:preload", Preload)
defaultCallback.Query().Register("gorm:query", queryCallback)
defaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
defaultCallback.Query().Register("gorm:preload", preloadCallback)
}

View File

@ -7,7 +7,7 @@ import (
"strings"
)
func Preload(scope *Scope) {
func preloadCallback(scope *Scope) {
if scope.Search.preload == nil || scope.HasError() {
return
}

View File

@ -2,15 +2,15 @@ package gorm
import "reflect"
func BeginTransaction(scope *Scope) {
func beginTransactionCallback(scope *Scope) {
scope.Begin()
}
func CommitOrRollbackTransaction(scope *Scope) {
func commitOrRollbackTransactionCallback(scope *Scope) {
scope.CommitOrRollback()
}
func SaveBeforeAssociations(scope *Scope) {
func saveBeforeAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
@ -32,7 +32,7 @@ func SaveBeforeAssociations(scope *Scope) {
}
}
func SaveAfterAssociations(scope *Scope) {
func saveAfterAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}

View File

@ -5,7 +5,7 @@ import (
"strings"
)
func AssignUpdateAttributes(scope *Scope) {
func assignUpdateAttributesCallback(scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
protected, ok := scope.Get("gorm:ignore_protected_attrs")
@ -24,20 +24,20 @@ func AssignUpdateAttributes(scope *Scope) {
}
}
func BeforeUpdate(scope *Scope) {
func beforeUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("BeforeSave")
scope.CallMethodWithErrorCheck("BeforeUpdate")
}
}
func UpdateTimeStampWhenUpdate(scope *Scope) {
func updateTimeStampForUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc())
}
}
func Update(scope *Scope) {
func updateCallback(scope *Scope) {
if !scope.HasError() {
var sqls []string
@ -75,7 +75,7 @@ func Update(scope *Scope) {
}
}
func AfterUpdate(scope *Scope) {
func afterUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("AfterUpdate")
scope.CallMethodWithErrorCheck("AfterSave")
@ -83,13 +83,13 @@ func AfterUpdate(scope *Scope) {
}
func init() {
defaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
defaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction)
defaultCallback.Update().Register("gorm:before_update", BeforeUpdate)
defaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
defaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
defaultCallback.Update().Register("gorm:update", Update)
defaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
defaultCallback.Update().Register("gorm:after_update", AfterUpdate)
defaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
defaultCallback.Update().Register("gorm:assign_update_attributes", assignUpdateAttributesCallback)
defaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
defaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
defaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
defaultCallback.Update().Register("gorm:update_time_stamp_when_update", updateTimeStampForUpdateCallback)
defaultCallback.Update().Register("gorm:update", updateCallback)
defaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
defaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
defaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}