Fix CallbackProcessor.Get() for removed or replaced same name callback (#2548)

This commit is contained in:
Shunsuke Otani 2019-09-12 23:16:05 +09:00 committed by Jinzhu
parent b954854116
commit d5cafb5db1
2 changed files with 52 additions and 6 deletions

View File

@ -135,11 +135,15 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
// db.Callback().Create().Get("gorm:create") // db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
for _, p := range cp.parent.processors { for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind && !cp.remove { if p.name == callbackName && p.kind == cp.kind {
return *p.processor if p.remove {
callback = nil
} else {
callback = *p.processor
}
} }
} }
return nil return
} }
// getRIndex get right index from string slice // getRIndex get right index from string slice

View File

@ -2,11 +2,10 @@ package gorm_test
import ( import (
"errors" "errors"
"github.com/jinzhu/gorm"
"reflect" "reflect"
"testing" "testing"
"github.com/jinzhu/gorm"
) )
func (s *Product) BeforeCreate() (err error) { func (s *Product) BeforeCreate() (err error) {
@ -175,3 +174,46 @@ func TestCallbacksWithErrors(t *testing.T) {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
} }
} }
func TestGetCallback(t *testing.T) {
scope := DB.NewScope(nil)
if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
callback := DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
}
DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
}
DB.Callback().Create().Remove("gorm:test_callback")
if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}
DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
}
}