mirror of https://github.com/go-gorm/gorm.git
Refactor Callback
This commit is contained in:
parent
dc23ae63bf
commit
f1237e4fe9
125
callback.go
125
callback.go
|
@ -4,34 +4,45 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type callback struct {
|
// defaultCallbacks hold default callbacks defined by gorm
|
||||||
|
var defaultCallbacks = &Callbacks{}
|
||||||
|
|
||||||
|
// Callbacks contains callbacks that used when CURD objects
|
||||||
|
// Field `creates` hold callbacks will be call when creating object
|
||||||
|
// Field `updates` hold callbacks will be call when updating object
|
||||||
|
// Field `deletes` hold callbacks will be call when deleting object
|
||||||
|
// Field `queries` hold callbacks will be call when querying object with query methods like Find, First, Related, Association...
|
||||||
|
// Field `rowQueries` hold callbacks will be call when querying object with Row, Rows...
|
||||||
|
// Field `processors` hold all callback processors, will be used to generate above callbacks in order
|
||||||
|
type Callbacks struct {
|
||||||
creates []*func(scope *Scope)
|
creates []*func(scope *Scope)
|
||||||
updates []*func(scope *Scope)
|
updates []*func(scope *Scope)
|
||||||
deletes []*func(scope *Scope)
|
deletes []*func(scope *Scope)
|
||||||
queries []*func(scope *Scope)
|
queries []*func(scope *Scope)
|
||||||
rowQueries []*func(scope *Scope)
|
rowQueries []*func(scope *Scope)
|
||||||
processors []*callbackProcessor
|
processors []*CallbackProcessor
|
||||||
}
|
}
|
||||||
|
|
||||||
type callbackProcessor struct {
|
// callbackProcessor contains all informations for a callback
|
||||||
name string
|
type CallbackProcessor struct {
|
||||||
before string
|
name string // current callback's name
|
||||||
after string
|
before string // register current callback before a callback
|
||||||
replace bool
|
after string // register current callback after a callback
|
||||||
remove bool
|
replace bool // replace callbacks with same name
|
||||||
typ string
|
remove bool // delete callbacks with same name
|
||||||
processor *func(scope *Scope)
|
kind string // callback type: create, update, delete, query, row_query
|
||||||
callback *callback
|
processor *func(scope *Scope) // callback handler
|
||||||
|
parent *Callbacks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) addProcessor(typ string) *callbackProcessor {
|
func (c *Callbacks) addProcessor(kind string) *CallbackProcessor {
|
||||||
cp := &callbackProcessor{typ: typ, callback: c}
|
cp := &CallbackProcessor{kind: kind, parent: c}
|
||||||
c.processors = append(c.processors, cp)
|
c.processors = append(c.processors, cp)
|
||||||
return cp
|
return cp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) clone() *callback {
|
func (c *Callbacks) clone() *Callbacks {
|
||||||
return &callback{
|
return &Callbacks{
|
||||||
creates: c.creates,
|
creates: c.creates,
|
||||||
updates: c.updates,
|
updates: c.updates,
|
||||||
deletes: c.deletes,
|
deletes: c.deletes,
|
||||||
|
@ -40,57 +51,81 @@ func (c *callback) clone() *callback {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Create() *callbackProcessor {
|
// Create could be used to register callbacks for creating object
|
||||||
|
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
|
||||||
|
// // business logic
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
// // set error if some thing wrong happened, will rollback the creating
|
||||||
|
// scope.Err(errors.New("error"))
|
||||||
|
// })
|
||||||
|
func (c *Callbacks) Create() *CallbackProcessor {
|
||||||
return c.addProcessor("create")
|
return c.addProcessor("create")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Update() *callbackProcessor {
|
// Update could be used to register callbacks for updating object, refer `Create` for usage
|
||||||
|
func (c *Callbacks) Update() *CallbackProcessor {
|
||||||
return c.addProcessor("update")
|
return c.addProcessor("update")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Delete() *callbackProcessor {
|
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
|
||||||
|
func (c *Callbacks) Delete() *CallbackProcessor {
|
||||||
return c.addProcessor("delete")
|
return c.addProcessor("delete")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) Query() *callbackProcessor {
|
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
|
||||||
|
// refer `Create` for usage
|
||||||
|
func (c *Callbacks) Query() *CallbackProcessor {
|
||||||
return c.addProcessor("query")
|
return c.addProcessor("query")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) RowQuery() *callbackProcessor {
|
// Query could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
|
||||||
|
func (c *Callbacks) RowQuery() *CallbackProcessor {
|
||||||
return c.addProcessor("row_query")
|
return c.addProcessor("row_query")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callbackProcessor) Before(name string) *callbackProcessor {
|
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
|
||||||
cp.before = name
|
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
|
||||||
|
cp.after = callbackName
|
||||||
return cp
|
return cp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callbackProcessor) After(name string) *callbackProcessor {
|
// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
|
||||||
cp.after = name
|
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
||||||
|
cp.before = callbackName
|
||||||
return cp
|
return cp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
|
// Register a new callback, refer `Callbacks.Create`
|
||||||
cp.name = name
|
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||||
cp.processor = &fc
|
cp.name = callbackName
|
||||||
cp.callback.sort()
|
cp.processor = &callback
|
||||||
|
cp.parent.reorder()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callbackProcessor) Remove(name string) {
|
// Remove a registered callback
|
||||||
fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
|
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||||
cp.name = name
|
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||||
|
fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||||
|
cp.name = callbackName
|
||||||
cp.remove = true
|
cp.remove = true
|
||||||
cp.callback.sort()
|
cp.parent.reorder()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
|
// Replace a registered callback with new callback
|
||||||
fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
|
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
||||||
cp.name = name
|
// scope.SetColumn("Created", now)
|
||||||
cp.processor = &fc
|
// scope.SetColumn("Updated", now)
|
||||||
|
// })
|
||||||
|
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||||
|
fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||||
|
cp.name = callbackName
|
||||||
|
cp.processor = &callback
|
||||||
cp.replace = true
|
cp.replace = true
|
||||||
cp.callback.sort()
|
cp.parent.reorder()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getRIndex get right index from string slice
|
||||||
func getRIndex(strs []string, str string) int {
|
func getRIndex(strs []string, str string) int {
|
||||||
for i := len(strs) - 1; i >= 0; i-- {
|
for i := len(strs) - 1; i >= 0; i-- {
|
||||||
if strs[i] == str {
|
if strs[i] == str {
|
||||||
|
@ -100,8 +135,9 @@ func getRIndex(strs []string, str string) int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
|
// sortProcessors sort callback processors based on its before, after, remove, replace
|
||||||
var sortCallbackProcessor func(c *callbackProcessor)
|
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
||||||
|
var sortCallbackProcessor func(c *CallbackProcessor)
|
||||||
var names, sortedNames = []string{}, []string{}
|
var names, sortedNames = []string{}, []string{}
|
||||||
|
|
||||||
for _, cp := range cps {
|
for _, cp := range cps {
|
||||||
|
@ -113,7 +149,7 @@ func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
|
||||||
names = append(names, cp.name)
|
names = append(names, cp.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
sortCallbackProcessor = func(c *callbackProcessor) {
|
sortCallbackProcessor = func(c *CallbackProcessor) {
|
||||||
if getRIndex(sortedNames, c.name) > -1 {
|
if getRIndex(sortedNames, c.name) > -1 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -172,11 +208,12 @@ func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
|
||||||
return append(sortedFuncs, funcs...)
|
return append(sortedFuncs, funcs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *callback) sort() {
|
// reorder all registered processors, and reset CURD callbacks
|
||||||
var creates, updates, deletes, queries, rowQueries []*callbackProcessor
|
func (c *Callbacks) reorder() {
|
||||||
|
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
|
||||||
|
|
||||||
for _, processor := range c.processors {
|
for _, processor := range c.processors {
|
||||||
switch processor.typ {
|
switch processor.kind {
|
||||||
case "create":
|
case "create":
|
||||||
creates = append(creates, processor)
|
creates = append(creates, processor)
|
||||||
case "update":
|
case "update":
|
||||||
|
@ -196,5 +233,3 @@ func (c *callback) sort() {
|
||||||
c.queries = sortProcessors(queries)
|
c.queries = sortProcessors(queries)
|
||||||
c.rowQueries = sortProcessors(rowQueries)
|
c.rowQueries = sortProcessors(rowQueries)
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultCallback = &callback{processors: []*callbackProcessor{}}
|
|
||||||
|
|
|
@ -114,13 +114,13 @@ func AfterCreate(scope *Scope) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction)
|
defaultCallbacks.Create().Register("gorm:begin_transaction", BeginTransaction)
|
||||||
DefaultCallback.Create().Register("gorm:before_create", BeforeCreate)
|
defaultCallbacks.Create().Register("gorm:before_create", BeforeCreate)
|
||||||
DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
|
defaultCallbacks.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
|
||||||
DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
|
defaultCallbacks.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
|
||||||
DefaultCallback.Create().Register("gorm:create", Create)
|
defaultCallbacks.Create().Register("gorm:create", Create)
|
||||||
DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
|
defaultCallbacks.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
|
||||||
DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
|
defaultCallbacks.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
|
||||||
DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
|
defaultCallbacks.Create().Register("gorm:after_create", AfterCreate)
|
||||||
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
defaultCallbacks.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,9 +28,9 @@ func AfterDelete(scope *Scope) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
|
defaultCallbacks.Delete().Register("gorm:begin_transaction", BeginTransaction)
|
||||||
DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
|
defaultCallbacks.Delete().Register("gorm:before_delete", BeforeDelete)
|
||||||
DefaultCallback.Delete().Register("gorm:delete", Delete)
|
defaultCallbacks.Delete().Register("gorm:delete", Delete)
|
||||||
DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
|
defaultCallbacks.Delete().Register("gorm:after_delete", AfterDelete)
|
||||||
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
defaultCallbacks.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,7 +83,7 @@ func AfterQuery(scope *Scope) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
DefaultCallback.Query().Register("gorm:query", Query)
|
defaultCallbacks.Query().Register("gorm:query", Query)
|
||||||
DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
|
defaultCallbacks.Query().Register("gorm:after_query", AfterQuery)
|
||||||
DefaultCallback.Query().Register("gorm:preload", Preload)
|
defaultCallbacks.Query().Register("gorm:preload", Preload)
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,62 +23,62 @@ func afterCreate1(s *Scope) {}
|
||||||
func afterCreate2(s *Scope) {}
|
func afterCreate2(s *Scope) {}
|
||||||
|
|
||||||
func TestRegisterCallback(t *testing.T) {
|
func TestRegisterCallback(t *testing.T) {
|
||||||
var callback = &callback{processors: []*callbackProcessor{}}
|
var callbacks = &Callbacks{}
|
||||||
|
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callbacks.Create().Register("before_create1", beforeCreate1)
|
||||||
callback.Create().Register("before_create2", beforeCreate2)
|
callbacks.Create().Register("before_create2", beforeCreate2)
|
||||||
callback.Create().Register("create", create)
|
callbacks.Create().Register("create", create)
|
||||||
callback.Create().Register("after_create1", afterCreate1)
|
callbacks.Create().Register("after_create1", afterCreate1)
|
||||||
callback.Create().Register("after_create2", afterCreate2)
|
callbacks.Create().Register("after_create2", afterCreate2)
|
||||||
|
|
||||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
if !equalFuncs(callbacks.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
||||||
t.Errorf("register callback")
|
t.Errorf("register callback")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterCallbackWithOrder(t *testing.T) {
|
func TestRegisterCallbackWithOrder(t *testing.T) {
|
||||||
var callback1 = &callback{processors: []*callbackProcessor{}}
|
var callbacks1 = &Callbacks{}
|
||||||
callback1.Create().Register("before_create1", beforeCreate1)
|
callbacks1.Create().Register("before_create1", beforeCreate1)
|
||||||
callback1.Create().Register("create", create)
|
callbacks1.Create().Register("create", create)
|
||||||
callback1.Create().Register("after_create1", afterCreate1)
|
callbacks1.Create().Register("after_create1", afterCreate1)
|
||||||
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
|
callbacks1.Create().Before("after_create1").Register("after_create2", afterCreate2)
|
||||||
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
if !equalFuncs(callbacks1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
|
|
||||||
var callback2 = &callback{processors: []*callbackProcessor{}}
|
var callbacks2 = &Callbacks{}
|
||||||
|
|
||||||
callback2.Update().Register("create", create)
|
callbacks2.Update().Register("create", create)
|
||||||
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
callbacks2.Update().Before("create").Register("before_create1", beforeCreate1)
|
||||||
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
|
callbacks2.Update().After("after_create2").Register("after_create1", afterCreate1)
|
||||||
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
|
callbacks2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
|
||||||
callback2.Update().Register("after_create2", afterCreate2)
|
callbacks2.Update().Register("after_create2", afterCreate2)
|
||||||
|
|
||||||
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
if !equalFuncs(callbacks2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||||
var callback1 = &callback{processors: []*callbackProcessor{}}
|
var callbacks1 = &Callbacks{}
|
||||||
|
|
||||||
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
callbacks1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback1.Query().Register("before_create1", beforeCreate1)
|
callbacks1.Query().Register("before_create1", beforeCreate1)
|
||||||
callback1.Query().Register("after_create1", afterCreate1)
|
callbacks1.Query().Register("after_create1", afterCreate1)
|
||||||
|
|
||||||
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
|
if !equalFuncs(callbacks1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
|
||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
|
|
||||||
var callback2 = &callback{processors: []*callbackProcessor{}}
|
var callbacks2 = &Callbacks{}
|
||||||
|
|
||||||
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
callbacks2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
callbacks2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
||||||
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
|
callbacks2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
|
||||||
callback2.Delete().Register("after_create1", afterCreate1)
|
callbacks2.Delete().Register("after_create1", afterCreate1)
|
||||||
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
|
callbacks2.Delete().After("after_create1").Register("after_create2", afterCreate2)
|
||||||
|
|
||||||
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
if !equalFuncs(callbacks2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
||||||
t.Errorf("register callback with order")
|
t.Errorf("register callback with order")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -86,27 +86,27 @@ func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||||
func replaceCreate(s *Scope) {}
|
func replaceCreate(s *Scope) {}
|
||||||
|
|
||||||
func TestReplaceCallback(t *testing.T) {
|
func TestReplaceCallback(t *testing.T) {
|
||||||
var callback = &callback{processors: []*callbackProcessor{}}
|
var callbacks = &Callbacks{}
|
||||||
|
|
||||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
callbacks.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callbacks.Create().Register("before_create1", beforeCreate1)
|
||||||
callback.Create().Register("after_create1", afterCreate1)
|
callbacks.Create().Register("after_create1", afterCreate1)
|
||||||
callback.Create().Replace("create", replaceCreate)
|
callbacks.Create().Replace("create", replaceCreate)
|
||||||
|
|
||||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
|
if !equalFuncs(callbacks.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
|
||||||
t.Errorf("replace callback")
|
t.Errorf("replace callback")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveCallback(t *testing.T) {
|
func TestRemoveCallback(t *testing.T) {
|
||||||
var callback = &callback{processors: []*callbackProcessor{}}
|
var callbacks = &Callbacks{}
|
||||||
|
|
||||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
callbacks.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
callback.Create().Register("before_create1", beforeCreate1)
|
callbacks.Create().Register("before_create1", beforeCreate1)
|
||||||
callback.Create().Register("after_create1", afterCreate1)
|
callbacks.Create().Register("after_create1", afterCreate1)
|
||||||
callback.Create().Remove("create")
|
callbacks.Create().Remove("create")
|
||||||
|
|
||||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
|
if !equalFuncs(callbacks.creates, []string{"beforeCreate1", "afterCreate1"}) {
|
||||||
t.Errorf("remove callback")
|
t.Errorf("remove callback")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,13 +83,13 @@ func AfterUpdate(scope *Scope) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
|
defaultCallbacks.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
|
||||||
DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction)
|
defaultCallbacks.Update().Register("gorm:begin_transaction", BeginTransaction)
|
||||||
DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate)
|
defaultCallbacks.Update().Register("gorm:before_update", BeforeUpdate)
|
||||||
DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
|
defaultCallbacks.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
|
||||||
DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
|
defaultCallbacks.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
|
||||||
DefaultCallback.Update().Register("gorm:update", Update)
|
defaultCallbacks.Update().Register("gorm:update", Update)
|
||||||
DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
|
defaultCallbacks.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
|
||||||
DefaultCallback.Update().Register("gorm:after_update", AfterUpdate)
|
defaultCallbacks.Update().Register("gorm:after_update", AfterUpdate)
|
||||||
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
defaultCallbacks.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||||
}
|
}
|
||||||
|
|
44
main.go
44
main.go
|
@ -23,7 +23,7 @@ type DB struct {
|
||||||
Value interface{}
|
Value interface{}
|
||||||
Error error
|
Error error
|
||||||
RowsAffected int64
|
RowsAffected int64
|
||||||
callback *callback
|
callbacks *Callbacks
|
||||||
db sqlCommon
|
db sqlCommon
|
||||||
parent *DB
|
parent *DB
|
||||||
search *search
|
search *search
|
||||||
|
@ -65,12 +65,12 @@ func Open(dialect string, args ...interface{}) (*DB, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
db = DB{
|
db = DB{
|
||||||
dialect: NewDialect(dialect),
|
dialect: NewDialect(dialect),
|
||||||
logger: defaultLogger,
|
logger: defaultLogger,
|
||||||
callback: DefaultCallback,
|
callbacks: defaultCallbacks,
|
||||||
source: source,
|
source: source,
|
||||||
values: map[string]interface{}{},
|
values: map[string]interface{}{},
|
||||||
db: dbSql,
|
db: dbSql,
|
||||||
}
|
}
|
||||||
db.parent = &db
|
db.parent = &db
|
||||||
|
|
||||||
|
@ -111,9 +111,9 @@ func (s *DB) CommonDB() sqlCommon {
|
||||||
return s.db
|
return s.db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Callback() *callback {
|
func (s *DB) Callback() *Callbacks {
|
||||||
s.parent.callback = s.parent.callback.clone()
|
s.parent.callbacks = s.parent.callbacks.clone()
|
||||||
return s.parent.callback
|
return s.parent.callbacks
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) SetLogger(l logger) {
|
func (s *DB) SetLogger(l logger) {
|
||||||
|
@ -201,22 +201,22 @@ func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.clone().NewScope(out)
|
newScope := s.clone().NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.clone().NewScope(out)
|
newScope := s.clone().NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||||
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Scan(dest interface{}) *DB {
|
func (s *DB) Scan(dest interface{}) *DB {
|
||||||
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db
|
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Row() *sql.Row {
|
func (s *DB) Row() *sql.Row {
|
||||||
|
@ -258,9 +258,9 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||||
if !result.RecordNotFound() {
|
if !result.RecordNotFound() {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callback.creates).db.Error)
|
c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db.Error)
|
||||||
} else if len(c.search.assignAttrs) > 0 {
|
} else if len(c.search.assignAttrs) > 0 {
|
||||||
c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callback.updates).db.Error)
|
c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db.Error)
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
@ -273,7 +273,7 @@ func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
||||||
return s.clone().NewScope(s.Value).
|
return s.clone().NewScope(s.Value).
|
||||||
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
||||||
InstanceSet("gorm:update_interface", values).
|
InstanceSet("gorm:update_interface", values).
|
||||||
callCallbacks(s.parent.callback.updates).db
|
callCallbacks(s.parent.callbacks.updates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
||||||
|
@ -285,24 +285,24 @@ func (s *DB) UpdateColumns(values interface{}) *DB {
|
||||||
Set("gorm:update_column", true).
|
Set("gorm:update_column", true).
|
||||||
Set("gorm:save_associations", false).
|
Set("gorm:save_associations", false).
|
||||||
InstanceSet("gorm:update_interface", values).
|
InstanceSet("gorm:update_interface", values).
|
||||||
callCallbacks(s.parent.callback.updates).db
|
callCallbacks(s.parent.callbacks.updates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Save(value interface{}) *DB {
|
func (s *DB) Save(value interface{}) *DB {
|
||||||
scope := s.clone().NewScope(value)
|
scope := s.clone().NewScope(value)
|
||||||
if scope.PrimaryKeyZero() {
|
if scope.PrimaryKeyZero() {
|
||||||
return scope.callCallbacks(s.parent.callback.creates).db
|
return scope.callCallbacks(s.parent.callbacks.creates).db
|
||||||
}
|
}
|
||||||
return scope.callCallbacks(s.parent.callback.updates).db
|
return scope.callCallbacks(s.parent.callbacks.updates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Create(value interface{}) *DB {
|
func (s *DB) Create(value interface{}) *DB {
|
||||||
scope := s.clone().NewScope(value)
|
scope := s.clone().NewScope(value)
|
||||||
return scope.callCallbacks(s.parent.callback.creates).db
|
return scope.callCallbacks(s.parent.callbacks.creates).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
||||||
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db
|
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Raw(sql string, values ...interface{}) *DB {
|
func (s *DB) Raw(sql string, values ...interface{}) *DB {
|
||||||
|
|
|
@ -377,14 +377,14 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
|
||||||
|
|
||||||
func (scope *Scope) row() *sql.Row {
|
func (scope *Scope) row() *sql.Row {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
scope.callCallbacks(scope.db.parent.callback.rowQueries)
|
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||||
scope.prepareQuerySql()
|
scope.prepareQuerySql()
|
||||||
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
|
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||||
defer scope.trace(NowFunc())
|
defer scope.trace(NowFunc())
|
||||||
scope.callCallbacks(scope.db.parent.callback.rowQueries)
|
scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
|
||||||
scope.prepareQuerySql()
|
scope.prepareQuerySql()
|
||||||
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
|
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue