mirror of https://github.com/go-gorm/gorm.git
Fix CallbackProcessor.Get() for removed or replaced same name callback (#2548)
This commit is contained in:
parent
b954854116
commit
d5cafb5db1
10
callback.go
10
callback.go
|
@ -135,11 +135,15 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
|
||||||
// db.Callback().Create().Get("gorm:create")
|
// db.Callback().Create().Get("gorm:create")
|
||||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||||
for _, p := range cp.parent.processors {
|
for _, p := range cp.parent.processors {
|
||||||
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
|
if p.name == callbackName && p.kind == cp.kind {
|
||||||
return *p.processor
|
if p.remove {
|
||||||
|
callback = nil
|
||||||
|
} else {
|
||||||
|
callback = *p.processor
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRIndex get right index from string slice
|
// getRIndex get right index from string slice
|
||||||
|
|
|
@ -2,11 +2,10 @@ package gorm_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
|
||||||
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Product) BeforeCreate() (err error) {
|
func (s *Product) BeforeCreate() (err error) {
|
||||||
|
@ -175,3 +174,46 @@ func TestCallbacksWithErrors(t *testing.T) {
|
||||||
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetCallback(t *testing.T) {
|
||||||
|
scope := DB.NewScope(nil)
|
||||||
|
|
||||||
|
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
||||||
|
t.Errorf("`gorm:test_callback` should be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
|
||||||
|
callback := DB.Callback().Create().Get("gorm:test_callback")
|
||||||
|
if callback == nil {
|
||||||
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
|
}
|
||||||
|
callback(scope)
|
||||||
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
|
||||||
|
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
|
||||||
|
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||||
|
if callback == nil {
|
||||||
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
|
}
|
||||||
|
callback(scope)
|
||||||
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
|
||||||
|
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Callback().Create().Remove("gorm:test_callback")
|
||||||
|
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
||||||
|
t.Errorf("`gorm:test_callback` should be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
|
||||||
|
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||||
|
if callback == nil {
|
||||||
|
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||||
|
}
|
||||||
|
callback(scope)
|
||||||
|
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
|
||||||
|
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue