Refactor callbacks

This commit is contained in:
Jinzhu 2016-01-17 15:30:42 +08:00
parent 09f46f01b9
commit de73d30503
7 changed files with 59 additions and 48 deletions

View File

@ -23,7 +23,7 @@ type Callback struct {
processors []*CallbackProcessor processors []*CallbackProcessor
} }
// callbackProcessor contains all informations for a callback // CallbackProcessor contains all informations for a callback
type CallbackProcessor struct { type CallbackProcessor struct {
name string // current callback's name name string // current callback's name
before string // register current callback before a callback before string // register current callback before a callback
@ -79,7 +79,7 @@ func (c *Callback) Query() *CallbackProcessor {
return c.addProcessor("query") return c.addProcessor("query")
} }
// Query could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage // RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
func (c *Callback) RowQuery() *CallbackProcessor { func (c *Callback) RowQuery() *CallbackProcessor {
return c.addProcessor("row_query") return c.addProcessor("row_query")
} }
@ -125,6 +125,17 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
cp.parent.reorder() cp.parent.reorder()
} }
// Get registered callback
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
for _, processor := range cp.parent.processors {
if processor.name == callbackName && processor.kind == cp.kind && !cp.remove {
return *cp.processor
}
}
return nil
}
// getRIndex get right index from string slice // getRIndex get right index from string slice
func getRIndex(strs []string, str string) int { func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- { for i := len(strs) - 1; i >= 0; i-- {

View File

@ -5,12 +5,12 @@ import (
"strings" "strings"
) )
func BeforeCreate(scope *Scope) { func beforeCreateCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("BeforeSave") scope.CallMethodWithErrorCheck("BeforeSave")
scope.CallMethodWithErrorCheck("BeforeCreate") scope.CallMethodWithErrorCheck("BeforeCreate")
} }
func UpdateTimeStampWhenCreate(scope *Scope) { func updateTimeStampForCreateCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
now := NowFunc() now := NowFunc()
scope.SetColumn("CreatedAt", now) scope.SetColumn("CreatedAt", now)
@ -18,7 +18,7 @@ func UpdateTimeStampWhenCreate(scope *Scope) {
} }
} }
func Create(scope *Scope) { func createCallback(scope *Scope) {
defer scope.trace(NowFunc()) defer scope.trace(NowFunc())
if !scope.HasError() { if !scope.HasError() {
@ -102,25 +102,25 @@ func Create(scope *Scope) {
} }
} }
func ForceReloadAfterCreate(scope *Scope) { func forceReloadAfterCreateCallback(scope *Scope) {
if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok { if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
scope.DB().New().Select(columns.([]string)).First(scope.Value) scope.DB().New().Select(columns.([]string)).First(scope.Value)
} }
} }
func AfterCreate(scope *Scope) { func afterCreateCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterCreate") scope.CallMethodWithErrorCheck("AfterCreate")
scope.CallMethodWithErrorCheck("AfterSave") scope.CallMethodWithErrorCheck("AfterSave")
} }
func init() { func init() {
defaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction) defaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
defaultCallback.Create().Register("gorm:before_create", BeforeCreate) defaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
defaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations) defaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
defaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate) defaultCallback.Create().Register("gorm:update_time_stamp_when_create", updateTimeStampForCreateCallback)
defaultCallback.Create().Register("gorm:create", Create) defaultCallback.Create().Register("gorm:create", createCallback)
defaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate) defaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
defaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations) defaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
defaultCallback.Create().Register("gorm:after_create", AfterCreate) defaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
defaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) defaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
} }

View File

@ -2,11 +2,11 @@ package gorm
import "fmt" import "fmt"
func BeforeDelete(scope *Scope) { func beforeDeleteCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("BeforeDelete") scope.CallMethodWithErrorCheck("BeforeDelete")
} }
func Delete(scope *Scope) { func deleteCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
scope.Raw( scope.Raw(
@ -23,14 +23,14 @@ func Delete(scope *Scope) {
} }
} }
func AfterDelete(scope *Scope) { func afterDeleteCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterDelete") scope.CallMethodWithErrorCheck("AfterDelete")
} }
func init() { func init() {
defaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction) defaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
defaultCallback.Delete().Register("gorm:before_delete", BeforeDelete) defaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
defaultCallback.Delete().Register("gorm:delete", Delete) defaultCallback.Delete().Register("gorm:delete", deleteCallback)
defaultCallback.Delete().Register("gorm:after_delete", AfterDelete) defaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
defaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) defaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
} }

View File

@ -6,7 +6,7 @@ import (
"reflect" "reflect"
) )
func Query(scope *Scope) { func queryCallback(scope *Scope) {
defer scope.trace(NowFunc()) defer scope.trace(NowFunc())
var ( var (
@ -78,12 +78,12 @@ func Query(scope *Scope) {
} }
} }
func AfterQuery(scope *Scope) { func afterQueryCallback(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterFind") scope.CallMethodWithErrorCheck("AfterFind")
} }
func init() { func init() {
defaultCallback.Query().Register("gorm:query", Query) defaultCallback.Query().Register("gorm:query", queryCallback)
defaultCallback.Query().Register("gorm:after_query", AfterQuery) defaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
defaultCallback.Query().Register("gorm:preload", Preload) defaultCallback.Query().Register("gorm:preload", preloadCallback)
} }

View File

@ -7,7 +7,7 @@ import (
"strings" "strings"
) )
func Preload(scope *Scope) { func preloadCallback(scope *Scope) {
if scope.Search.preload == nil || scope.HasError() { if scope.Search.preload == nil || scope.HasError() {
return return
} }

View File

@ -2,15 +2,15 @@ package gorm
import "reflect" import "reflect"
func BeginTransaction(scope *Scope) { func beginTransactionCallback(scope *Scope) {
scope.Begin() scope.Begin()
} }
func CommitOrRollbackTransaction(scope *Scope) { func commitOrRollbackTransactionCallback(scope *Scope) {
scope.CommitOrRollback() scope.CommitOrRollback()
} }
func SaveBeforeAssociations(scope *Scope) { func saveBeforeAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() { if !scope.shouldSaveAssociations() {
return return
} }
@ -32,7 +32,7 @@ func SaveBeforeAssociations(scope *Scope) {
} }
} }
func SaveAfterAssociations(scope *Scope) { func saveAfterAssociationsCallback(scope *Scope) {
if !scope.shouldSaveAssociations() { if !scope.shouldSaveAssociations() {
return return
} }

View File

@ -5,7 +5,7 @@ import (
"strings" "strings"
) )
func AssignUpdateAttributes(scope *Scope) { func assignUpdateAttributesCallback(scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if maps := convertInterfaceToMap(attrs); len(maps) > 0 { if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
protected, ok := scope.Get("gorm:ignore_protected_attrs") protected, ok := scope.Get("gorm:ignore_protected_attrs")
@ -24,20 +24,20 @@ func AssignUpdateAttributes(scope *Scope) {
} }
} }
func BeforeUpdate(scope *Scope) { func beforeUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok { if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("BeforeSave") scope.CallMethodWithErrorCheck("BeforeSave")
scope.CallMethodWithErrorCheck("BeforeUpdate") scope.CallMethodWithErrorCheck("BeforeUpdate")
} }
} }
func UpdateTimeStampWhenUpdate(scope *Scope) { func updateTimeStampForUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok { if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc()) scope.SetColumn("UpdatedAt", NowFunc())
} }
} }
func Update(scope *Scope) { func updateCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
var sqls []string var sqls []string
@ -75,7 +75,7 @@ func Update(scope *Scope) {
} }
} }
func AfterUpdate(scope *Scope) { func afterUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok { if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("AfterUpdate") scope.CallMethodWithErrorCheck("AfterUpdate")
scope.CallMethodWithErrorCheck("AfterSave") scope.CallMethodWithErrorCheck("AfterSave")
@ -83,13 +83,13 @@ func AfterUpdate(scope *Scope) {
} }
func init() { func init() {
defaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes) defaultCallback.Update().Register("gorm:assign_update_attributes", assignUpdateAttributesCallback)
defaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction) defaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
defaultCallback.Update().Register("gorm:before_update", BeforeUpdate) defaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
defaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations) defaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
defaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate) defaultCallback.Update().Register("gorm:update_time_stamp_when_update", updateTimeStampForUpdateCallback)
defaultCallback.Update().Register("gorm:update", Update) defaultCallback.Update().Register("gorm:update", updateCallback)
defaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations) defaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
defaultCallback.Update().Register("gorm:after_update", AfterUpdate) defaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
defaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) defaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
} }