mirror of https://github.com/go-gorm/gorm.git
Refactor callbacks
This commit is contained in:
parent
e509b3100d
commit
5959c81be6
285
callbacks.go
285
callbacks.go
|
@ -2,26 +2,36 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/jinzhu/gorm/logger"
|
"github.com/jinzhu/gorm/logger"
|
||||||
"github.com/jinzhu/gorm/utils"
|
"github.com/jinzhu/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Callbacks gorm callbacks manager
|
func InitializeCallbacks() *callbacks {
|
||||||
type Callbacks struct {
|
return &callbacks{
|
||||||
creates []func(*DB)
|
processors: map[string]*processor{
|
||||||
queries []func(*DB)
|
"create": &processor{},
|
||||||
updates []func(*DB)
|
"query": &processor{},
|
||||||
deletes []func(*DB)
|
"update": &processor{},
|
||||||
row []func(*DB)
|
"delete": &processor{},
|
||||||
raw []func(*DB)
|
"row": &processor{},
|
||||||
db *DB
|
"raw": &processor{},
|
||||||
processors []*processor
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// callbacks gorm callbacks manager
|
||||||
|
type callbacks struct {
|
||||||
|
processors map[string]*processor
|
||||||
}
|
}
|
||||||
|
|
||||||
type processor struct {
|
type processor struct {
|
||||||
kind string
|
db *DB
|
||||||
|
fns []func(*DB)
|
||||||
|
callbacks []*callback
|
||||||
|
}
|
||||||
|
|
||||||
|
type callback struct {
|
||||||
name string
|
name string
|
||||||
before string
|
before string
|
||||||
after string
|
after string
|
||||||
|
@ -29,79 +39,111 @@ type processor struct {
|
||||||
replace bool
|
replace bool
|
||||||
match func(*DB) bool
|
match func(*DB) bool
|
||||||
handler func(*DB)
|
handler func(*DB)
|
||||||
callbacks *Callbacks
|
processor *processor
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *Callbacks) Create() *processor {
|
func (cs *callbacks) Create() *processor {
|
||||||
return &processor{callbacks: cs, kind: "create"}
|
return cs.processors["create"]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *Callbacks) Query() *processor {
|
func (cs *callbacks) Query() *processor {
|
||||||
return &processor{callbacks: cs, kind: "query"}
|
return cs.processors["query"]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *Callbacks) Update() *processor {
|
func (cs *callbacks) Update() *processor {
|
||||||
return &processor{callbacks: cs, kind: "update"}
|
return cs.processors["update"]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *Callbacks) Delete() *processor {
|
func (cs *callbacks) Delete() *processor {
|
||||||
return &processor{callbacks: cs, kind: "delete"}
|
return cs.processors["delete"]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *Callbacks) Row() *processor {
|
func (cs *callbacks) Row() *processor {
|
||||||
return &processor{callbacks: cs, kind: "row"}
|
return cs.processors["row"]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *Callbacks) Raw() *processor {
|
func (cs *callbacks) Raw() *processor {
|
||||||
return &processor{callbacks: cs, kind: "raw"}
|
return cs.processors["raw"]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) Before(name string) *processor {
|
func (p *processor) Execute(db *DB) {
|
||||||
p.before = name
|
for _, f := range p.fns {
|
||||||
return p
|
f(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) After(name string) *processor {
|
|
||||||
p.after = name
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *processor) Match(fc func(*DB) bool) *processor {
|
|
||||||
p.match = fc
|
|
||||||
return p
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) Get(name string) func(*DB) {
|
func (p *processor) Get(name string) func(*DB) {
|
||||||
for i := len(p.callbacks.processors) - 1; i >= 0; i-- {
|
for i := len(p.callbacks) - 1; i >= 0; i-- {
|
||||||
if v := p.callbacks.processors[i]; v.name == name && v.kind == v.kind && !v.remove {
|
if v := p.callbacks[i]; v.name == name && !v.remove {
|
||||||
return v.handler
|
return v.handler
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) Register(name string, fn func(*DB)) {
|
func (p *processor) Before(name string) *callback {
|
||||||
p.name = name
|
return &callback{before: name, processor: p}
|
||||||
p.handler = fn
|
|
||||||
p.callbacks.processors = append(p.callbacks.processors, p)
|
|
||||||
p.callbacks.compile(p.callbacks.db)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) Remove(name string) {
|
func (p *processor) After(name string) *callback {
|
||||||
logger.Default.Info("removing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
return &callback{after: name, processor: p}
|
||||||
p.name = name
|
|
||||||
p.remove = true
|
|
||||||
p.callbacks.processors = append(p.callbacks.processors, p)
|
|
||||||
p.callbacks.compile(p.callbacks.db)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) Replace(name string, fn func(*DB)) {
|
func (p *processor) Match(fc func(*DB) bool) *callback {
|
||||||
logger.Default.Info("[info] replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
return &callback{match: fc, processor: p}
|
||||||
p.name = name
|
}
|
||||||
p.handler = fn
|
|
||||||
p.replace = true
|
func (p *processor) Register(name string, fn func(*DB)) error {
|
||||||
p.callbacks.processors = append(p.callbacks.processors, p)
|
return (&callback{processor: p}).Register(name, fn)
|
||||||
p.callbacks.compile(p.callbacks.db)
|
}
|
||||||
|
|
||||||
|
func (p *processor) Remove(name string) error {
|
||||||
|
return (&callback{processor: p}).Remove(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *processor) Replace(name string, fn func(*DB)) error {
|
||||||
|
return (&callback{processor: p}).Replace(name, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *processor) compile(db *DB) (err error) {
|
||||||
|
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
||||||
|
logger.Default.Error("Got error when compile callbacks, got %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *callback) Before(name string) *callback {
|
||||||
|
c.before = name
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *callback) After(name string) *callback {
|
||||||
|
c.after = name
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *callback) Register(name string, fn func(*DB)) error {
|
||||||
|
c.name = name
|
||||||
|
c.handler = fn
|
||||||
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||||
|
return c.processor.compile(c.processor.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *callback) Remove(name string) error {
|
||||||
|
logger.Default.Warn("removing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||||
|
c.name = name
|
||||||
|
c.remove = true
|
||||||
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||||
|
return c.processor.compile(c.processor.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *callback) Replace(name string, fn func(*DB)) error {
|
||||||
|
logger.Default.Info("replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||||
|
c.name = name
|
||||||
|
c.handler = fn
|
||||||
|
c.replace = true
|
||||||
|
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||||
|
return c.processor.compile(c.processor.db)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRIndex get right index from string slice
|
// getRIndex get right index from string slice
|
||||||
|
@ -114,98 +156,81 @@ func getRIndex(strs []string, str string) int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func sortProcessors(ps []*processor) []func(*DB) {
|
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||||
var (
|
var (
|
||||||
allNames, sortedNames []string
|
names, sorted []string
|
||||||
sortProcessor func(*processor) error
|
sortCallback func(*callback) error
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, p := range ps {
|
for _, c := range cs {
|
||||||
// show warning message the callback name already exists
|
// show warning message the callback name already exists
|
||||||
if idx := getRIndex(allNames, p.name); idx > -1 && !p.replace && !p.remove && !ps[idx].remove {
|
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
|
||||||
log.Printf("[warning] duplicated callback `%v` from %v\n", p.name, utils.FileWithLineNum())
|
logger.Default.Warn("duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
|
||||||
}
|
}
|
||||||
allNames = append(allNames, p.name)
|
names = append(names, c.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
sortProcessor = func(p *processor) error {
|
sortCallback = func(c *callback) error {
|
||||||
if getRIndex(sortedNames, p.name) == -1 { // if not sorted
|
if c.before != "" { // if defined before callback
|
||||||
if p.before != "" { // if defined before callback
|
if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
||||||
if sortedIdx := getRIndex(sortedNames, p.before); sortedIdx != -1 {
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||||
if curIdx := getRIndex(sortedNames, p.name); curIdx != -1 || true {
|
// if before callback already sorted, append current callback just after it
|
||||||
// if before callback already sorted, append current callback just after it
|
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
||||||
sortedNames = append(sortedNames[:sortedIdx], append([]string{p.name}, sortedNames[sortedIdx:]...)...)
|
} else if curIdx > sortedIdx {
|
||||||
} else if curIdx > sortedIdx {
|
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
|
||||||
return fmt.Errorf("conflicting callback %v with before %v", p.name, p.before)
|
|
||||||
}
|
|
||||||
} else if idx := getRIndex(allNames, p.before); idx != -1 {
|
|
||||||
// if before callback exists
|
|
||||||
ps[idx].after = p.name
|
|
||||||
}
|
}
|
||||||
|
} else if idx := getRIndex(names, c.before); idx != -1 {
|
||||||
|
// if before callback exists
|
||||||
|
cs[idx].after = c.name
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if p.after != "" { // if defined after callback
|
if c.after != "" { // if defined after callback
|
||||||
if sortedIdx := getRIndex(sortedNames, p.after); sortedIdx != -1 {
|
if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
||||||
|
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||||
// if after callback sorted, append current callback to last
|
// if after callback sorted, append current callback to last
|
||||||
sortedNames = append(sortedNames, p.name)
|
sorted = append(sorted, c.name)
|
||||||
} else if idx := getRIndex(allNames, p.after); idx != -1 {
|
} else if curIdx < sortedIdx {
|
||||||
// if after callback exists but haven't sorted
|
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
|
||||||
// set after callback's before callback to current callback
|
}
|
||||||
if after := ps[idx]; after.before == "" {
|
} else if idx := getRIndex(names, c.after); idx != -1 {
|
||||||
after.before = p.name
|
// if after callback exists but haven't sorted
|
||||||
sortProcessor(after)
|
// set after callback's before callback to current callback
|
||||||
}
|
after := cs[idx]
|
||||||
|
|
||||||
|
if after.before == "" {
|
||||||
|
after.before = c.name
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sortCallback(after); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sortCallback(c); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// if current callback haven't been sorted, append it to last
|
// if current callback haven't been sorted, append it to last
|
||||||
if getRIndex(sortedNames, p.name) == -1 {
|
if getRIndex(sorted, c.name) == -1 {
|
||||||
sortedNames = append(sortedNames, p.name)
|
sorted = append(sorted, c.name)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range ps {
|
for _, c := range cs {
|
||||||
sortProcessor(p)
|
if err = sortCallback(c); err != nil {
|
||||||
}
|
return
|
||||||
|
|
||||||
var fns []func(*DB)
|
|
||||||
for _, name := range sortedNames {
|
|
||||||
if idx := getRIndex(allNames, name); !ps[idx].remove {
|
|
||||||
fns = append(fns, ps[idx].handler)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fns
|
for _, name := range sorted {
|
||||||
}
|
if idx := getRIndex(names, name); !cs[idx].remove {
|
||||||
|
fns = append(fns, cs[idx].handler)
|
||||||
// compile processors
|
}
|
||||||
func (cs *Callbacks) compile(db *DB) {
|
}
|
||||||
processors := map[string][]*processor{}
|
|
||||||
for _, p := range cs.processors {
|
return
|
||||||
if p.name != "" {
|
|
||||||
if p.match == nil || p.match(db) {
|
|
||||||
processors[p.kind] = append(processors[p.kind], p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, ps := range processors {
|
|
||||||
switch name {
|
|
||||||
case "create":
|
|
||||||
cs.creates = sortProcessors(ps)
|
|
||||||
case "query":
|
|
||||||
cs.queries = sortProcessors(ps)
|
|
||||||
case "update":
|
|
||||||
cs.updates = sortProcessors(ps)
|
|
||||||
case "delete":
|
|
||||||
cs.deletes = sortProcessors(ps)
|
|
||||||
case "row":
|
|
||||||
cs.row = sortProcessors(ps)
|
|
||||||
case "raw":
|
|
||||||
cs.raw = sortProcessors(ps)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,131 +0,0 @@
|
||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func assertCallbacks(funcs []func(*DB), fnames []string) (result bool, msg string) {
|
|
||||||
var got []string
|
|
||||||
|
|
||||||
for _, f := range funcs {
|
|
||||||
got = append(got, getFuncName(f))
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getFuncName(fc func(*DB)) string {
|
|
||||||
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(fc).Pointer()).Name(), ".")
|
|
||||||
return fnames[len(fnames)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
func c1(*DB) {}
|
|
||||||
func c2(*DB) {}
|
|
||||||
func c3(*DB) {}
|
|
||||||
func c4(*DB) {}
|
|
||||||
func c5(*DB) {}
|
|
||||||
|
|
||||||
func TestCallbacks(t *testing.T) {
|
|
||||||
type callback struct {
|
|
||||||
name string
|
|
||||||
before string
|
|
||||||
after string
|
|
||||||
remove bool
|
|
||||||
replace bool
|
|
||||||
err error
|
|
||||||
match func(*DB) bool
|
|
||||||
h func(*DB)
|
|
||||||
}
|
|
||||||
|
|
||||||
datas := []struct {
|
|
||||||
callbacks []callback
|
|
||||||
results []string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}},
|
|
||||||
results: []string{"c1", "c2", "c3", "c4", "c5"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}},
|
|
||||||
results: []string{"c1", "c2", "c3", "c5", "c4"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}},
|
|
||||||
results: []string{"c1", "c2", "c3", "c5", "c4"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}},
|
|
||||||
results: []string{"c1", "c2", "c3", "c5", "c4"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}},
|
|
||||||
results: []string{"c1", "c5", "c2", "c3", "c4"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
|
|
||||||
results: []string{"c1", "c3", "c5", "c2", "c4"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
|
|
||||||
results: []string{"c1", "c5", "c3", "c4"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
|
||||||
results: []string{"c1", "c4", "c3"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
|
||||||
// var callback2 = &Callback{logger: defaultLogger}
|
|
||||||
|
|
||||||
// callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
|
||||||
// callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
|
||||||
// callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
|
|
||||||
// callback2.Delete().Register("after_create1", afterCreate1)
|
|
||||||
// callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
|
|
||||||
|
|
||||||
// if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
|
||||||
// t.Errorf("register callback with order")
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
for idx, data := range datas {
|
|
||||||
callbacks := &Callbacks{}
|
|
||||||
|
|
||||||
for _, c := range data.callbacks {
|
|
||||||
p := callbacks.Create()
|
|
||||||
|
|
||||||
if c.name == "" {
|
|
||||||
c.name = getFuncName(c.h)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.before != "" {
|
|
||||||
p = p.Before(c.before)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.after != "" {
|
|
||||||
p = p.After(c.after)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.match != nil {
|
|
||||||
p = p.Match(c.match)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.remove {
|
|
||||||
p.Remove(c.name)
|
|
||||||
} else if c.replace {
|
|
||||||
p.Replace(c.name, c.h)
|
|
||||||
} else {
|
|
||||||
p.Register(c.name, c.h)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok, msg := assertCallbacks(callbacks.creates, data.results); !ok {
|
|
||||||
t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,158 @@
|
||||||
|
package gorm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func assertCallbacks(v interface{}, fnames []string) (result bool, msg string) {
|
||||||
|
var (
|
||||||
|
got []string
|
||||||
|
funcs = reflect.ValueOf(v).Elem().FieldByName("fns")
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := 0; i < funcs.Len(); i++ {
|
||||||
|
got = append(got, getFuncName(funcs.Index(i)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFuncName(fc interface{}) string {
|
||||||
|
reflectValue, ok := fc.(reflect.Value)
|
||||||
|
if !ok {
|
||||||
|
reflectValue = reflect.ValueOf(fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
fnames := strings.Split(runtime.FuncForPC(reflectValue.Pointer()).Name(), ".")
|
||||||
|
return fnames[len(fnames)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func c1(*gorm.DB) {}
|
||||||
|
func c2(*gorm.DB) {}
|
||||||
|
func c3(*gorm.DB) {}
|
||||||
|
func c4(*gorm.DB) {}
|
||||||
|
func c5(*gorm.DB) {}
|
||||||
|
|
||||||
|
func TestCallbacks(t *testing.T) {
|
||||||
|
type callback struct {
|
||||||
|
name string
|
||||||
|
before string
|
||||||
|
after string
|
||||||
|
remove bool
|
||||||
|
replace bool
|
||||||
|
err string
|
||||||
|
match func(*gorm.DB) bool
|
||||||
|
h func(*gorm.DB)
|
||||||
|
}
|
||||||
|
|
||||||
|
datas := []struct {
|
||||||
|
callbacks []callback
|
||||||
|
err string
|
||||||
|
results []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}},
|
||||||
|
results: []string{"c1", "c2", "c3", "c4", "c5"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}},
|
||||||
|
results: []string{"c1", "c2", "c3", "c5", "c4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}},
|
||||||
|
results: []string{"c1", "c2", "c3", "c5", "c4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}},
|
||||||
|
results: []string{"c1", "c2", "c3", "c5", "c4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}},
|
||||||
|
results: []string{"c1", "c5", "c2", "c3", "c4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1, after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
|
||||||
|
results: []string{"c3", "c1", "c5", "c2", "c4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1, before: "c4", after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
|
||||||
|
results: []string{"c3", "c1", "c5", "c2", "c4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
|
||||||
|
err: "conflicting",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
|
||||||
|
results: []string{"c1", "c5", "c3", "c4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
||||||
|
results: []string{"c1", "c4", "c3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, data := range datas {
|
||||||
|
var err error
|
||||||
|
callbacks := gorm.InitializeCallbacks()
|
||||||
|
|
||||||
|
for _, c := range data.callbacks {
|
||||||
|
var v interface{} = callbacks.Create()
|
||||||
|
callMethod := func(s interface{}, name string, args ...interface{}) {
|
||||||
|
var argValues []reflect.Value
|
||||||
|
for _, arg := range args {
|
||||||
|
argValues = append(argValues, reflect.ValueOf(arg))
|
||||||
|
}
|
||||||
|
|
||||||
|
results := reflect.ValueOf(s).MethodByName(name).Call(argValues)
|
||||||
|
if len(results) > 0 {
|
||||||
|
v = results[0].Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.name == "" {
|
||||||
|
c.name = getFuncName(c.h)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.before != "" {
|
||||||
|
callMethod(v, "Before", c.before)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.after != "" {
|
||||||
|
callMethod(v, "After", c.after)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.match != nil {
|
||||||
|
callMethod(v, "Match", c.match)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.remove {
|
||||||
|
callMethod(v, "Remove", c.name)
|
||||||
|
} else if c.replace {
|
||||||
|
callMethod(v, "Replace", c.name, c.h)
|
||||||
|
} else {
|
||||||
|
callMethod(v, "Register", c.name, c.h)
|
||||||
|
}
|
||||||
|
|
||||||
|
if e, ok := v.(error); !ok || e != nil {
|
||||||
|
err = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data.err) > 0 && err == nil {
|
||||||
|
t.Errorf("callbacks tests #%v should got error %v, but not", idx+1, data.err)
|
||||||
|
} else if len(data.err) == 0 && err != nil {
|
||||||
|
t.Errorf("callbacks tests #%v should not got error, but got %v", idx+1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok, msg := assertCallbacks(callbacks.Create(), data.results); !ok {
|
||||||
|
t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue