Refactor callbacks

This commit is contained in:
Jinzhu 2020-01-31 08:29:35 +08:00
parent e509b3100d
commit 5959c81be6
3 changed files with 313 additions and 261 deletions

View File

@ -2,26 +2,36 @@ package gorm
import ( import (
"fmt" "fmt"
"log"
"github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/logger"
"github.com/jinzhu/gorm/utils" "github.com/jinzhu/gorm/utils"
) )
// Callbacks gorm callbacks manager func InitializeCallbacks() *callbacks {
type Callbacks struct { return &callbacks{
creates []func(*DB) processors: map[string]*processor{
queries []func(*DB) "create": &processor{},
updates []func(*DB) "query": &processor{},
deletes []func(*DB) "update": &processor{},
row []func(*DB) "delete": &processor{},
raw []func(*DB) "row": &processor{},
db *DB "raw": &processor{},
processors []*processor },
}
}
// callbacks gorm callbacks manager
type callbacks struct {
processors map[string]*processor
} }
type processor struct { type processor struct {
kind string db *DB
fns []func(*DB)
callbacks []*callback
}
type callback struct {
name string name string
before string before string
after string after string
@ -29,79 +39,111 @@ type processor struct {
replace bool replace bool
match func(*DB) bool match func(*DB) bool
handler func(*DB) handler func(*DB)
callbacks *Callbacks processor *processor
} }
func (cs *Callbacks) Create() *processor { func (cs *callbacks) Create() *processor {
return &processor{callbacks: cs, kind: "create"} return cs.processors["create"]
} }
func (cs *Callbacks) Query() *processor { func (cs *callbacks) Query() *processor {
return &processor{callbacks: cs, kind: "query"} return cs.processors["query"]
} }
func (cs *Callbacks) Update() *processor { func (cs *callbacks) Update() *processor {
return &processor{callbacks: cs, kind: "update"} return cs.processors["update"]
} }
func (cs *Callbacks) Delete() *processor { func (cs *callbacks) Delete() *processor {
return &processor{callbacks: cs, kind: "delete"} return cs.processors["delete"]
} }
func (cs *Callbacks) Row() *processor { func (cs *callbacks) Row() *processor {
return &processor{callbacks: cs, kind: "row"} return cs.processors["row"]
} }
func (cs *Callbacks) Raw() *processor { func (cs *callbacks) Raw() *processor {
return &processor{callbacks: cs, kind: "raw"} return cs.processors["raw"]
} }
func (p *processor) Before(name string) *processor { func (p *processor) Execute(db *DB) {
p.before = name for _, f := range p.fns {
return p f(db)
} }
func (p *processor) After(name string) *processor {
p.after = name
return p
}
func (p *processor) Match(fc func(*DB) bool) *processor {
p.match = fc
return p
} }
func (p *processor) Get(name string) func(*DB) { func (p *processor) Get(name string) func(*DB) {
for i := len(p.callbacks.processors) - 1; i >= 0; i-- { for i := len(p.callbacks) - 1; i >= 0; i-- {
if v := p.callbacks.processors[i]; v.name == name && v.kind == v.kind && !v.remove { if v := p.callbacks[i]; v.name == name && !v.remove {
return v.handler return v.handler
} }
} }
return nil return nil
} }
func (p *processor) Register(name string, fn func(*DB)) { func (p *processor) Before(name string) *callback {
p.name = name return &callback{before: name, processor: p}
p.handler = fn
p.callbacks.processors = append(p.callbacks.processors, p)
p.callbacks.compile(p.callbacks.db)
} }
func (p *processor) Remove(name string) { func (p *processor) After(name string) *callback {
logger.Default.Info("removing callback `%v` from %v\n", name, utils.FileWithLineNum()) return &callback{after: name, processor: p}
p.name = name
p.remove = true
p.callbacks.processors = append(p.callbacks.processors, p)
p.callbacks.compile(p.callbacks.db)
} }
func (p *processor) Replace(name string, fn func(*DB)) { func (p *processor) Match(fc func(*DB) bool) *callback {
logger.Default.Info("[info] replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) return &callback{match: fc, processor: p}
p.name = name }
p.handler = fn
p.replace = true func (p *processor) Register(name string, fn func(*DB)) error {
p.callbacks.processors = append(p.callbacks.processors, p) return (&callback{processor: p}).Register(name, fn)
p.callbacks.compile(p.callbacks.db) }
func (p *processor) Remove(name string) error {
return (&callback{processor: p}).Remove(name)
}
func (p *processor) Replace(name string, fn func(*DB)) error {
return (&callback{processor: p}).Replace(name, fn)
}
func (p *processor) compile(db *DB) (err error) {
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
logger.Default.Error("Got error when compile callbacks, got %v", err)
}
return
}
func (c *callback) Before(name string) *callback {
c.before = name
return c
}
func (c *callback) After(name string) *callback {
c.after = name
return c
}
func (c *callback) Register(name string, fn func(*DB)) error {
c.name = name
c.handler = fn
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile(c.processor.db)
}
func (c *callback) Remove(name string) error {
logger.Default.Warn("removing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name
c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile(c.processor.db)
}
func (c *callback) Replace(name string, fn func(*DB)) error {
logger.Default.Info("replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name
c.handler = fn
c.replace = true
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile(c.processor.db)
} }
// getRIndex get right index from string slice // getRIndex get right index from string slice
@ -114,98 +156,81 @@ func getRIndex(strs []string, str string) int {
return -1 return -1
} }
func sortProcessors(ps []*processor) []func(*DB) { func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
var ( var (
allNames, sortedNames []string names, sorted []string
sortProcessor func(*processor) error sortCallback func(*callback) error
) )
for _, p := range ps { for _, c := range cs {
// show warning message the callback name already exists // show warning message the callback name already exists
if idx := getRIndex(allNames, p.name); idx > -1 && !p.replace && !p.remove && !ps[idx].remove { if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
log.Printf("[warning] duplicated callback `%v` from %v\n", p.name, utils.FileWithLineNum()) logger.Default.Warn("duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
} }
allNames = append(allNames, p.name) names = append(names, c.name)
} }
sortProcessor = func(p *processor) error { sortCallback = func(c *callback) error {
if getRIndex(sortedNames, p.name) == -1 { // if not sorted if c.before != "" { // if defined before callback
if p.before != "" { // if defined before callback if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
if sortedIdx := getRIndex(sortedNames, p.before); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
if curIdx := getRIndex(sortedNames, p.name); curIdx != -1 || true {
// if before callback already sorted, append current callback just after it // if before callback already sorted, append current callback just after it
sortedNames = append(sortedNames[:sortedIdx], append([]string{p.name}, sortedNames[sortedIdx:]...)...) sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
} else if curIdx > sortedIdx { } else if curIdx > sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", p.name, p.before) return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
} }
} else if idx := getRIndex(allNames, p.before); idx != -1 { } else if idx := getRIndex(names, c.before); idx != -1 {
// if before callback exists // if before callback exists
ps[idx].after = p.name cs[idx].after = c.name
} }
} }
if p.after != "" { // if defined after callback if c.after != "" { // if defined after callback
if sortedIdx := getRIndex(sortedNames, p.after); sortedIdx != -1 { if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if after callback sorted, append current callback to last // if after callback sorted, append current callback to last
sortedNames = append(sortedNames, p.name) sorted = append(sorted, c.name)
} else if idx := getRIndex(allNames, p.after); idx != -1 { } else if curIdx < sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
}
} else if idx := getRIndex(names, c.after); idx != -1 {
// if after callback exists but haven't sorted // if after callback exists but haven't sorted
// set after callback's before callback to current callback // set after callback's before callback to current callback
if after := ps[idx]; after.before == "" { after := cs[idx]
after.before = p.name
sortProcessor(after) if after.before == "" {
after.before = c.name
}
if err := sortCallback(after); err != nil {
return err
}
if err := sortCallback(c); err != nil {
return err
} }
} }
} }
// if current callback haven't been sorted, append it to last // if current callback haven't been sorted, append it to last
if getRIndex(sortedNames, p.name) == -1 { if getRIndex(sorted, c.name) == -1 {
sortedNames = append(sortedNames, p.name) sorted = append(sorted, c.name)
}
} }
return nil return nil
} }
for _, p := range ps { for _, c := range cs {
sortProcessor(p) if err = sortCallback(c); err != nil {
} return
var fns []func(*DB)
for _, name := range sortedNames {
if idx := getRIndex(allNames, name); !ps[idx].remove {
fns = append(fns, ps[idx].handler)
} }
} }
return fns for _, name := range sorted {
} if idx := getRIndex(names, name); !cs[idx].remove {
fns = append(fns, cs[idx].handler)
// compile processors }
func (cs *Callbacks) compile(db *DB) { }
processors := map[string][]*processor{}
for _, p := range cs.processors { return
if p.name != "" {
if p.match == nil || p.match(db) {
processors[p.kind] = append(processors[p.kind], p)
}
}
}
for name, ps := range processors {
switch name {
case "create":
cs.creates = sortProcessors(ps)
case "query":
cs.queries = sortProcessors(ps)
case "update":
cs.updates = sortProcessors(ps)
case "delete":
cs.deletes = sortProcessors(ps)
case "row":
cs.row = sortProcessors(ps)
case "raw":
cs.raw = sortProcessors(ps)
}
}
} }

View File

@ -1,131 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"runtime"
"strings"
"testing"
)
func assertCallbacks(funcs []func(*DB), fnames []string) (result bool, msg string) {
var got []string
for _, f := range funcs {
got = append(got, getFuncName(f))
}
return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got)
}
func getFuncName(fc func(*DB)) string {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(fc).Pointer()).Name(), ".")
return fnames[len(fnames)-1]
}
func c1(*DB) {}
func c2(*DB) {}
func c3(*DB) {}
func c4(*DB) {}
func c5(*DB) {}
func TestCallbacks(t *testing.T) {
type callback struct {
name string
before string
after string
remove bool
replace bool
err error
match func(*DB) bool
h func(*DB)
}
datas := []struct {
callbacks []callback
results []string
}{
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}},
results: []string{"c1", "c2", "c3", "c4", "c5"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}},
results: []string{"c1", "c5", "c2", "c3", "c4"},
},
{
callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
results: []string{"c1", "c3", "c5", "c2", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
results: []string{"c1", "c5", "c3", "c4"},
},
{
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
results: []string{"c1", "c4", "c3"},
},
}
// func TestRegisterCallbackWithComplexOrder(t *testing.T) {
// var callback2 = &Callback{logger: defaultLogger}
// callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
// callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
// callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
// callback2.Delete().Register("after_create1", afterCreate1)
// callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
// if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
// t.Errorf("register callback with order")
// }
// }
for idx, data := range datas {
callbacks := &Callbacks{}
for _, c := range data.callbacks {
p := callbacks.Create()
if c.name == "" {
c.name = getFuncName(c.h)
}
if c.before != "" {
p = p.Before(c.before)
}
if c.after != "" {
p = p.After(c.after)
}
if c.match != nil {
p = p.Match(c.match)
}
if c.remove {
p.Remove(c.name)
} else if c.replace {
p.Replace(c.name, c.h)
} else {
p.Register(c.name, c.h)
}
}
if ok, msg := assertCallbacks(callbacks.creates, data.results); !ok {
t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg)
}
}
}

158
tests/callbacks_test.go Normal file
View File

@ -0,0 +1,158 @@
package gorm_test
import (
"fmt"
"reflect"
"runtime"
"strings"
"testing"
"github.com/jinzhu/gorm"
)
func assertCallbacks(v interface{}, fnames []string) (result bool, msg string) {
var (
got []string
funcs = reflect.ValueOf(v).Elem().FieldByName("fns")
)
for i := 0; i < funcs.Len(); i++ {
got = append(got, getFuncName(funcs.Index(i)))
}
return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got)
}
func getFuncName(fc interface{}) string {
reflectValue, ok := fc.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(fc)
}
fnames := strings.Split(runtime.FuncForPC(reflectValue.Pointer()).Name(), ".")
return fnames[len(fnames)-1]
}
func c1(*gorm.DB) {}
func c2(*gorm.DB) {}
func c3(*gorm.DB) {}
func c4(*gorm.DB) {}
func c5(*gorm.DB) {}
func TestCallbacks(t *testing.T) {
type callback struct {
name string
before string
after string
remove bool
replace bool
err string
match func(*gorm.DB) bool
h func(*gorm.DB)
}
datas := []struct {
callbacks []callback
err string
results []string
}{
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}},
results: []string{"c1", "c2", "c3", "c4", "c5"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}},
results: []string{"c1", "c5", "c2", "c3", "c4"},
},
{
callbacks: []callback{{h: c1, after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
results: []string{"c3", "c1", "c5", "c2", "c4"},
},
{
callbacks: []callback{{h: c1, before: "c4", after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
results: []string{"c3", "c1", "c5", "c2", "c4"},
},
{
callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
err: "conflicting",
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
results: []string{"c1", "c5", "c3", "c4"},
},
{
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
results: []string{"c1", "c4", "c3"},
},
}
for idx, data := range datas {
var err error
callbacks := gorm.InitializeCallbacks()
for _, c := range data.callbacks {
var v interface{} = callbacks.Create()
callMethod := func(s interface{}, name string, args ...interface{}) {
var argValues []reflect.Value
for _, arg := range args {
argValues = append(argValues, reflect.ValueOf(arg))
}
results := reflect.ValueOf(s).MethodByName(name).Call(argValues)
if len(results) > 0 {
v = results[0].Interface()
}
}
if c.name == "" {
c.name = getFuncName(c.h)
}
if c.before != "" {
callMethod(v, "Before", c.before)
}
if c.after != "" {
callMethod(v, "After", c.after)
}
if c.match != nil {
callMethod(v, "Match", c.match)
}
if c.remove {
callMethod(v, "Remove", c.name)
} else if c.replace {
callMethod(v, "Replace", c.name, c.h)
} else {
callMethod(v, "Register", c.name, c.h)
}
if e, ok := v.(error); !ok || e != nil {
err = e
}
}
if len(data.err) > 0 && err == nil {
t.Errorf("callbacks tests #%v should got error %v, but not", idx+1, data.err)
} else if len(data.err) == 0 && err != nil {
t.Errorf("callbacks tests #%v should not got error, but got %v", idx+1, err)
}
if ok, msg := assertCallbacks(callbacks.Create(), data.results); !ok {
t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg)
}
}
}