diff --git a/callback.go b/callback.go index 82eb2121..3329ba99 100644 --- a/callback.go +++ b/callback.go @@ -1,10 +1,10 @@ package gorm type callback struct { - create []func() - update []func() - delete []func() - query []func() + creates []*func() + updates []*func() + deletes []*func() + queries []*func() processors []*callback_processor } @@ -14,7 +14,7 @@ type callback_processor struct { after string replace bool typ string - processor func() + processor *func() callback *callback } @@ -40,7 +40,106 @@ func (c *callback) Query() *callback_processor { return c.addProcessor("query") } -func (c *callback) Sort() { +func (cp *callback_processor) Before(name string) *callback_processor { + cp.before = name + return cp +} + +func (cp *callback_processor) After(name string) *callback_processor { + cp.after = name + return cp +} + +func (cp *callback_processor) Register(name string, fc func()) { + cp.name = name + cp.processor = &fc + cp.callback.sort() +} + +func (cp *callback_processor) Remove(name string) { + cp.Replace(name, func() {}) +} + +func (cp *callback_processor) Replace(name string, fc func()) { + cp.name = name + cp.processor = &fc + cp.replace = true + cp.callback.sort() +} + +func getIndex(strs []string, str string) int { + for index, value := range strs { + if str == value { + return index + } + } + return -1 +} + +func sortProcessors(cps []*callback_processor) []*func() { + var sortCallbackProcessor func(c *callback_processor, force bool) + var names, sortedNames = []string{}, []string{} + + for _, cp := range cps { + names = append(names, cp.name) + } + + sortCallbackProcessor = func(c *callback_processor, force bool) { + if getIndex(sortedNames, c.name) > -1 { + return + } + + if len(c.before) > 0 { + if index := getIndex(sortedNames, c.before); index > -1 { + sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) + } else if index := getIndex(names, c.before); index > -1 { + sortedNames = append(sortedNames, c.name) + sortCallbackProcessor(cps[index], true) + } else { + sortedNames = append(sortedNames, c.name) + } + } + + if len(c.after) > 0 { + if index := getIndex(sortedNames, c.after); index > -1 { + sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) + } else if index := getIndex(names, c.after); index > -1 { + cp := cps[index] + if len(cp.before) == 0 { + cp.before = c.name + } + sortCallbackProcessor(cp, true) + } else { + sortedNames = append(sortedNames, c.name) + } + } + + if getIndex(sortedNames, c.name) == -1 && force { + sortedNames = append(sortedNames, c.name) + } + } + + for _, cp := range cps { + sortCallbackProcessor(cp, false) + } + + var funcs = []*func(){} + var sortedFuncs = []*func(){} + for _, name := range sortedNames { + index := getIndex(names, name) + sortedFuncs = append(sortedFuncs, cps[index].processor) + } + + for _, cp := range cps { + if sindex := getIndex(sortedNames, cp.name); sindex == -1 { + funcs = append(funcs, cp.processor) + } + } + + return append(sortedFuncs, funcs...) +} + +func (c *callback) sort() { creates, updates, deletes, queries := []*callback_processor{}, []*callback_processor{}, []*callback_processor{}, []*callback_processor{} for _, processor := range c.processors { @@ -55,33 +154,11 @@ func (c *callback) Sort() { queries = append(queries, processor) } } -} -func (cp *callback_processor) Before(name string) *callback_processor { - cp.before = name - return cp -} - -func (cp *callback_processor) After(name string) *callback_processor { - cp.after = name - return cp -} - -func (cp *callback_processor) Register(name string, fc func()) { - cp.name = name - cp.processor = fc - cp.callback.Sort() -} - -func (cp *callback_processor) Remove(name string) { - cp.Replace(name, func() {}) -} - -func (cp *callback_processor) Replace(name string, fc func()) { - cp.name = name - cp.processor = fc - cp.replace = true - cp.callback.Sort() + c.creates = sortProcessors(creates) + c.updates = sortProcessors(updates) + c.deletes = sortProcessors(deletes) + c.queries = sortProcessors(queries) } var DefaultCallback = &callback{processors: []*callback_processor{}} diff --git a/callback_test.go b/callback_test.go new file mode 100644 index 00000000..2749c945 --- /dev/null +++ b/callback_test.go @@ -0,0 +1,77 @@ +package gorm + +import ( + "reflect" + "runtime" + "strings" + "testing" +) + +func equalFuncs(funcs []*func(), fnames []string) bool { + var names []string + for _, f := range funcs { + fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") + names = append(names, fnames[len(fnames)-1]) + } + return reflect.DeepEqual(names, fnames) +} + +func create() {} +func before_create1() {} +func before_create2() {} +func after_create1() {} +func after_create2() {} + +func TestRegisterCallback(t *testing.T) { + var callback = &callback{processors: []*callback_processor{}} + + callback.Create().Register("before_create1", before_create1) + callback.Create().Register("before_create2", before_create2) + callback.Create().Register("create", create) + callback.Create().Register("after_create1", after_create1) + callback.Create().Register("after_create2", after_create2) + + if !equalFuncs(callback.creates, []string{"before_create1", "before_create2", "create", "after_create1", "after_create2"}) { + t.Errorf("register callback") + } +} + +func TestRegisterCallbackWithBasicOrder(t *testing.T) { + var callback = &callback{processors: []*callback_processor{}} + + callback.Update().Register("create", create) + callback.Update().Before("create").Register("before_create1", before_create1) + callback.Update().After("after_create2").Register("after_create1", after_create1) + callback.Update().Before("before_create1").Register("before_create2", before_create2) + callback.Update().Register("after_create2", after_create2) + + if !equalFuncs(callback.updates, []string{"before_create2", "before_create1", "create", "after_create2", "after_create1"}) { + t.Errorf("register callback with order") + } +} + +func TestRegisterCallbackWithComplexOrder1(t *testing.T) { + var callback = &callback{processors: []*callback_processor{}} + + callback.Query().Before("after_create1").After("before_create1").Register("create", create) + callback.Query().Register("before_create1", before_create1) + callback.Query().Register("after_create1", after_create1) + + if !equalFuncs(callback.queries, []string{"before_create1", "create", "after_create1"}) { + t.Errorf("register callback with order") + } +} + +func TestRegisterCallbackWithComplexOrder2(t *testing.T) { + var callback = &callback{processors: []*callback_processor{}} + + callback.Delete().Before("after_create1").After("before_create1").Register("create", create) + callback.Delete().Before("create").Register("before_create1", before_create1) + callback.Delete().After("before_create1").Register("before_create2", before_create2) + callback.Delete().Register("after_create1", after_create1) + callback.Delete().After("after_create1").Register("after_create2", after_create2) + + if !equalFuncs(callback.deletes, []string{"before_create1", "before_create2", "create", "after_create1", "after_create2"}) { + t.Errorf("register callback with order") + } +} diff --git a/gorm_test.go b/gorm_test.go index c2557c9b..e3eba586 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -92,8 +92,6 @@ var ( ) func init() { - db.Debug().Model(User{}).RemoveIndex("name") - var err error switch os.Getenv("GORM_DIALECT") { case "mysql":