Add tests for replace, remove callback

This commit is contained in:
Jinzhu 2014-01-25 20:04:01 +08:00
parent a8b13f015f
commit 192ed06ad8
3 changed files with 62 additions and 16 deletions

View File

@ -1,5 +1,7 @@
package gorm package gorm
import "fmt"
type callback struct { type callback struct {
creates []*func() creates []*func()
updates []*func() updates []*func()
@ -13,6 +15,7 @@ type callback_processor struct {
before string before string
after string after string
replace bool replace bool
remove bool
typ string typ string
processor *func() processor *func()
callback *callback callback *callback
@ -57,7 +60,9 @@ func (cp *callback_processor) Register(name string, fc func()) {
} }
func (cp *callback_processor) Remove(name string) { func (cp *callback_processor) Remove(name string) {
cp.Replace(name, func() {}) cp.name = name
cp.remove = true
cp.callback.sort()
} }
func (cp *callback_processor) Replace(name string, fc func()) { func (cp *callback_processor) Replace(name string, fc func()) {
@ -67,10 +72,10 @@ func (cp *callback_processor) Replace(name string, fc func()) {
cp.callback.sort() cp.callback.sort()
} }
func getIndex(strs []string, str string) int { func getRIndex(strs []string, str string) int {
for index, value := range strs { for i := len(strs) - 1; i >= 0; i-- {
if str == value { if strs[i] == str {
return index return i
} }
} }
return -1 return -1
@ -81,18 +86,27 @@ func sortProcessors(cps []*callback_processor) []*func() {
var names, sortedNames = []string{}, []string{} var names, sortedNames = []string{}, []string{}
for _, cp := range cps { for _, cp := range cps {
if index := getRIndex(names, cp.name); index > -1 {
if cp.replace {
fmt.Printf("[info] replacing callback `%v` from %v\n", cp.name, fileWithLineNum())
} else if cp.remove {
fmt.Printf("[info] removing callback `%v` from %v\n", cp.name, fileWithLineNum())
} else {
fmt.Println("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
}
}
names = append(names, cp.name) names = append(names, cp.name)
} }
sortCallbackProcessor = func(c *callback_processor, force bool) { sortCallbackProcessor = func(c *callback_processor, force bool) {
if getIndex(sortedNames, c.name) > -1 { if getRIndex(sortedNames, c.name) > -1 {
return return
} }
if len(c.before) > 0 { if len(c.before) > 0 {
if index := getIndex(sortedNames, c.before); index > -1 { if index := getRIndex(sortedNames, c.before); index > -1 {
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
} else if index := getIndex(names, c.before); index > -1 { } else if index := getRIndex(names, c.before); index > -1 {
sortedNames = append(sortedNames, c.name) sortedNames = append(sortedNames, c.name)
sortCallbackProcessor(cps[index], true) sortCallbackProcessor(cps[index], true)
} else { } else {
@ -101,9 +115,9 @@ func sortProcessors(cps []*callback_processor) []*func() {
} }
if len(c.after) > 0 { if len(c.after) > 0 {
if index := getIndex(sortedNames, c.after); index > -1 { if index := getRIndex(sortedNames, c.after); index > -1 {
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
} else if index := getIndex(names, c.after); index > -1 { } else if index := getRIndex(names, c.after); index > -1 {
cp := cps[index] cp := cps[index]
if len(cp.before) == 0 { if len(cp.before) == 0 {
cp.before = c.name cp.before = c.name
@ -114,7 +128,7 @@ func sortProcessors(cps []*callback_processor) []*func() {
} }
} }
if getIndex(sortedNames, c.name) == -1 && force { if getRIndex(sortedNames, c.name) == -1 && force {
sortedNames = append(sortedNames, c.name) sortedNames = append(sortedNames, c.name)
} }
} }
@ -126,15 +140,19 @@ func sortProcessors(cps []*callback_processor) []*func() {
var funcs = []*func(){} var funcs = []*func(){}
var sortedFuncs = []*func(){} var sortedFuncs = []*func(){}
for _, name := range sortedNames { for _, name := range sortedNames {
index := getIndex(names, name) index := getRIndex(names, name)
if !cps[index].remove {
sortedFuncs = append(sortedFuncs, cps[index].processor) sortedFuncs = append(sortedFuncs, cps[index].processor)
} }
}
for _, cp := range cps { for _, cp := range cps {
if sindex := getIndex(sortedNames, cp.name); sindex == -1 { if sindex := getRIndex(sortedNames, cp.name); sindex == -1 {
if !cp.remove {
funcs = append(funcs, cp.processor) funcs = append(funcs, cp.processor)
} }
} }
}
return append(sortedFuncs, funcs...) return append(sortedFuncs, funcs...)
} }

View File

@ -75,3 +75,31 @@ func TestRegisterCallbackWithComplexOrder2(t *testing.T) {
t.Errorf("register callback with order") t.Errorf("register callback with order")
} }
} }
func replace_create() {}
func TestReplaceCallback(t *testing.T) {
var callback = &callback{processors: []*callback_processor{}}
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", before_create1)
callback.Create().Register("after_create1", after_create1)
callback.Create().Replace("create", replace_create)
if !equalFuncs(callback.creates, []string{"before_create1", "replace_create", "after_create1"}) {
t.Errorf("replace callback")
}
}
func TestRemoveCallback(t *testing.T) {
var callback = &callback{processors: []*callback_processor{}}
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", before_create1)
callback.Create().Register("after_create1", after_create1)
callback.Create().Remove("create")
if !equalFuncs(callback.creates, []string{"before_create1", "after_create1"}) {
t.Errorf("remove callback")
}
}

View File

@ -56,7 +56,7 @@ func (s *DB) hasError() bool {
} }
func fileWithLineNum() string { func fileWithLineNum() string {
for i := 5; i < 15; i++ { for i := 1; i < 15; i++ {
_, file, line, ok := runtime.Caller(i) _, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line) return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)