gorm/tests/callbacks_test.go

255 lines
7.2 KiB
Go
Raw Normal View History

package tests_test
2020-01-31 03:29:35 +03:00
import (
"fmt"
"reflect"
"runtime"
"strings"
"testing"
2020-06-02 04:16:07 +03:00
"gorm.io/gorm"
2020-01-31 03:29:35 +03:00
)
func assertCallbacks(v interface{}, fnames []string) (result bool, msg string) {
var (
got []string
funcs = reflect.ValueOf(v).Elem().FieldByName("fns")
)
for i := 0; i < funcs.Len(); i++ {
got = append(got, getFuncName(funcs.Index(i)))
}
return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got)
}
func getFuncName(fc interface{}) string {
reflectValue, ok := fc.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(fc)
}
fnames := strings.Split(runtime.FuncForPC(reflectValue.Pointer()).Name(), ".")
return fnames[len(fnames)-1]
}
func c1(*gorm.DB) {}
func c2(*gorm.DB) {}
func c3(*gorm.DB) {}
func c4(*gorm.DB) {}
func c5(*gorm.DB) {}
func c6(*gorm.DB) {}
2020-01-31 03:29:35 +03:00
func TestCallbacks(t *testing.T) {
type callback struct {
name string
before string
after string
remove bool
replace bool
err string
match func(*gorm.DB) bool
h func(*gorm.DB)
}
datas := []struct {
callbacks []callback
err string
results []string
}{
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}},
results: []string{"c1", "c2", "c3", "c4", "c5"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}},
results: []string{"c1", "c5", "c2", "c3", "c4"},
},
{
callbacks: []callback{{h: c1, after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
results: []string{"c3", "c1", "c5", "c2", "c4"},
},
{
callbacks: []callback{{h: c1, before: "c4", after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
results: []string{"c3", "c1", "c5", "c2", "c4"},
},
{
callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
err: "conflicting",
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
results: []string{"c1", "c3", "c4", "c5"},
2020-01-31 03:29:35 +03:00
},
{
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
results: []string{"c1", "c4", "c3"},
},
2020-08-03 16:48:36 +03:00
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c3", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c3", "c5", "c1", "c2", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c3", "c4"},
2020-08-03 16:48:36 +03:00
},
2020-01-31 03:29:35 +03:00
}
for idx, data := range datas {
2020-02-02 14:32:27 +03:00
db, err := gorm.Open(nil, nil)
if err != nil {
t.Fatal(err)
}
2020-02-02 14:32:27 +03:00
callbacks := db.Callback()
2020-01-31 03:29:35 +03:00
for _, c := range data.callbacks {
var v interface{} = callbacks.Create()
callMethod := func(s interface{}, name string, args ...interface{}) {
var argValues []reflect.Value
for _, arg := range args {
argValues = append(argValues, reflect.ValueOf(arg))
}
results := reflect.ValueOf(s).MethodByName(name).Call(argValues)
if len(results) > 0 {
v = results[0].Interface()
}
}
if c.name == "" {
c.name = getFuncName(c.h)
}
if c.before != "" {
callMethod(v, "Before", c.before)
}
if c.after != "" {
callMethod(v, "After", c.after)
}
if c.match != nil {
callMethod(v, "Match", c.match)
}
if c.remove {
callMethod(v, "Remove", c.name)
} else if c.replace {
callMethod(v, "Replace", c.name, c.h)
} else {
callMethod(v, "Register", c.name, c.h)
}
if e, ok := v.(error); !ok || e != nil {
err = e
}
}
if len(data.err) > 0 && err == nil {
t.Errorf("callbacks tests #%v should got error %v, but not", idx+1, data.err)
} else if len(data.err) == 0 && err != nil {
t.Errorf("callbacks tests #%v should not got error, but got %v", idx+1, err)
}
if ok, msg := assertCallbacks(callbacks.Create(), data.results); !ok {
t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg)
}
}
}
func TestPluginCallbacks(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()
createCallback.Before("*").Register("plugin_1_fn1", c1)
createCallback.After("*").Register("plugin_1_fn2", c2)
if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
// plugin 2
createCallback.Before("*").Register("plugin_2_fn1", c3)
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.After("*").Register("plugin_2_fn2", c4)
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
// plugin 3
createCallback.Before("*").Register("plugin_3_fn1", c5)
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.After("*").Register("plugin_3_fn2", c6)
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
}
func TestCallbacksGet(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()
createCallback.Before("*").Register("c1", c1)
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
}
createCallback.Remove("c1")
if cb := createCallback.Get("c2"); cb != nil {
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
}
}
func TestCallbacksRemove(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()
createCallback.Before("*").Register("c1", c1)
createCallback.After("*").Register("c2", c2)
createCallback.Before("c4").Register("c3", c3)
createCallback.After("c2").Register("c4", c4)
// callbacks: []string{"c1", "c3", "c4", "c2"}
createCallback.Remove("c1")
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.Remove("c4")
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.Remove("c2")
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.Remove("c3")
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
}