diff --git a/callback.go b/callback.go index 3329ba99..1cdeeffb 100644 --- a/callback.go +++ b/callback.go @@ -1,5 +1,7 @@ package gorm +import "fmt" + type callback struct { creates []*func() updates []*func() @@ -13,6 +15,7 @@ type callback_processor struct { before string after string replace bool + remove bool typ string processor *func() callback *callback @@ -57,7 +60,9 @@ func (cp *callback_processor) Register(name string, fc func()) { } func (cp *callback_processor) Remove(name string) { - cp.Replace(name, func() {}) + cp.name = name + cp.remove = true + cp.callback.sort() } func (cp *callback_processor) Replace(name string, fc func()) { @@ -67,10 +72,10 @@ func (cp *callback_processor) Replace(name string, fc func()) { cp.callback.sort() } -func getIndex(strs []string, str string) int { - for index, value := range strs { - if str == value { - return index +func getRIndex(strs []string, str string) int { + for i := len(strs) - 1; i >= 0; i-- { + if strs[i] == str { + return i } } return -1 @@ -81,18 +86,27 @@ func sortProcessors(cps []*callback_processor) []*func() { var names, sortedNames = []string{}, []string{} for _, cp := range cps { + if index := getRIndex(names, cp.name); index > -1 { + if cp.replace { + fmt.Printf("[info] replacing callback `%v` from %v\n", cp.name, fileWithLineNum()) + } else if cp.remove { + fmt.Printf("[info] removing callback `%v` from %v\n", cp.name, fileWithLineNum()) + } else { + fmt.Println("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) + } + } names = append(names, cp.name) } sortCallbackProcessor = func(c *callback_processor, force bool) { - if getIndex(sortedNames, c.name) > -1 { + if getRIndex(sortedNames, c.name) > -1 { return } if len(c.before) > 0 { - if index := getIndex(sortedNames, c.before); index > -1 { + if index := getRIndex(sortedNames, c.before); index > -1 { sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) - } else if index := getIndex(names, c.before); index > -1 { + } else if index := getRIndex(names, c.before); index > -1 { sortedNames = append(sortedNames, c.name) sortCallbackProcessor(cps[index], true) } else { @@ -101,9 +115,9 @@ func sortProcessors(cps []*callback_processor) []*func() { } if len(c.after) > 0 { - if index := getIndex(sortedNames, c.after); index > -1 { + if index := getRIndex(sortedNames, c.after); index > -1 { sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) - } else if index := getIndex(names, c.after); index > -1 { + } else if index := getRIndex(names, c.after); index > -1 { cp := cps[index] if len(cp.before) == 0 { cp.before = c.name @@ -114,7 +128,7 @@ func sortProcessors(cps []*callback_processor) []*func() { } } - if getIndex(sortedNames, c.name) == -1 && force { + if getRIndex(sortedNames, c.name) == -1 && force { sortedNames = append(sortedNames, c.name) } } @@ -126,13 +140,17 @@ func sortProcessors(cps []*callback_processor) []*func() { var funcs = []*func(){} var sortedFuncs = []*func(){} for _, name := range sortedNames { - index := getIndex(names, name) - sortedFuncs = append(sortedFuncs, cps[index].processor) + index := getRIndex(names, name) + if !cps[index].remove { + sortedFuncs = append(sortedFuncs, cps[index].processor) + } } for _, cp := range cps { - if sindex := getIndex(sortedNames, cp.name); sindex == -1 { - funcs = append(funcs, cp.processor) + if sindex := getRIndex(sortedNames, cp.name); sindex == -1 { + if !cp.remove { + funcs = append(funcs, cp.processor) + } } } diff --git a/callback_test.go b/callback_test.go index 2749c945..2f4156d7 100644 --- a/callback_test.go +++ b/callback_test.go @@ -75,3 +75,31 @@ func TestRegisterCallbackWithComplexOrder2(t *testing.T) { t.Errorf("register callback with order") } } + +func replace_create() {} + +func TestReplaceCallback(t *testing.T) { + var callback = &callback{processors: []*callback_processor{}} + + callback.Create().Before("after_create1").After("before_create1").Register("create", create) + callback.Create().Register("before_create1", before_create1) + callback.Create().Register("after_create1", after_create1) + callback.Create().Replace("create", replace_create) + + if !equalFuncs(callback.creates, []string{"before_create1", "replace_create", "after_create1"}) { + t.Errorf("replace callback") + } +} + +func TestRemoveCallback(t *testing.T) { + var callback = &callback{processors: []*callback_processor{}} + + callback.Create().Before("after_create1").After("before_create1").Register("create", create) + callback.Create().Register("before_create1", before_create1) + callback.Create().Register("after_create1", after_create1) + callback.Create().Remove("create") + + if !equalFuncs(callback.creates, []string{"before_create1", "after_create1"}) { + t.Errorf("remove callback") + } +} diff --git a/private.go b/private.go index c107c371..7b72b5fd 100644 --- a/private.go +++ b/private.go @@ -56,7 +56,7 @@ func (s *DB) hasError() bool { } func fileWithLineNum() string { - for i := 5; i < 15; i++ { + for i := 1; i < 15; i++ { _, file, line, ok := runtime.Caller(i) if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)