From 26195e6d16cbb086423303d8178b78852ef12e2a Mon Sep 17 00:00:00 2001 From: snackmgmg <16898622+snackmgmg@users.noreply.github.com> Date: Tue, 26 Mar 2024 12:33:36 +0900 Subject: [PATCH] fix: remove `callback` from `callbacks` if `Remove()` called (#6916) * fix: remove callback from callbacks if Remove() called * reduce number of loops * remove unnecessary blank line --- callbacks.go | 19 ++++++++++++++++ tests/callbacks_test.go | 48 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/callbacks.go b/callbacks.go index 195d1720..50b5b0e9 100644 --- a/callbacks.go +++ b/callbacks.go @@ -187,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error { func (p *processor) compile() (err error) { var callbacks []*callback + removedMap := map[string]bool{} for _, callback := range p.callbacks { if callback.match == nil || callback.match(p.db) { callbacks = append(callbacks, callback) } + if callback.remove { + removedMap[callback.name] = true + } + } + + if len(removedMap) > 0 { + callbacks = removeCallbacks(callbacks, removedMap) } p.callbacks = callbacks @@ -339,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { return } + +func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback { + callbacks := make([]*callback, 0, len(cs)) + for _, callback := range cs { + if nameMap[callback.name] { + continue + } + callbacks = append(callbacks, callback) + } + return callbacks +} diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 4479da4c..f77209f1 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) { }, { callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, - results: []string{"c1", "c5", "c3", "c4"}, + results: []string{"c1", "c3", "c4", "c5"}, }, { callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, @@ -206,3 +206,49 @@ func TestPluginCallbacks(t *testing.T) { t.Errorf("callbacks tests failed, got %v", msg) } } + +func TestCallbacksGet(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("c1", c1) + if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) { + t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1) + } + + createCallback.Remove("c1") + if cb := createCallback.Get("c2"); cb != nil { + t.Errorf("callbacks test failed. got: %p, want: nil", cb) + } +} + +func TestCallbacksRemove(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("c1", c1) + createCallback.After("*").Register("c2", c2) + createCallback.Before("c4").Register("c3", c3) + createCallback.After("c2").Register("c4", c4) + + // callbacks: []string{"c1", "c3", "c4", "c2"} + createCallback.Remove("c1") + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c4") + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c2") + if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c3") + if ok, msg := assertCallbacks(createCallback, []string{}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } +}