Set nopLogger to DefaultCallback for avoid nil pointer dereference (#2742)

This commit is contained in:
Shunsuke Otani 2019-12-05 23:57:15 +09:00 committed by Jinzhu
parent 0aba7ff3a0
commit e8c07b5531
3 changed files with 36 additions and 7 deletions

View File

@ -3,7 +3,7 @@ package gorm
import "fmt"
// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
var DefaultCallback = &Callback{logger: nopLogger{}}
// Callback is a struct that contains all CRUD callbacks
// Field `creates` contains callbacks will be call when creating object
@ -101,12 +101,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
}
}
if cp.logger != nil {
// note cp.logger will be nil during the default gorm callback registrations
// as they occur within init() blocks. However, any user-registered callbacks
// will happen after cp.logger exists (as the default logger or user-specified).
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
}
cp.name = callbackName
cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp)

View File

@ -217,3 +217,33 @@ func TestGetCallback(t *testing.T) {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
}
}
func TestUseDefaultCallback(t *testing.T) {
createCallbackName := "gorm:test_use_default_callback_for_create"
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
// nop
})
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
}
gorm.DefaultCallback.Create().Remove(createCallbackName)
if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
}
updateCallbackName := "gorm:test_use_default_callback_for_update"
scopeValueName := "gorm:test_use_default_callback_for_update_value"
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 1)
})
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 2)
})
scope := DB.NewScope(nil)
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
callback(scope)
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
}
}

View File

@ -135,3 +135,7 @@ type Logger struct {
func (logger Logger) Print(values ...interface{}) {
logger.Println(LogFormatter(values...)...)
}
type nopLogger struct{}
func (nopLogger) Print(values ...interface{}) {}