mirror of https://github.com/go-gorm/gorm.git
fix: remove `callback` from `callbacks` if `Remove()` called (#6916)
* fix: remove callback from callbacks if Remove() called * reduce number of loops * remove unnecessary blank line
This commit is contained in:
parent
956f7ce843
commit
26195e6d16
19
callbacks.go
19
callbacks.go
|
@ -187,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
|
|||
|
||||
func (p *processor) compile() (err error) {
|
||||
var callbacks []*callback
|
||||
removedMap := map[string]bool{}
|
||||
for _, callback := range p.callbacks {
|
||||
if callback.match == nil || callback.match(p.db) {
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
if callback.remove {
|
||||
removedMap[callback.name] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(removedMap) > 0 {
|
||||
callbacks = removeCallbacks(callbacks, removedMap)
|
||||
}
|
||||
p.callbacks = callbacks
|
||||
|
||||
|
@ -339,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
|
||||
callbacks := make([]*callback, 0, len(cs))
|
||||
for _, callback := range cs {
|
||||
if nameMap[callback.name] {
|
||||
continue
|
||||
}
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
return callbacks
|
||||
}
|
||||
|
|
|
@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) {
|
|||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
|
||||
results: []string{"c1", "c5", "c3", "c4"},
|
||||
results: []string{"c1", "c3", "c4", "c5"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
||||
|
@ -206,3 +206,49 @@ func TestPluginCallbacks(t *testing.T) {
|
|||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksGet(t *testing.T) {
|
||||
db, _ := gorm.Open(nil, nil)
|
||||
createCallback := db.Callback().Create()
|
||||
|
||||
createCallback.Before("*").Register("c1", c1)
|
||||
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
|
||||
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
|
||||
}
|
||||
|
||||
createCallback.Remove("c1")
|
||||
if cb := createCallback.Get("c2"); cb != nil {
|
||||
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksRemove(t *testing.T) {
|
||||
db, _ := gorm.Open(nil, nil)
|
||||
createCallback := db.Callback().Create()
|
||||
|
||||
createCallback.Before("*").Register("c1", c1)
|
||||
createCallback.After("*").Register("c2", c2)
|
||||
createCallback.Before("c4").Register("c3", c3)
|
||||
createCallback.After("c2").Register("c4", c4)
|
||||
|
||||
// callbacks: []string{"c1", "c3", "c4", "c2"}
|
||||
createCallback.Remove("c1")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c4")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c2")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c3")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue