diff --git a/callbacks.go b/callbacks.go index d53e8049..a7f30612 100644 --- a/callbacks.go +++ b/callbacks.go @@ -2,26 +2,36 @@ package gorm import ( "fmt" - "log" "github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/utils" ) -// Callbacks gorm callbacks manager -type Callbacks struct { - creates []func(*DB) - queries []func(*DB) - updates []func(*DB) - deletes []func(*DB) - row []func(*DB) - raw []func(*DB) - db *DB - processors []*processor +func InitializeCallbacks() *callbacks { + return &callbacks{ + processors: map[string]*processor{ + "create": &processor{}, + "query": &processor{}, + "update": &processor{}, + "delete": &processor{}, + "row": &processor{}, + "raw": &processor{}, + }, + } +} + +// callbacks gorm callbacks manager +type callbacks struct { + processors map[string]*processor } type processor struct { - kind string + db *DB + fns []func(*DB) + callbacks []*callback +} + +type callback struct { name string before string after string @@ -29,79 +39,111 @@ type processor struct { replace bool match func(*DB) bool handler func(*DB) - callbacks *Callbacks + processor *processor } -func (cs *Callbacks) Create() *processor { - return &processor{callbacks: cs, kind: "create"} +func (cs *callbacks) Create() *processor { + return cs.processors["create"] } -func (cs *Callbacks) Query() *processor { - return &processor{callbacks: cs, kind: "query"} +func (cs *callbacks) Query() *processor { + return cs.processors["query"] } -func (cs *Callbacks) Update() *processor { - return &processor{callbacks: cs, kind: "update"} +func (cs *callbacks) Update() *processor { + return cs.processors["update"] } -func (cs *Callbacks) Delete() *processor { - return &processor{callbacks: cs, kind: "delete"} +func (cs *callbacks) Delete() *processor { + return cs.processors["delete"] } -func (cs *Callbacks) Row() *processor { - return &processor{callbacks: cs, kind: "row"} +func (cs *callbacks) Row() *processor { + return cs.processors["row"] } -func (cs *Callbacks) Raw() *processor { - return &processor{callbacks: cs, kind: "raw"} +func (cs *callbacks) Raw() *processor { + return cs.processors["raw"] } -func (p *processor) Before(name string) *processor { - p.before = name - return p -} - -func (p *processor) After(name string) *processor { - p.after = name - return p -} - -func (p *processor) Match(fc func(*DB) bool) *processor { - p.match = fc - return p +func (p *processor) Execute(db *DB) { + for _, f := range p.fns { + f(db) + } } func (p *processor) Get(name string) func(*DB) { - for i := len(p.callbacks.processors) - 1; i >= 0; i-- { - if v := p.callbacks.processors[i]; v.name == name && v.kind == v.kind && !v.remove { + for i := len(p.callbacks) - 1; i >= 0; i-- { + if v := p.callbacks[i]; v.name == name && !v.remove { return v.handler } } return nil } -func (p *processor) Register(name string, fn func(*DB)) { - p.name = name - p.handler = fn - p.callbacks.processors = append(p.callbacks.processors, p) - p.callbacks.compile(p.callbacks.db) +func (p *processor) Before(name string) *callback { + return &callback{before: name, processor: p} } -func (p *processor) Remove(name string) { - logger.Default.Info("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) - p.name = name - p.remove = true - p.callbacks.processors = append(p.callbacks.processors, p) - p.callbacks.compile(p.callbacks.db) +func (p *processor) After(name string) *callback { + return &callback{after: name, processor: p} } -func (p *processor) Replace(name string, fn func(*DB)) { - logger.Default.Info("[info] replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) - p.name = name - p.handler = fn - p.replace = true - p.callbacks.processors = append(p.callbacks.processors, p) - p.callbacks.compile(p.callbacks.db) +func (p *processor) Match(fc func(*DB) bool) *callback { + return &callback{match: fc, processor: p} +} + +func (p *processor) Register(name string, fn func(*DB)) error { + return (&callback{processor: p}).Register(name, fn) +} + +func (p *processor) Remove(name string) error { + return (&callback{processor: p}).Remove(name) +} + +func (p *processor) Replace(name string, fn func(*DB)) error { + return (&callback{processor: p}).Replace(name, fn) +} + +func (p *processor) compile(db *DB) (err error) { + if p.fns, err = sortCallbacks(p.callbacks); err != nil { + logger.Default.Error("Got error when compile callbacks, got %v", err) + } + return +} + +func (c *callback) Before(name string) *callback { + c.before = name + return c +} + +func (c *callback) After(name string) *callback { + c.after = name + return c +} + +func (c *callback) Register(name string, fn func(*DB)) error { + c.name = name + c.handler = fn + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile(c.processor.db) +} + +func (c *callback) Remove(name string) error { + logger.Default.Warn("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.name = name + c.remove = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile(c.processor.db) +} + +func (c *callback) Replace(name string, fn func(*DB)) error { + logger.Default.Info("replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.name = name + c.handler = fn + c.replace = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile(c.processor.db) } // getRIndex get right index from string slice @@ -114,98 +156,81 @@ func getRIndex(strs []string, str string) int { return -1 } -func sortProcessors(ps []*processor) []func(*DB) { +func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { var ( - allNames, sortedNames []string - sortProcessor func(*processor) error + names, sorted []string + sortCallback func(*callback) error ) - for _, p := range ps { + for _, c := range cs { // show warning message the callback name already exists - if idx := getRIndex(allNames, p.name); idx > -1 && !p.replace && !p.remove && !ps[idx].remove { - log.Printf("[warning] duplicated callback `%v` from %v\n", p.name, utils.FileWithLineNum()) + if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { + logger.Default.Warn("duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) } - allNames = append(allNames, p.name) + names = append(names, c.name) } - sortProcessor = func(p *processor) error { - if getRIndex(sortedNames, p.name) == -1 { // if not sorted - if p.before != "" { // if defined before callback - if sortedIdx := getRIndex(sortedNames, p.before); sortedIdx != -1 { - if curIdx := getRIndex(sortedNames, p.name); curIdx != -1 || true { - // if before callback already sorted, append current callback just after it - sortedNames = append(sortedNames[:sortedIdx], append([]string{p.name}, sortedNames[sortedIdx:]...)...) - } else if curIdx > sortedIdx { - return fmt.Errorf("conflicting callback %v with before %v", p.name, p.before) - } - } else if idx := getRIndex(allNames, p.before); idx != -1 { - // if before callback exists - ps[idx].after = p.name + sortCallback = func(c *callback) error { + if c.before != "" { // if defined before callback + if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + // if before callback already sorted, append current callback just after it + sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) + } else if curIdx > sortedIdx { + return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before) } + } else if idx := getRIndex(names, c.before); idx != -1 { + // if before callback exists + cs[idx].after = c.name } + } - if p.after != "" { // if defined after callback - if sortedIdx := getRIndex(sortedNames, p.after); sortedIdx != -1 { + if c.after != "" { // if defined after callback + if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if after callback sorted, append current callback to last - sortedNames = append(sortedNames, p.name) - } else if idx := getRIndex(allNames, p.after); idx != -1 { - // if after callback exists but haven't sorted - // set after callback's before callback to current callback - if after := ps[idx]; after.before == "" { - after.before = p.name - sortProcessor(after) - } + sorted = append(sorted, c.name) + } else if curIdx < sortedIdx { + return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after) + } + } else if idx := getRIndex(names, c.after); idx != -1 { + // if after callback exists but haven't sorted + // set after callback's before callback to current callback + after := cs[idx] + + if after.before == "" { + after.before = c.name + } + + if err := sortCallback(after); err != nil { + return err + } + + if err := sortCallback(c); err != nil { + return err } } + } - // if current callback haven't been sorted, append it to last - if getRIndex(sortedNames, p.name) == -1 { - sortedNames = append(sortedNames, p.name) - } + // if current callback haven't been sorted, append it to last + if getRIndex(sorted, c.name) == -1 { + sorted = append(sorted, c.name) } return nil } - for _, p := range ps { - sortProcessor(p) - } - - var fns []func(*DB) - for _, name := range sortedNames { - if idx := getRIndex(allNames, name); !ps[idx].remove { - fns = append(fns, ps[idx].handler) + for _, c := range cs { + if err = sortCallback(c); err != nil { + return } } - return fns -} - -// compile processors -func (cs *Callbacks) compile(db *DB) { - processors := map[string][]*processor{} - for _, p := range cs.processors { - if p.name != "" { - if p.match == nil || p.match(db) { - processors[p.kind] = append(processors[p.kind], p) - } - } - } - - for name, ps := range processors { - switch name { - case "create": - cs.creates = sortProcessors(ps) - case "query": - cs.queries = sortProcessors(ps) - case "update": - cs.updates = sortProcessors(ps) - case "delete": - cs.deletes = sortProcessors(ps) - case "row": - cs.row = sortProcessors(ps) - case "raw": - cs.raw = sortProcessors(ps) - } - } + for _, name := range sorted { + if idx := getRIndex(names, name); !cs[idx].remove { + fns = append(fns, cs[idx].handler) + } + } + + return } diff --git a/callbacks_test.go b/callbacks_test.go deleted file mode 100644 index 547cdca1..00000000 --- a/callbacks_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "runtime" - "strings" - "testing" -) - -func assertCallbacks(funcs []func(*DB), fnames []string) (result bool, msg string) { - var got []string - - for _, f := range funcs { - got = append(got, getFuncName(f)) - } - - return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) -} - -func getFuncName(fc func(*DB)) string { - fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(fc).Pointer()).Name(), ".") - return fnames[len(fnames)-1] -} - -func c1(*DB) {} -func c2(*DB) {} -func c3(*DB) {} -func c4(*DB) {} -func c5(*DB) {} - -func TestCallbacks(t *testing.T) { - type callback struct { - name string - before string - after string - remove bool - replace bool - err error - match func(*DB) bool - h func(*DB) - } - - datas := []struct { - callbacks []callback - results []string - }{ - { - callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, - results: []string{"c1", "c2", "c3", "c4", "c5"}, - }, - { - callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, - results: []string{"c1", "c2", "c3", "c5", "c4"}, - }, - { - callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, - results: []string{"c1", "c2", "c3", "c5", "c4"}, - }, - { - callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, - results: []string{"c1", "c2", "c3", "c5", "c4"}, - }, - { - callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, - results: []string{"c1", "c5", "c2", "c3", "c4"}, - }, - { - callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, - results: []string{"c1", "c3", "c5", "c2", "c4"}, - }, - { - callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, - results: []string{"c1", "c5", "c3", "c4"}, - }, - { - callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, - results: []string{"c1", "c4", "c3"}, - }, - } - - // func TestRegisterCallbackWithComplexOrder(t *testing.T) { - // var callback2 = &Callback{logger: defaultLogger} - - // callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) - // callback2.Delete().Before("create").Register("before_create1", beforeCreate1) - // callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) - // callback2.Delete().Register("after_create1", afterCreate1) - // callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) - - // if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - // t.Errorf("register callback with order") - // } - // } - - for idx, data := range datas { - callbacks := &Callbacks{} - - for _, c := range data.callbacks { - p := callbacks.Create() - - if c.name == "" { - c.name = getFuncName(c.h) - } - - if c.before != "" { - p = p.Before(c.before) - } - - if c.after != "" { - p = p.After(c.after) - } - - if c.match != nil { - p = p.Match(c.match) - } - - if c.remove { - p.Remove(c.name) - } else if c.replace { - p.Replace(c.name, c.h) - } else { - p.Register(c.name, c.h) - } - } - - if ok, msg := assertCallbacks(callbacks.creates, data.results); !ok { - t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) - } - } -} diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go new file mode 100644 index 00000000..878384a7 --- /dev/null +++ b/tests/callbacks_test.go @@ -0,0 +1,158 @@ +package gorm_test + +import ( + "fmt" + "reflect" + "runtime" + "strings" + "testing" + + "github.com/jinzhu/gorm" +) + +func assertCallbacks(v interface{}, fnames []string) (result bool, msg string) { + var ( + got []string + funcs = reflect.ValueOf(v).Elem().FieldByName("fns") + ) + + for i := 0; i < funcs.Len(); i++ { + got = append(got, getFuncName(funcs.Index(i))) + } + + return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) +} + +func getFuncName(fc interface{}) string { + reflectValue, ok := fc.(reflect.Value) + if !ok { + reflectValue = reflect.ValueOf(fc) + } + + fnames := strings.Split(runtime.FuncForPC(reflectValue.Pointer()).Name(), ".") + return fnames[len(fnames)-1] +} + +func c1(*gorm.DB) {} +func c2(*gorm.DB) {} +func c3(*gorm.DB) {} +func c4(*gorm.DB) {} +func c5(*gorm.DB) {} + +func TestCallbacks(t *testing.T) { + type callback struct { + name string + before string + after string + remove bool + replace bool + err string + match func(*gorm.DB) bool + h func(*gorm.DB) + } + + datas := []struct { + callbacks []callback + err string + results []string + }{ + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c4", "c5"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c5", "c2", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1, after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + results: []string{"c3", "c1", "c5", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1, before: "c4", after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + results: []string{"c3", "c1", "c5", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + err: "conflicting", + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, + results: []string{"c1", "c5", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, + results: []string{"c1", "c4", "c3"}, + }, + } + + for idx, data := range datas { + var err error + callbacks := gorm.InitializeCallbacks() + + for _, c := range data.callbacks { + var v interface{} = callbacks.Create() + callMethod := func(s interface{}, name string, args ...interface{}) { + var argValues []reflect.Value + for _, arg := range args { + argValues = append(argValues, reflect.ValueOf(arg)) + } + + results := reflect.ValueOf(s).MethodByName(name).Call(argValues) + if len(results) > 0 { + v = results[0].Interface() + } + } + + if c.name == "" { + c.name = getFuncName(c.h) + } + + if c.before != "" { + callMethod(v, "Before", c.before) + } + + if c.after != "" { + callMethod(v, "After", c.after) + } + + if c.match != nil { + callMethod(v, "Match", c.match) + } + + if c.remove { + callMethod(v, "Remove", c.name) + } else if c.replace { + callMethod(v, "Replace", c.name, c.h) + } else { + callMethod(v, "Register", c.name, c.h) + } + + if e, ok := v.(error); !ok || e != nil { + err = e + } + } + + if len(data.err) > 0 && err == nil { + t.Errorf("callbacks tests #%v should got error %v, but not", idx+1, data.err) + } else if len(data.err) == 0 && err != nil { + t.Errorf("callbacks tests #%v should not got error, but got %v", idx+1, err) + } + + if ok, msg := assertCallbacks(callbacks.Create(), data.results); !ok { + t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) + } + } +}