forked from mirror/gorm
Fix CallbackProcessor.Get() for removed or replaced same name callback (#2548)
This commit is contained in:
parent
b954854116
commit
d5cafb5db1
10
callback.go
10
callback.go
|
@ -135,11 +135,15 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
|
|||
// db.Callback().Create().Get("gorm:create")
|
||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||
for _, p := range cp.parent.processors {
|
||||
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
|
||||
return *p.processor
|
||||
if p.name == callbackName && p.kind == cp.kind {
|
||||
if p.remove {
|
||||
callback = nil
|
||||
} else {
|
||||
callback = *p.processor
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// getRIndex get right index from string slice
|
||||
|
|
|
@ -2,11 +2,10 @@ package gorm_test
|
|||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func (s *Product) BeforeCreate() (err error) {
|
||||
|
@ -175,3 +174,46 @@ func TestCallbacksWithErrors(t *testing.T) {
|
|||
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCallback(t *testing.T) {
|
||||
scope := DB.NewScope(nil)
|
||||
|
||||
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
||||
t.Errorf("`gorm:test_callback` should be nil")
|
||||
}
|
||||
|
||||
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
|
||||
callback := DB.Callback().Create().Get("gorm:test_callback")
|
||||
if callback == nil {
|
||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||
}
|
||||
callback(scope)
|
||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
|
||||
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
|
||||
}
|
||||
|
||||
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
|
||||
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||
if callback == nil {
|
||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||
}
|
||||
callback(scope)
|
||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
|
||||
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
|
||||
}
|
||||
|
||||
DB.Callback().Create().Remove("gorm:test_callback")
|
||||
if DB.Callback().Create().Get("gorm:test_callback") != nil {
|
||||
t.Errorf("`gorm:test_callback` should be nil")
|
||||
}
|
||||
|
||||
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
|
||||
callback = DB.Callback().Create().Get("gorm:test_callback")
|
||||
if callback == nil {
|
||||
t.Errorf("`gorm:test_callback` should be non-nil")
|
||||
}
|
||||
callback(scope)
|
||||
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
|
||||
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue