gorm/tests/callbacks_test.go

195 lines
5.3 KiB
Go

package tests_test
import (
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
. "gorm.io/gorm/utils/tests"
"reflect"
"runtime"
"strings"
"testing"
)
func assertCallbacks(v interface{}, fnames []string) (result bool, msg string) {
var (
got []string
funcs = reflect.ValueOf(v).Elem().FieldByName("fns")
)
for i := 0; i < funcs.Len(); i++ {
got = append(got, getFuncName(funcs.Index(i)))
}
return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got)
}
func getFuncName(fc interface{}) string {
reflectValue, ok := fc.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(fc)
}
fnames := strings.Split(runtime.FuncForPC(reflectValue.Pointer()).Name(), ".")
return fnames[len(fnames)-1]
}
func c1(*gorm.DB) {}
func c2(*gorm.DB) {}
func c3(*gorm.DB) {}
func c4(*gorm.DB) {}
func c5(*gorm.DB) {}
func TestCallbacks(t *testing.T) {
type callback struct {
name string
before string
after string
remove bool
replace bool
err string
match func(*gorm.DB) bool
h func(*gorm.DB)
}
datas := []struct {
callbacks []callback
err string
results []string
}{
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}},
results: []string{"c1", "c2", "c3", "c4", "c5"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}},
results: []string{"c1", "c2", "c3", "c5", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}},
results: []string{"c1", "c5", "c2", "c3", "c4"},
},
{
callbacks: []callback{{h: c1, after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
results: []string{"c3", "c1", "c5", "c2", "c4"},
},
{
callbacks: []callback{{h: c1, before: "c4", after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
results: []string{"c3", "c1", "c5", "c2", "c4"},
},
{
callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}},
err: "conflicting",
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
results: []string{"c1", "c5", "c3", "c4"},
},
{
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
results: []string{"c1", "c4", "c3"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c3", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c3", "c5", "c1", "c2", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c3", "c4"},
},
}
for idx, data := range datas {
db, err := gorm.Open(nil, nil)
callbacks := db.Callback()
for _, c := range data.callbacks {
var v interface{} = callbacks.Create()
callMethod := func(s interface{}, name string, args ...interface{}) {
var argValues []reflect.Value
for _, arg := range args {
argValues = append(argValues, reflect.ValueOf(arg))
}
results := reflect.ValueOf(s).MethodByName(name).Call(argValues)
if len(results) > 0 {
v = results[0].Interface()
}
}
if c.name == "" {
c.name = getFuncName(c.h)
}
if c.before != "" {
callMethod(v, "Before", c.before)
}
if c.after != "" {
callMethod(v, "After", c.after)
}
if c.match != nil {
callMethod(v, "Match", c.match)
}
if c.remove {
callMethod(v, "Remove", c.name)
} else if c.replace {
callMethod(v, "Replace", c.name, c.h)
} else {
callMethod(v, "Register", c.name, c.h)
}
if e, ok := v.(error); !ok || e != nil {
err = e
}
}
if len(data.err) > 0 && err == nil {
t.Errorf("callbacks tests #%v should got error %v, but not", idx+1, data.err)
} else if len(data.err) == 0 && err != nil {
t.Errorf("callbacks tests #%v should not got error, but got %v", idx+1, err)
}
if ok, msg := assertCallbacks(callbacks.Create(), data.results); !ok {
t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg)
}
}
}
func TestLoadOrStoreVisitMap(t *testing.T) {
var vistMap callbacks.VisitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vistMap = make(callbacks.VisitMap)
loaded = callbacks.LoadOrStoreVisitMap(&vistMap, &t1)
AssertEqual(t, loaded, false)
loaded = callbacks.LoadOrStoreVisitMap(&vistMap, &t1)
AssertEqual(t, loaded, true)
// t1 already exist but t2 not
loaded = callbacks.LoadOrStoreVisitMap(&vistMap, []*testM{&t1, &t2, &t3})
AssertEqual(t, loaded, false)
loaded = callbacks.LoadOrStoreVisitMap(&vistMap, []*testM{&t2, &t3})
AssertEqual(t, loaded, true)
}